"""
The model class for the Time-Dependent Schrodinger Equation (TDSE) calculation.
"""
from __future__ import annotations
import copy
from typing import Literal
import discvar
from discvar.abc import DVRPrimitivesMixin
from loguru import logger
import pytdscf
from pytdscf._const_cls import const
from pytdscf.hamiltonian_cls import (
HamiltonianMixin,
PolynomialHamiltonian,
TensorHamiltonian,
)
[docs]
class Model:
""" The wavefunction and operator information class
Args:
basinfo (BasInfo) : The wavefunction basis information
operators (Dict) : The operators. ``operators[name]`` gives a operator \
named ``name``, such as ``hamiltonian``.
build_td_hamiltonian (PolynomialHamiltonian) : \
Time dependent PolynomialHamiltonian. Defaults to ``None``.
Attributes:
init_weight_VIBSTATE (List[List[float]]) : \
Initial weight of VIBSTATE. List length is [nstate, ndof].
init_weight_GRID (List[List[float]]) : \
Initial weight of GRID. List length is [nstate, ndof].
init_weight_ESTATE (List[float]) : Initial weight of ESTATE. List length is nstate.
ints_prim_file (str) : The file name of the primitive integrals
m_aux_max (int) : The maximum number of auxiliary basis.
basinfo (BasInfo) : The wavefunction basis information
hamiltonian (PolynomialHamiltonian) : PolynomialHamiltonian
observables (Dict) : Observable operators, such as PolynomialHamiltonian, \
dipole moment, occupation number etc.
build_td_hamiltonian (PolynomialHamiltonian) : Time-dependent PolynomialHamiltonian.
"""
init_weight_VIBSTATE: list[list[float]] | None = None
init_weight_ESTATE: list[float] | None = None
init_HartreeProduct: list[list[list[float]]] | None = (
None # [state][dof][basis]
)
ints_prim_file: str | None = None
m_aux_max: int | None = None
def __init__(
self,
basinfo: BasInfo,
operators: dict[str, HamiltonianMixin],
*,
build_td_hamiltonian: PolynomialHamiltonian | None = None,
space: Literal["hilbert", "liouville"] = "hilbert",
):
self.basinfo = basinfo
self.hamiltonian = operators.pop("hamiltonian")
self.observables = operators
self.build_td_hamiltonian = build_td_hamiltonian
if self.hamiltonian.nstate != basinfo.get_nstate():
raise ValueError(
"The number of states in Hamiltonian and BasInfo are different."
)
self.nstate = self.hamiltonian.nstate
self.use_mpo = isinstance(self.hamiltonian, TensorHamiltonian)
if space.lower() not in ["hilbert", "liouville"]:
raise ValueError(
f"space must be 'hilbert' or 'liouville' but got {space}"
)
self.space: Literal["hilbert", "liouville"] = space.lower() # type: ignore
[docs]
def get_nstate(self) -> int:
"""
Returns:
int : Number of electronic states
"""
return self.basinfo.get_nstate()
[docs]
def get_ndof(self) -> int:
"""
Returns:
int : Degree of freedoms
"""
return self.basinfo.get_ndof()
[docs]
def get_ndof_per_sites(self):
"""N.Y.I ?"""
return self.basinfo.get_ndof_per_sites()
[docs]
def get_primbas(
self, istate: int, idof: int
) -> pytdscf.basis._primints_cls.PrimBas_HO:
"""
Args:
istate (int) : index of electronic states
idof (int) : index of degree of freedom
Returns:
primints_cls.PrimBas_HO : The primitive basis in istate, idof.
"""
return self.basinfo.get_primbas(istate, idof)
[docs]
def get_nspf(self, istate: int, idof: int) -> int:
"""
Args:
istate (int) : index of electronic states
idof (int) : index of degree of freedom
Returns:
int : The number of SPF in istate, idof.
"""
return self.basinfo.get_nspf(istate, idof)
[docs]
def get_nprim(self, istate: int, idof: int) -> int:
"""
Args:
istate (int) : index of electronic states
idof (int) : index of degree of freedom
Returns:
int : The number of primitive basis in istate, idof.
"""
return self.basinfo.get_nprim(istate, idof)
[docs]
def get_nspf_list(self, istate: int) -> list[int]:
"""
Args:
istate (int) : index of electronic states
Returns:
List(int) : Number of SPFs. e.g. ``[2, 2]``
"""
return self.basinfo.get_nspf_list(istate)
[docs]
def get_nspf_rangelist(self, istate: int) -> list[list[int]]:
"""
Args:
istate (int) : index of electronic states
Returns:
List[List[int]] : the indices of SPFs. e.g. \
``[[0,1,2],[0,1,2]]``
"""
return self.basinfo.get_nspf_rangelist(istate)
[docs]
class BasInfo:
""" The Wave function basis information class
Args:
prim_info (List[List[PrimBas_HO or DVRPrimitivesMixin]]) : \
``prim_info[istate][idof]`` gives ``PrimBas_HO``.
spf_info (List[List[int]]) : ``spf_info[istate][idof]`` gives \
the number of SPF.
ndof_per_sites (bool) : Defaults to ``None``. N.Y.I.
Attributes:
prim_info (List[List[PrimBas_HO or DVRPrimitivesMixin]]) : \
``prim_info[istate][idof]`` gives ``PrimBas_HO``.
spf_info (List[List[int]]) : ``spf_info[istate][idof]`` gives \
the number of SPF.
nstate (int) : The number of electronic states.
ndof (int) : The degree of freedoms.
nspf (int) : The number of SPF.
nspf_list (List[int]) : The number of SPFs.
nspf_rangelist (List[List[int]]) : The indices of SPFs.
"""
def __init__(self, prim_info, spf_info=None, ndof_per_sites=None):
self.prim_info = copy.deepcopy(prim_info)
self.is_DVR = any(
isinstance(basis, pytdscf.basis.abc.DVRPrimitivesMixin)
or isinstance(basis, DVRPrimitivesMixin)
for basis in prim_info[0]
)
self.need_primints = any(
isinstance(basis, pytdscf.PrimBas_HO | discvar.ho.PrimBas_HO)
for basis in prim_info[0]
)
if spf_info is None:
if const.verbose > 1:
logger.info("The layer of SPF is not used.")
self.spf_info = [
[
len(self.prim_info[istate][idof])
for idof in range(self.get_ndof())
]
for istate in range(self.get_nstate())
]
self.is_standard_method = True
else:
self.spf_info = copy.deepcopy(spf_info)
self.is_standard_method = False
if ndof_per_sites:
raise NotImplementedError
self.ndof_per_sites = ndof_per_sites
[docs]
def get_nstate(self) -> int:
"""Get ``nstate`` attributes
Returns:
int : Number of electronic states
"""
if not hasattr(self, "nstate"):
self.nstate = len(self.prim_info)
return self.nstate
[docs]
def get_ndof(self) -> int:
"""Get ``ndof`` attributes
Returns:
int : Degree of freedom
"""
if not hasattr(self, "ndof"):
self.ndof = len(self.prim_info[0])
return self.ndof
[docs]
def get_ndof_per_sites(self) -> list[int]:
"""Get ``ndof_per_sites`` attributes"""
raise NotImplementedError
# return self.ndof_per_sites
[docs]
def get_primbas(
self, istate: int, idof: int
) -> pytdscf.basis._primints_cls.PrimBas_HO:
"""Get ``prim_info[istate][idof]`` attributes
Args:
istate (int) : index of electronic states
idof (int) : index of degree of freedom
Returns:
PrimBas_HO : The primitive basis of istate, idof.
"""
# NYI->i_set = self.state_label[istate][idof]
return self.prim_info[istate][idof]
[docs]
def get_nspf(self, istate: int, idof: int) -> int:
"""Get number of SPF
Args:
istate (int) : index of electronic states
idof (int) : index of degree of freedom
Returns:
int : Number of SPF
"""
# NYI->i_set = self.state_label[istate][idof]
return self.spf_info[istate][idof]
[docs]
def get_nprim(self, istate: int, idof: int) -> int:
"""Get number of primitive basis
Args:
istate (int) : index of electronic states
idof (int) : index of degree of freedom
Returns:
int : Number of primitive basis
"""
return self.prim_info[istate][idof].nprim
[docs]
def get_ngrid(self, istate: int, idof: int) -> int:
return self.get_nprim(istate, idof)
[docs]
def get_nspf_list(self, istate: int) -> list[int]:
"""Get number of SPFs list ``nspf_list`` attributes
Args:
istate (int) : index of electronic states
Returns:
list(int) : Number of SPFs. e.g. ``[2, 2]``
"""
if not hasattr(self, "nsfp_list"):
self.nspf_list = []
for idof in range(self.get_ndof()):
self.nspf_list.append(self.get_nspf(istate, idof))
return self.nspf_list
[docs]
def get_nspf_rangelist(self, istate: int) -> list[list[int]]:
""" Get number of SPFs list ``nspf_rangelist`` attributes
Args:
istate (int) : index of electronic states
Returns:
List[List[int]] : the indices of SPFs. e.g. \
``[[0,1,2],[0,1,2]]``
"""
if not hasattr(self, "nsfp_rangelist"):
self.nspf_rangelist = []
for idof in range(self.get_ndof()):
self.nspf_rangelist.append(
list(range(self.get_nspf(istate, idof)))
)
return self.nspf_rangelist