import sys
import os
import subprocess
import numpy as np
import argparse

# ── Configuration ────────────────────────────────────────────────────────────

# Absolute path to the MSYS2/MinGW-w64 Python you want for the f2py build
MSYS_PYTHON = r"D:\msys64\mingw64\bin\python.exe"

BENCHMARK_TARGETS = [
    {
        "name": "Python (base)",
        "command": [sys.executable, os.path.join("benchmarks", "benchmark_svd_python_base.py")],
    },
    {
        "name": "Python (Vectorized)",
        "command": [sys.executable, os.path.join("benchmarks", "benchmark_svd_python_vectorized.py")],
    },
    {
        "name": "Python (Base+JIT)",
        "command": [sys.executable, os.path.join("benchmarks", "benchmark_svd_python_base_jit.py")],
    },
    {
        "name": "Python (Base+JIT-Cached)",
        "command": [sys.executable, os.path.join("benchmarks", "benchmark_svd_python_base_jit_cached.py")],
    },
    {
        "name": "Python (f2py)",
#        "command": [MSYS_PYTHON, os.path.join("benchmarks", "benchmark_svd_python_f2py.py")],
        "command": [sys.executable, os.path.join("benchmarks", "benchmark_svd_python_f2py.py")],
    },
    {
        "name": "Python (f2py_intent_out)",
        "command": [sys.executable, os.path.join("benchmarks", "benchmark_svd_python_f2py_intent_out.py")],
    },
    {
        "name": "Python (ctypes)",
        "command": [sys.executable, os.path.join("benchmarks", "benchmark_svd_python_ctypes.py")],
    },
    {
        "name": "Fortran (Original)",
        "command": [os.path.join("bin", "benchmark_svd_fortran_original.exe")],
    },
    {
        "name": "Fortran (Refactored)",
        "command": [os.path.join("bin", "benchmark_svd_fortran_refactored.exe")],
    },
]

DATA_FILES = [
    {"size": "100x100",   "path": os.path.join("test_data", "performance_input", "matrix_100x100_rand_uniform.bin")},
    {"size": "200x200",   "path": os.path.join("test_data", "performance_input", "matrix_200x200_rand_uniform.bin")},
    {"size": "400x400",   "path": os.path.join("test_data", "performance_input", "matrix_400x400_rand_uniform.bin")},
    {"size": "800x800",   "path": os.path.join("test_data", "performance_input", "matrix_800x800_rand_uniform.bin")},
    {"size": "1600x1600", "path": os.path.join("test_data", "performance_input", "matrix_1600x1600_rand_uniform.bin")},
    {"size": "2400x2400", "path": os.path.join("test_data", "performance_input", "matrix_2400x2400_rand_uniform.bin")},
    {"size": "3200x3200", "path": os.path.join("test_data", "performance_input", "matrix_3200x3200_rand_uniform.bin")},
]

# ── Helpers ──────────────────────────────────────────────────────────────────
def run_single_benchmark(target_name: str, command: list[str], data_path: str, num_runs: int, base_level: str):
    """Run one target against one data file, returning best / mean / std / cv times."""
    print(f"  Running: {target_name}...")

    # Existence check
    if not os.path.exists(data_path):
        print(f"    -> ERROR: Data file not found at '{data_path}'", file=sys.stderr)
        return [], None, None, None

    # Optional skip: Python baseline is painfully slow on the biggest matrices. Controlled by --base-level.
    # In run_single_benchmark function...
    if target_name == "Python (base)":
        # Define sizes to skip for each level
        sizes_to_skip_for_800 = ("800x800", "1600x1600", "2400x2400", "3200x3200")
        sizes_to_skip_for_1600 = ("1600x1600", "2400x2400", "3200x3200")

        if base_level == 'skip-800' and any(size in data_path for size in sizes_to_skip_for_800):
            print("     -> Skipped (per --base-level skip-800).")
            return None, None, None, None
        if base_level == 'skip-1600' and any(size in data_path for size in sizes_to_skip_for_1600):
            print("     -> Skipped (per --base-level skip-1600).")
            return None, None, None, None
            
    times: list[float] = []
    # Loop for the specified number of runs.
    for _ in range(num_runs):
        full_command = command + [data_path]
        try:
            # Capture text, but *ignore* undecodable bytes to avoid UnicodeDecodeError
            result = subprocess.run(
                full_command,
                capture_output=True,
                text=True,
                encoding="utf-8",
                errors="ignore",  # ← prevents UnicodeDecodeError on Shift-JIS consoles
                check=True,
            )
            # Assume first whitespace-separated token on stdout is the timing number
            times.append(float(result.stdout.strip().split()[0]))

        except FileNotFoundError:
            print(f"    -> ERROR: Command not found. Is '{command[0]}' built and on PATH?",
                  file=sys.stderr)
            return [], None, None, None

        except Exception as e:
            print(f"    -> ERROR: Could not run command '{' '.join(full_command)}'",
                  file=sys.stderr)
            print(f"    -> REASON: {e}", file=sys.stderr)

            # Safely display stderr when available
            if isinstance(e, subprocess.CalledProcessError):
                stderr = e.stderr
                if isinstance(stderr, bytes):
                    stderr = stderr.decode("utf-8", errors="ignore")
                if stderr:
                    print(f"    -> STDERR: {stderr.strip()}", file=sys.stderr)
            return [], None, None, None

    if not times:
        return [], None, None, None
    
    # np.std returns 0.0 for a single run, which is correct.
    best, mean, std = np.min(times), np.mean(times), np.std(times)
    # Calculate Coefficient of Variation, handle division by zero.
    cv = (std / mean) * 100 if mean > 0 else 0.0
    return best, mean, std, cv


def print_results_table(results: list[dict]):
    """Neat ASCII table."""
    # Add new CV column to header
    print(f"{'Target':<28} | {'Best':>15s} | {'Mean':>15s} | {'Std Dev':>15s} | {'CV (%)':>10s}")
    print("-" * 95)
    for row in results:
        if row["best"] is None and row["mean"] is None:
            # Add N/A for CV column
            print(f"{row['name']:<28} | {'N/A':>15} | {'N/A':>15} | {'N/A':>15} | {'N/A':>10}")
        elif None in (row["best"], row["mean"], row["std"], row["cv"]):
             # Add ERROR for CV column
            print(f"{row['name']:<28} | {'ERROR':>15} | {'ERROR':>15} | {'ERROR':>15} | {'ERROR':>10}")
        else:
            # Add formatted CV column
            print(
                f"{row['name']:<28} | "
                f"{row['best']:.7f}s".rjust(15) + " | "
                f"{row['mean']:.7f}s".rjust(15) + " | "
                f"{row['std']:.7f}s".rjust(15) + " | "
                f"{row['cv']:.2f}%".rjust(10)
            )


# ── Main ─────────────────────────────────────────────────────────────────────
def main():
    """Main script execution."""
    parser = argparse.ArgumentParser(
        description="Run SVD implementation benchmarks.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "-n", "--num-runs",
        type=int,
        default=3, # Changed default from 1 to 3
        help="Number of times to run each benchmark for a given data size."
    )
    parser.add_argument(
        "--base-level",
        type=str,
        default="skip-800",
        choices=["full", "skip-1600", "skip-800", "none"],
        help="Control execution level for the slow 'Python (base)' implementation. "
             "'full': run all sizes. "
             "'skip-1600': skip sizes 1600x1600 and larger. "
             "'skip-800': skip sizes 800x800 and larger. "
             "'none': disable 'Python (base)' entirely."
    )
    args = parser.parse_args()

    if "PYTHONPATH" not in os.environ:
        print("Warning: PYTHONPATH is not set. Did you run './setup_pythonpath.ps1'?",
              file=sys.stderr)

    active_targets = BENCHMARK_TARGETS
    if args.base_level == 'none':
        active_targets = [t for t in BENCHMARK_TARGETS if t['name'] != 'Python (base)']
        print("--> 'Python (base)' benchmark is disabled via --base-level=none")

    for data in DATA_FILES:
        print("\n" + "=" * 95)
        print(f"--- Benchmarking with {data['size']} data ({args.num_runs} run(s)) ---")
        print("=" * 95)

        results_raw = [
            run_single_benchmark(t["name"], t["command"], data["path"], args.num_runs, args.base_level)
            for t in active_targets
        ]
        table_rows = [
            # Unpack new 'cv' value
            {"name": t["name"], "best": r[0], "mean": r[1], "std": r[2], "cv": r[3]}
            for t, r in zip(active_targets, results_raw)
        ]
        print_results_table(table_rows)

    print("\nBenchmark finished.")


if __name__ == "__main__":
    main()
