"""
Kraus operator functions for Lindblad master equation
"""
from math import isqrt
from typing import Literal, overload
import jax
import jax.numpy as jnp
import numpy as np
from opt_einsum import contract
from scipy.linalg import expm, svd
from pytdscf._const_cls import const
[docs]
def lindblad_to_kraus(
Lops: list[np.ndarray],
dt: float,
backend: Literal["numpy", "jax"] = "numpy",
) -> np.ndarray | jnp.ndarray:
"""
Convert set of Lindblad operators {L_j} to a Kraus operator set {B_q}
exp(D dt) = ∑_q=1^k B_q ⊗ B_q*
where
D = ∑_j [L_j⊗L_j* - 1/2 (L_j†L_j ⊗ I + I⊗L_j⊤L_j*)]
See DOI: https://doi.org/10.1103/PhysRevLett.116.237201
Args:
Lops (list[np.ndarray]): Lindblad operators
dt (float): time step
Returns:
np.ndarray | jnp.ndarray: Kraus operator tensor
"""
assert all(L.ndim == 2 for L in Lops)
assert all(L.shape[0] == L.shape[1] for L in Lops)
assert dt > 0
L = Lops.pop()
Ldag = L.conj().T
I = np.eye(L.shape[0])
D = np.kron(L, L.conj()) - 1 / 2 * (
np.kron(Ldag @ L, I) + np.kron(I, L.T @ L.conj())
)
if np.allclose(D.imag, 0):
D = D.real
while Lops:
L = Lops.pop()
Ldag = L.conj().T
_D = np.kron(L, L.conj()) - 1 / 2 * (
np.kron(Ldag @ L, I) + np.kron(I, L.T @ L.conj())
)
if np.allclose(_D.imag, 0):
_D = _D.real
else:
D = D.astype(complex)
D += _D
dissipator = expm(D * dt)
# eigenvalues of dissipator should be positive
assert np.all(np.linalg.eigvals(dissipator) > -1e-14), np.linalg.eigvals(
dissipator
)
# Kraus operators
def supergate_to_kraus(G, d, tol=1e-14):
S4 = G.reshape(d, d, d, d, order="F") # S[α,β,μ,ν]
J = np.transpose(S4, (0, 2, 1, 3)).reshape(
d * d, d * d, order="F"
) # J[(αμ),(βν)] = S[α,β,μ,ν]
J = (J + J.conj().T) / 2 # hermitize
if np.allclose(J, J.conj().T):
w, V = np.linalg.eigh(J)
else:
w, V = np.linalg.eig(J)
kraus = []
for lam, v in zip(w, V.T, strict=True):
lam = lam.real
if lam > tol:
kraus.append(np.sqrt(lam) * v.reshape(d, d, order="F"))
return kraus # satisfies sum_q B_q† B_q ≈ I
Bs = supergate_to_kraus(dissipator, isqrt(dissipator.shape[0]))
# Confirm exp(D dt) = ∑_q=1^k B_q ⊗ B_q*
np.testing.assert_allclose(
dissipator,
np.sum([np.kron(B, B.conj()) for B in Bs], axis=0),
atol=1e-14,
)
"""
TN diagram of Kraus operators
d
|
--B
| |
k d
tensor shape (k, d, d)
"""
k = len(Bs)
d = Bs[0].shape[0]
B: np.ndarray | jax.Array
match backend:
case "numpy":
Bs = [np.array(B, dtype=np.complex128) for B in Bs]
B = np.stack(Bs, axis=0)
case "jax":
Bs = [jnp.array(B, dtype=jnp.complex128) for B in Bs]
B = jnp.stack(Bs, axis=0)
case _:
raise ValueError(f"Invalid backend: {backend}")
assert B.shape == (k, d, d), (
f"Kraus operator shape mismatch: {B.shape} != ({k}, {d}, {d})"
)
return B
@overload
def kraus_contract_single_site(
B: np.ndarray, core: np.ndarray
) -> np.ndarray: ...
@overload
def kraus_contract_single_site(B: jax.Array, core: np.ndarray) -> jax.Array: ...
[docs]
def kraus_contract_single_site(
B: np.ndarray | jax.Array, core: np.ndarray | jax.Array
) -> np.ndarray | jax.Array:
if isinstance(B, np.ndarray) and isinstance(core, np.ndarray):
return _kraus_contract_single_site_np(B, core)
elif isinstance(B, jax.Array) and isinstance(core, jax.Array):
return _kraus_contract_single_site_jax(B, core)
else:
raise ValueError(f"Invalid backend: {type(B)=} while {type(core)=}")
def _kraus_contract_single_site_np(B: np.ndarray, A: np.ndarray) -> np.ndarray:
"""
x
|
k-B
|
d
dK
|
m-A-n
"""
k, d, x = B.shape
assert d == x
m, dK, n = A.shape
assert dK % d == 0, f"Kraus contract: dK={dK} must be divisible by d={d}"
K = dK // d
# 1. reshape A to (m, d, K, m)
A = A.reshape(m, d, K, n)
"""
x
|
k-B
|
d
d
|
m-A-n
|
K
"""
# 2. contract "d" legs of B and A
C = np.einsum("kxd,mdKn->mnxkK", B, A)
r"""
x
|
m-C-n
|\
k K
"""
# 3. reshape C to (m, n, x, kK)
C = C.reshape(m * n * x, k * K)
"""
mnx-C-kK
"""
# 4. SVD of C
# U, S, _ = np.linalg.svd(C, full_matrices=False)
U, S, _ = svd(C, full_matrices=False, overwrite_a=True)
"""
mnx-U-kK kK-S-kK kK-Vh-kK
"""
# 5. truncate singular values
S = S[:K]
U = U[:, :K]
"""
mnx-U-K K-S-K
"""
# 6. concatenate U and S as new A
A = U * S[np.newaxis, :]
if const.pytest_enabled:
np.testing.assert_allclose(A, U @ np.diag(S))
"""
mnx-A-K
"""
# 7. reshape A to (m, n, xK) and swap indices n and xK
A = np.ascontiguousarray(A.reshape(m, n, x * K).swapaxes(1, 2))
"""
xK
|
m-A-n
"""
return A
@jax.jit
def _kraus_contract_single_site_jax(B: jax.Array, A: jax.Array) -> jax.Array:
k, d, x = B.shape
m, dK, n = A.shape
K = dK // d
# 1. reshape A to (m, d, K, n)
A = A.reshape(m, d, K, n)
# 2. contract "d" legs of B and A
C = jnp.einsum("kxd,mdKn->mnxkK", B, A)
# 3. reshape C to (m, n, x, kK)
C = C.reshape(m * n * x, k * K)
# 4. SVD of C
U, S, _ = jnp.linalg.svd(C, full_matrices=False) # type: ignore[not-iterable, unknown-argument]
# 5. truncate singular values
S = S[:K]
U = U[:, :K]
# 6. concatenate U and S as new A
A = U * S[jnp.newaxis, :]
# 7. reshape A to (m, n, xK) and swap indices n and xK
A = A.reshape(m, n, x * K).swapaxes(1, 2)
return A
@overload
def kraus_contract_two_site(
B: np.ndarray, core_1: np.ndarray, core_2: np.ndarray
) -> tuple[np.ndarray, np.ndarray]: ...
@overload
def kraus_contract_two_site(
B: jax.Array, core_1: jax.Array, core_2: jax.Array
) -> tuple[jax.Array, jax.Array]: ...
[docs]
def kraus_contract_two_site(
B: np.ndarray | jax.Array,
core_1: np.ndarray | jax.Array,
core_2: np.ndarray | jax.Array,
) -> tuple[np.ndarray, np.ndarray] | tuple[jax.Array, jax.Array]:
if (
isinstance(B, np.ndarray)
and isinstance(core_1, np.ndarray)
and isinstance(core_2, np.ndarray)
):
return _kraus_contract_two_site_np(B, core_1, core_2)
elif (
isinstance(B, jax.Array)
and isinstance(core_1, jax.Array)
and isinstance(core_2, jax.Array)
):
return _kraus_contract_two_site_jax(B, core_1, core_2)
else:
raise ValueError(
f"Invalid backend: {type(B)=} while {type(core_1)=} and {type(core_2)=}"
)
def _kraus_contract_two_site_np(
B: np.ndarray, A1: np.ndarray, A2: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""
x
|
B-k
|
d
d K
| |
m-A1-l l-A2-n
Typically, m=l=n > K >> k=x=d
"""
k, x, d = B.shape
assert x == d
assert A1.shape[1] == d
m, d, l = A1.shape
assert A1.shape[2] == A2.shape[0]
l, K, n = A2.shape
# 0. Contract all B, A1, A2 with optimized path
# opt_einsum will find the optimal contraction order automatically
C = contract("kxd,mdl,lKn->mxnkK", B, A1, A2)
"""
mxn-C-kK
"""
C = C.reshape(m * x * n, k * K)
# U, S, _ = np.linalg.svd(C, full_matrices=False)
U, S, _ = svd(C, full_matrices=False, overwrite_a=True)
if const.pytest_enabled:
print(f"truncation percentage: {1 - S[:K].sum() / S.sum():.2%}")
S = S[:K]
U = U[:, :K]
"""
mxn-U-K K-S-K
"""
# 3. concatenate U and S as new C (in-place when possible)
U *= S[np.newaxis, :] # in-place multiplication to save memory
"""
mxn-U-K (now contains US)
"""
# 4. reshape U to (m, x, n, K) and swap indices n and K
C = U.reshape(m, x, n, K).swapaxes(2, 3)
r"""
x K
\ /
m-C-n
"""
# 5. reshape C to (mx, Kn)
C = C.reshape(m * x, K * n)
C = np.ascontiguousarray(C)
"""
mx-C-Kn
"""
# 6. Rank-limited SVD of C (only compute first l singular values)
# U, S, Vh = np.linalg.svd(C, full_matrices=False)
U, S, Vh = svd(C, full_matrices=False, overwrite_a=True)
U = U[:, :l]
S = S[:l]
Vh = Vh[:l, :]
"""
mx-U-l-S-l-Vh-Kn
"""
# 8. concatenate U and S as new A1, Vh as new A2 (in-place when possible)
U *= S[np.newaxis, :] # in-place multiplication to save memory
"""
mx-A1-l l-A2-Kn
"""
# 9. reshape A1 to (m, x, l) and A2 to (l, K, n)
A1 = U.reshape(m, x, l)
A2 = Vh.reshape(l, K, n)
"""
x K
| |
m-A1-l l-A2-n
"""
return A1, A2
@jax.jit
def _kraus_contract_two_site_jax(
B: jax.Array, A1: jax.Array, A2: jax.Array
) -> tuple[jax.Array, jax.Array]:
"""
x
|
B-k
|
d
d K
| |
m-A1-l l-A2-n
Typically, m=l=n > K >> k=x=d
"""
k, x, d = B.shape
m, d, l = A1.shape
l, K, n = A2.shape
# 0. Contract all B, A1, A2 with optimized einsum
C = jnp.einsum("kxd,mdl,lKn->mxnkK", B, A1, A2)
"""
mxn-C-kK
"""
C = C.reshape(m * x * n, k * K)
# 1. SVD of C (truncated if beneficial)
U, S, _ = jnp.linalg.svd(C, full_matrices=False) # type: ignore[not-iterable, unknown-argument]
S = S[:K]
U = U[:, :K]
"""
mxn-U-K K-S-K
"""
# 3. concatenate U and S as new C (in-place style)
U *= S[jnp.newaxis, :]
"""
mxn-U-K (now contains US)
"""
# 4. reshape U to (m, x, n, K) and swap indices n and K
C = U.reshape(m, x, n, K).swapaxes(2, 3)
r"""
x K
\ /
m-C-n
"""
# 5. reshape C to (mx, Kn)
C = C.reshape(m * x, K * n)
"""
mx-C-Kn
"""
# 6. SVD of C (second decomposition)
U, S, Vh = jnp.linalg.svd(C, full_matrices=False) # type: ignore[not-iterable, unknown-argument]
U = U[:, :l]
S = S[:l]
Vh = Vh[:l, :]
"""
mx-U-l-S-l-Vh-Kn
"""
# 8. concatenate U and S as new A1, Vh as new A2 (in-place style)
U *= S[jnp.newaxis, :]
"""
mx-A1-l l-A2-Kn
"""
# 9. reshape A1 to (m, x, l) and A2 to (l, K, n)
A1 = U.reshape(m, x, l)
A2 = Vh.reshape(l, K, n)
"""
x K
| |
m-A1-l l-A2-n
"""
return A1, A2
[docs]
def trace_kraus_dim(rdm: np.ndarray, d: int):
r"""
dK d K d
| \ /| |
C => C | => C
| / \| |
dK d K d
"""
dK = rdm.shape[-1]
assert dK % d == 0, (
f"Kraus dimension reduction: dK={dK} must be divisible by d={d}"
)
K = dK // d
if rdm.ndim == 2:
rdm = rdm.reshape(d, K, d, K)
rdm = np.einsum("dKxK->dx", rdm)
else:
assert rdm.ndim == 3, f"rdm.ndim={rdm.ndim} must be 2 or 3"
rdm = rdm.reshape(-1, d, K, d, K)
rdm = np.einsum("tdKxK->tdx", rdm)
return rdm