import numpy as np
import ctypes
from pathlib import Path

# --- DLLのパスを解決し、ロードする ---
try:
    dll_path = Path(__file__).parent.parent / 'bin' / 'svd_ctypes.dll'
    if not dll_path.exists():
        raise FileNotFoundError(f"DLL not found at {dll_path}. Please build it first using 'build-benchmarks.ps1'.")
    
    fortran_lib = ctypes.CDLL(str(dll_path))
except (FileNotFoundError, OSError) as e:
    print(f"Error loading DLL: {e}")
    fortran_lib = None

# --- Fortran関数のインターフェースを定義する ---
if fortran_lib:
    # Fortranラッパーで指定されたC言語名を指定
    dsvdc_c_func = fortran_lib.dsvdc_c
    
    # Fortran関数は info を整数(c_int)として返す
    dsvdc_c_func.restype = ctypes.c_int
    
    # 引数の型を定義。Fortranラッパーの仕様に完全に一致させる。
    dsvdc_c_func.argtypes = [
        np.ctypeslib.ndpointer(dtype=np.float64, flags='F_CONTIGUOUS'), # x(n, p)
        ctypes.c_int, # n (値渡し)
        ctypes.c_int, # p (値渡し)
        np.ctypeslib.ndpointer(dtype=np.float64), # s(n)
        np.ctypeslib.ndpointer(dtype=np.float64), # e(p)
        np.ctypeslib.ndpointer(dtype=np.float64, flags='F_CONTIGUOUS'), # u(n, n)
        np.ctypeslib.ndpointer(dtype=np.float64, flags='F_CONTIGUOUS'), # v(p, p)
        ctypes.c_int, # job (値渡し)
    ]

def dsvdc(x, n, p, s, e, u, v, job):
    """
    fortran_src/ctypes/dsvdc_wrapper.f90 の dsvdc_c 関数を呼び出す。
    出力用の配列 s, e, u, v は呼び出し元で確保し、引数として渡す。
    これらの配列は関数内で直接書き換えられる。

    Args:
        x (np.ndarray): 入力行列 (N x P)。
        n (int): The number of rows in x.
        p (int): The number of columns in x.
        s (np.ndarray): 特異値を格納する1次元配列。
        e (np.ndarray): 特異値を格納する1次元配列(追加情報)。
        u (np.ndarray): 左特異ベクトルを格納する2次元配列。
        v (np.ndarray): 右特異ベクトルを格納する2次元配列。
        job (int): 計算内容を指定するフラグ。

    Returns:
        int: Fortran関数からのinfoコード。0なら成功。
    """
    if not fortran_lib:
        raise RuntimeError("Fortran library is not loaded. Cannot perform SVD.")

    # Fortranは列優先配列を期待するため、入力配列をF-contiguousに変換
    # (すでになっている場合は変換されない)
    x_f = np.asfortranarray(x, dtype=np.float64)
    u_f = np.asfortranarray(u, dtype=np.float64)
    v_f = np.asfortranarray(v, dtype=np.float64)
    
    # The line `n, p = x_f.shape` is no longer needed as n and p are passed in.
    
    # C関数を直接呼び出す。n, p, job は単純な整数として渡す。
    info = dsvdc_c_func(x_f, n, p, s, e, u_f, v_f, job)
    
    # u, vはFortran側で書き換えられるため、元の配列に戻す(必要な場合のみ)
    # np.asfortranarrayがコピーを作成した場合、結果を元の配列にコピーし直す
    if u is not u_f:
        u[...] = u_f
    if v is not v_f:
        v[...] = v_f
        
    return info