"""
Kraus operator functions for Lindblad master equation
"""
from math import isqrt
from typing import Literal, overload
import jax
import jax.numpy as jnp
import numpy as np
from opt_einsum import contract
from scipy.linalg import expm, svd
from pytdscf._const_cls import const
[docs]
def lindblad_to_kraus(
    Lops: list[np.ndarray],
    dt: float,
    backend: Literal["numpy", "jax"] = "numpy",
) -> np.ndarray | jnp.ndarray:
    """
    Convert set of Lindblad operators {L_j} to a Kraus operator set {B_q}
    exp(D dt) = ∑_q=1^k B_q ⊗ B_q*
    where
    D = ∑_j [L_j⊗L_j* - 1/2 (L_j†L_j ⊗ I + I⊗L_j⊤L_j*)]
    See DOI: https://doi.org/10.1103/PhysRevLett.116.237201
    Args:
        Lops (list[np.ndarray]): Lindblad operators
        dt (float): time step
    Returns:
        np.ndarray | jnp.ndarray: Kraus operator tensor
    """
    assert all(L.ndim == 2 for L in Lops)
    assert all(L.shape[0] == L.shape[1] for L in Lops)
    assert dt > 0
    L = Lops.pop()
    Ldag = L.conj().T
    I = np.eye(L.shape[0])
    D = np.kron(L, L.conj()) - 1 / 2 * (
        np.kron(Ldag @ L, I) + np.kron(I, L.T @ L.conj())
    )
    if np.allclose(D.imag, 0):
        D = D.real
    while Lops:
        L = Lops.pop()
        Ldag = L.conj().T
        _D = np.kron(L, L.conj()) - 1 / 2 * (
            np.kron(Ldag @ L, I) + np.kron(I, L.T @ L.conj())
        )
        if np.allclose(_D.imag, 0):
            _D = _D.real
        else:
            D = D.astype(complex)
        D += _D
    dissipator = expm(D * dt)
    # eigenvalues of dissipator should be positive
    assert np.all(np.linalg.eigvals(dissipator) > -1e-14), np.linalg.eigvals(
        dissipator
    )
    # Kraus operators
    def supergate_to_kraus(G, d, tol=1e-14):
        S4 = G.reshape(d, d, d, d, order="F")  # S[α,β,μ,ν]
        J = np.transpose(S4, (0, 2, 1, 3)).reshape(
            d * d, d * d, order="F"
        )  # J[(αμ),(βν)] = S[α,β,μ,ν]
        J = (J + J.conj().T) / 2  # hermitize
        if np.allclose(J, J.conj().T):
            w, V = np.linalg.eigh(J)
        else:
            w, V = np.linalg.eig(J)
        kraus = []
        for lam, v in zip(w, V.T, strict=True):
            lam = lam.real
            if lam > tol:
                kraus.append(np.sqrt(lam) * v.reshape(d, d, order="F"))
        return kraus  # satisfies sum_q B_q† B_q ≈ I
    Bs = supergate_to_kraus(dissipator, isqrt(dissipator.shape[0]))
    # Confirm exp(D dt) = ∑_q=1^k B_q ⊗ B_q*
    np.testing.assert_allclose(
        dissipator,
        np.sum([np.kron(B, B.conj()) for B in Bs], axis=0),
        atol=1e-14,
    )
    """
    TN diagram of Kraus operators
      d
      |
    --B
    | |
    k d
    tensor shape (k, d, d)
    """
    k = len(Bs)
    d = Bs[0].shape[0]
    B: np.ndarray | jax.Array
    match backend:
        case "numpy":
            Bs = [np.array(B, dtype=np.complex128) for B in Bs]
            B = np.stack(Bs, axis=0)
        case "jax":
            Bs = [jnp.array(B, dtype=jnp.complex128) for B in Bs]
            B = jnp.stack(Bs, axis=0)
        case _:
            raise ValueError(f"Invalid backend: {backend}")
    assert B.shape == (k, d, d), (
        f"Kraus operator shape mismatch: {B.shape} != ({k}, {d}, {d})"
    )
    return B 
@overload
def kraus_contract_single_site(
    B: np.ndarray, core: np.ndarray
) -> np.ndarray: ...
@overload
def kraus_contract_single_site(B: jax.Array, core: np.ndarray) -> jax.Array: ...
[docs]
def kraus_contract_single_site(
    B: np.ndarray | jax.Array, core: np.ndarray | jax.Array
) -> np.ndarray | jax.Array:
    if isinstance(B, np.ndarray) and isinstance(core, np.ndarray):
        return _kraus_contract_single_site_np(B, core)
    elif isinstance(B, jax.Array) and isinstance(core, jax.Array):
        return _kraus_contract_single_site_jax(B, core)
    else:
        raise ValueError(f"Invalid backend: {type(B)=} while {type(core)=}") 
def _kraus_contract_single_site_np(B: np.ndarray, A: np.ndarray) -> np.ndarray:
    """
      x
      |
    k-B
      |
      d
      dK
      |
    m-A-n
    """
    k, d, x = B.shape
    assert d == x
    m, dK, n = A.shape
    assert dK % d == 0, f"Kraus contract: dK={dK} must be divisible by d={d}"
    K = dK // d
    # 1. reshape A to (m, d, K, m)
    A = A.reshape(m, d, K, n)
    """
      x
      |
    k-B
      |
      d
      d
      |
    m-A-n
      |
      K
    """
    # 2. contract "d" legs of B and A
    C = np.einsum("kxd,mdKn->mnxkK", B, A)
    r"""
      x
      |
    m-C-n
      |\
      k K
    """
    # 3. reshape C to (m, n, x, kK)
    C = C.reshape(m * n * x, k * K)
    """
    mnx-C-kK
    """
    # 4. SVD of C
    # U, S, _ = np.linalg.svd(C, full_matrices=False)
    U, S, _ = svd(C, full_matrices=False, overwrite_a=True)
    """
    mnx-U-kK kK-S-kK kK-Vh-kK
    """
    # 5. truncate singular values
    S = S[:K]
    U = U[:, :K]
    """
    mnx-U-K K-S-K
    """
    # 6. concatenate U and S as new A
    A = U * S[np.newaxis, :]
    if const.pytest_enabled:
        np.testing.assert_allclose(A, U @ np.diag(S))
    """
    mnx-A-K
    """
    # 7. reshape A to (m, n, xK) and swap indices n and xK
    A = np.ascontiguousarray(A.reshape(m, n, x * K).swapaxes(1, 2))
    """
      xK
      |
    m-A-n
    """
    return A
@jax.jit
def _kraus_contract_single_site_jax(B: jax.Array, A: jax.Array) -> jax.Array:
    k, d, x = B.shape
    m, dK, n = A.shape
    K = dK // d
    # 1. reshape A to (m, d, K, n)
    A = A.reshape(m, d, K, n)
    # 2. contract "d" legs of B and A
    C = jnp.einsum("kxd,mdKn->mnxkK", B, A)
    # 3. reshape C to (m, n, x, kK)
    C = C.reshape(m * n * x, k * K)
    # 4. SVD of C
    U, S, _ = jnp.linalg.svd(C, full_matrices=False)
    # 5. truncate singular values
    S = S[:K]
    U = U[:, :K]
    # 6. concatenate U and S as new A
    A = U * S[jnp.newaxis, :]
    # 7. reshape A to (m, n, xK) and swap indices n and xK
    A = A.reshape(m, n, x * K).swapaxes(1, 2)
    return A
@overload
def kraus_contract_two_site(
    B: np.ndarray, core_1: np.ndarray, core_2: np.ndarray
) -> tuple[np.ndarray, np.ndarray]: ...
@overload
def kraus_contract_two_site(
    B: jax.Array, core_1: jax.Array, core_2: jax.Array
) -> tuple[jax.Array, jax.Array]: ...
[docs]
def kraus_contract_two_site(
    B: np.ndarray | jax.Array,
    core_1: np.ndarray | jax.Array,
    core_2: np.ndarray | jax.Array,
) -> tuple[np.ndarray, np.ndarray] | tuple[jax.Array, jax.Array]:
    if (
        isinstance(B, np.ndarray)
        and isinstance(core_1, np.ndarray)
        and isinstance(core_2, np.ndarray)
    ):
        return _kraus_contract_two_site_np(B, core_1, core_2)
    elif (
        isinstance(B, jax.Array)
        and isinstance(core_1, jax.Array)
        and isinstance(core_2, jax.Array)
    ):
        return _kraus_contract_two_site_jax(B, core_1, core_2)
    else:
        raise ValueError(
            f"Invalid backend: {type(B)=} while {type(core_1)=} and {type(core_2)=}"
        ) 
def _kraus_contract_two_site_np(
    B: np.ndarray, A1: np.ndarray, A2: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    """
      x
      |
      B-k
      |
      d
      d      K
      |      |
    m-A1-l l-A2-n
    Typically, m=l=n > K >> k=x=d
    """
    k, x, d = B.shape
    assert x == d
    assert A1.shape[1] == d
    m, d, l = A1.shape
    assert A1.shape[2] == A2.shape[0]
    l, K, n = A2.shape
    # 0. Contract all B, A1, A2 with optimized path
    # opt_einsum will find the optimal contraction order automatically
    C = contract("kxd,mdl,lKn->mxnkK", B, A1, A2)
    """
    mxn-C-kK
    """
    C = C.reshape(m * x * n, k * K)
    # U, S, _ = np.linalg.svd(C, full_matrices=False)
    U, S, _ = svd(C, full_matrices=False, overwrite_a=True)
    if const.pytest_enabled:
        print(f"truncation percentage: {1 - S[:K].sum() / S.sum():.2%}")
    S = S[:K]
    U = U[:, :K]
    """
    mxn-U-K K-S-K
    """
    # 3. concatenate U and S as new C (in-place when possible)
    U *= S[np.newaxis, :]  # in-place multiplication to save memory
    """
    mxn-U-K (now contains US)
    """
    # 4. reshape U to (m, x, n, K) and swap indices n and K
    C = U.reshape(m, x, n, K).swapaxes(2, 3)
    r"""
    x   K
     \ /
    m-C-n
    """
    # 5. reshape C to (mx, Kn)
    C = C.reshape(m * x, K * n)
    C = np.ascontiguousarray(C)
    """
    mx-C-Kn
    """
    # 6. Rank-limited SVD of C (only compute first l singular values)
    # U, S, Vh = np.linalg.svd(C, full_matrices=False)
    U, S, Vh = svd(C, full_matrices=False, overwrite_a=True)
    U = U[:, :l]
    S = S[:l]
    Vh = Vh[:l, :]
    """
    mx-U-l-S-l-Vh-Kn
    """
    # 8. concatenate U and S as new A1, Vh as new A2 (in-place when possible)
    U *= S[np.newaxis, :]  # in-place multiplication to save memory
    """
    mx-A1-l l-A2-Kn
    """
    # 9. reshape A1 to (m, x, l) and A2 to (l, K, n)
    A1 = U.reshape(m, x, l)
    A2 = Vh.reshape(l, K, n)
    """
      x      K
      |      |
    m-A1-l l-A2-n
    """
    return A1, A2
@jax.jit
def _kraus_contract_two_site_jax(
    B: jax.Array, A1: jax.Array, A2: jax.Array
) -> tuple[jax.Array, jax.Array]:
    """
      x
      |
      B-k
      |
      d
      d      K
      |      |
    m-A1-l l-A2-n
    Typically, m=l=n > K >> k=x=d
    """
    k, x, d = B.shape
    m, d, l = A1.shape
    l, K, n = A2.shape
    # 0. Contract all B, A1, A2 with optimized einsum
    C = jnp.einsum("kxd,mdl,lKn->mxnkK", B, A1, A2)
    """
    mxn-C-kK
    """
    C = C.reshape(m * x * n, k * K)
    # 1. SVD of C (truncated if beneficial)
    U, S, _ = jnp.linalg.svd(C, full_matrices=False)
    S = S[:K]
    U = U[:, :K]
    """
    mxn-U-K K-S-K
    """
    # 3. concatenate U and S as new C (in-place style)
    U *= S[jnp.newaxis, :]
    """
    mxn-U-K (now contains US)
    """
    # 4. reshape U to (m, x, n, K) and swap indices n and K
    C = U.reshape(m, x, n, K).swapaxes(2, 3)
    r"""
    x   K
     \ /
    m-C-n
    """
    # 5. reshape C to (mx, Kn)
    C = C.reshape(m * x, K * n)
    """
    mx-C-Kn
    """
    # 6. SVD of C (second decomposition)
    U, S, Vh = jnp.linalg.svd(C, full_matrices=False)
    U = U[:, :l]
    S = S[:l]
    Vh = Vh[:l, :]
    """
    mx-U-l-S-l-Vh-Kn
    """
    # 8. concatenate U and S as new A1, Vh as new A2 (in-place style)
    U *= S[jnp.newaxis, :]
    """
    mx-A1-l l-A2-Kn
    """
    # 9. reshape A1 to (m, x, l) and A2 to (l, K, n)
    A1 = U.reshape(m, x, l)
    A2 = Vh.reshape(l, K, n)
    """
      x      K
      |      |
    m-A1-l l-A2-n
    """
    return A1, A2
[docs]
def trace_kraus_dim(rdm: np.ndarray, d: int):
    r"""
    dK    d   K      d
    |      \ /|      |
    C   =>  C |  =>  C
    |      / \|      |
    dK    d   K      d
    """
    dK = rdm.shape[-1]
    assert dK % d == 0, (
        f"Kraus dimension reduction: dK={dK} must be divisible by d={d}"
    )
    K = dK // d
    if rdm.ndim == 2:
        rdm = rdm.reshape(d, K, d, K)
        rdm = np.einsum("dKxK->dx", rdm)
    else:
        assert rdm.ndim == 3, f"rdm.ndim={rdm.ndim} must be 2 or 3"
        rdm = rdm.reshape(-1, d, K, d, K)
        rdm = np.einsum("tdKxK->tdx", rdm)
    return rdm