"""
Multiprocessing mesh electronic structure calculation
caller using ASE library
"""
import itertools
import os
import pickle
import shutil
import time
import types
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from copy import deepcopy
from pathlib import Path
from typing import Callable
import numpy as np
from ase.atoms import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.genericfileio import GenericFileIOCalculator
from ase.db import connect
from loguru import logger
from pytdscf import units
from pytdscf._helper import from_dbkey, progressbar
from pytdscf.basis.abc import DVRPrimitivesMixin
logger = logger.bind(name="main")
def _run(atoms: Atoms) -> tuple[Atoms]:
"""Run for multiprocessing
(do not include in DVR_Mesh class because it may become slow.)
"""
try:
atoms.get_total_energy() # 1 / Hartree
atoms.get_forces() # 1 / Hartree * Bohr
atoms.get_dipole_moment() # 1 / Bohr
except Exception as e:
logger.warning(f"ERROR: {e}")
return (atoms,)
def _todict(self):
"""
If calculator has not attribute 'todict',
add this method to enable db.update
"""
original_dict = vars(self)
return original_dict["parameters"] | original_dict["results"]
[docs]
class DVR_Mesh:
""" DVR grid coordinate
Args:
dvr_prims (List[DVRPrimitivesMixin]) : DVR primitive list
atoms (List[List[str, Tuple[float,float,float]]] or ase.atoms.Atoms) : \
reference coordinate. Format is the same as PySCF ones
disp_vec (np.ndarray) : displacement vectors (row vectors) in angstrom
unit (Optional[bool]) : \
Input reference coordinate unit. Defaults to 'angstrom'
"""
displace: dict
geometry: dict
grid_id: dict
jobname: str
remain_jobs: deque
reset_calc: bool
done_jobs: deque
thrown_jobs: deque
def __init__(
self,
dvr_prims: list[DVRPrimitivesMixin],
atoms: Atoms,
disp_vec: np.ndarray,
unit: str = "angstrom",
):
self.grid_list = [g.get_grids() for g in dvr_prims]
self.dvr_prims = dvr_prims
self.ndof = len(dvr_prims)
self.disp_vec = disp_vec
self.zero_indices = [None for _ in range(self.ndof)]
if len(dvr_prims) != len(disp_vec):
raise TypeError
if disp_vec.shape[-1] != 3:
raise TypeError
for i, prim in enumerate(dvr_prims):
grids = prim.get_grids()
for j, grid in enumerate(grids):
if abs(grid) < 1.0e-10:
self.zero_indices[i] = j
break
if type(atoms) is Atoms:
unit = "angstrom"
self.symbols = atoms.symbols
self.positions = atoms.positions
self.masses = atoms.get_masses()
else:
if unit.lower() == "angstrom":
self.positions = np.array([position for _, position in atoms])
elif unit.lower() in ["bohr", "au", "a.u."]:
self.positions = (
np.array([position for _, position in atoms])
* units.au_in_angstrom
)
else:
raise NotImplementedError
self.symbols = [element for element, _ in atoms]
if disp_vec.shape[1] != len(self.positions):
raise TypeError
[docs]
def save_geoms(
self,
jobname: str,
nMR: int | None = None,
overwrite: bool | None = None,
) -> dict[str, dict[str, int]]:
"""
Generate cartesian coordinate geometry for each grid mesh.
Args:
nMR (Optional[int]) : Tne number of mode representation. \
limits n dimensional mesh. \
Defaults to ``None``, thus, \
``ngrid**ndof`` coords will be generated.
overwrite (Optional[bool]) : overwrite detabase
Returns:
Dict[str, Dict[str, int]] : DVR Mesh coordinates. \
E.g. [(0,1)][(2,3)] gives 2nd, 3rd grids of 0-mode, \
1-mode coordinate.
"""
self.jobname = jobname
if nMR is None:
nMR = self.ndof
logger.warning("START : Displacement Generation")
for iMR in range(1, nMR + 1):
if iMR == 1:
for idof in range(self.ndof):
self.displace[(idof,)] = {}
for igrid, coef in enumerate(self.grid_list[idof]):
self.displace[(idof,)][(igrid,)] = (
coef * self.disp_vec[idof]
)
else:
for dof_key in itertools.combinations(range(self.ndof), r=iMR):
self.displace[dof_key] = {}
add_dof = (dof_key[-1],)
orig_dof = dof_key[:-1]
for orig_grid_key, orig_disp_vec in self.displace[
orig_dof
].items():
for add_grid_key, add_disp_vec in self.displace[
add_dof
].items():
grid_key = orig_grid_key + add_grid_key
disp_vec = orig_disp_vec + add_disp_vec
self.displace[dof_key][grid_key] = disp_vec
for dof_key, grid_dict in self.displace.items():
if len(dof_key) < self.ndof == nMR:
continue
self.geometry[dof_key] = {}
for grid_key, disp_vec in grid_dict.items():
self.geometry[dof_key][grid_key] = self.positions + disp_vec
logger.warning("DONE : Displacement Generation")
if os.path.exists(f"{self.jobname}.db"):
if overwrite is None:
yes_or_else = input(
f"{self.jobname}.db is already exists!"
+ " Dou you remove the database ?[y/n]"
+ "(Defaults to 'y')"
)
else:
yes_or_else = "y" if overwrite else "n"
if yes_or_else.lower() in ["y", "yes", ""]:
os.remove(f"{self.jobname}.db")
overwrite = True
else:
overwrite = False
self.grid_id = {}
with connect(f"{self.jobname}.db") as db:
_iter = 0
for dof_key, grid_dict in self.geometry.items():
self.grid_id[dof_key] = {}
for grid_key, coord in progressbar(
grid_dict.items(),
prefix=f"Save geometry in DB of DOFs = {dof_key}",
):
if overwrite:
atoms_grid = Atoms(self.symbols, positions=coord)
grid = deepcopy(self.zero_indices)
for p, g in zip(dof_key, grid_key, strict=True):
grid[p] = g
_id = db.write(
atoms_grid,
dofs="|" + " ".join(map(str, dof_key)),
grids="|" + " ".join(map(str, grid)),
)
self.grid_id[dof_key][grid_key] = _id
else:
_iter += 1
self.grid_id[dof_key][grid_key] = _iter
with open(f"{self.jobname}_grid_id.pkl", "wb") as f:
pickle.dump(self.grid_id, f)
""" nMR coord can get by (n-1)MR coord.
May be too memory consumption"""
# Free Memory
del self.geometry
del self.displace
return self.grid_id
[docs]
def execute_multiproc(
self,
calculator: Calculator,
max_workers: int | None = None,
timeout: float = 60.0,
jobname: str | None = None,
reset_calc: bool = False,
judge_func: Callable | None = None,
):
"""Execute electronic structure calculation by multiprocessing
Args:
calculator (Calculator) : calculator for each geomtry
max_workers (Optional[int]) : maximum workers in multi-processing.
Defaults to None. If None, use cpu_count - 1.
timeout (float) : Timeout calculation in second. Defaults to 60.0
jobname (Optional[str]) : jobname
reset_calc (Optional[bool]) : set new calculator in any case.
Defaults to False.
judge_func (Optional[Callable[[Any],bool]]) : judge function whether re-calculation is needed.
Defaults to None.
""" # noqa: E501
self.calc = calculator
if isinstance(jobname, str):
self.jobname = jobname
if self.jobname is None:
raise ValueError("required jobname argument.")
self.remain_jobs = deque()
self.thrown_jobs = deque()
self.done_jobs = deque()
# In case of duplicated jobs, we should remove duplicated jobs.
if judge_func is None:
self.judge_func = lambda row: True
else:
self.judge_func = judge_func
with connect(f"{self.jobname}.db") as db:
unique_jobs = dict()
count_unique_jobs = 0
for row in db.select():
if self.judge_func(row):
dof_key = tuple(from_dbkey(row.dofs))
grid_key_tmp = from_dbkey(row.grids)
grid_key = tuple([grid_key_tmp[p] for p in dof_key])
if row.grids not in unique_jobs:
unique_jobs[row.grids] = row.id
count_unique_jobs += 1
self.remain_jobs.append(
(
dof_key,
grid_key,
row.id,
unique_jobs[row.grids],
None,
)
)
logger.warning(f"unique jobs : {count_unique_jobs}")
self.reset_calc = reset_calc
if max_workers is None:
ncpu = os.cpu_count()
assert isinstance(ncpu, int)
max_workers = ncpu - 1
wait_process = max_workers
logger.warning("START : Electronic Structure Calculations")
n = len(self.remain_jobs)
if n > 0:
with ProcessPoolExecutor(max_workers) as exe:
with connect(f"{self.jobname}.db") as db:
logger.warning(f"Connected: {self.jobname}.db")
for _ in range(min(wait_process, len(self.remain_jobs))):
self._throw_job_to_queue(exe, db)
for _iter in progressbar(range(n)):
self._pick_up_job_from_queue(db, timeout)
if self.remain_jobs:
self._throw_job_to_queue(exe, db)
logger.warning("WAIT : Remaining future task")
while self.thrown_jobs:
self._pick_up_job_from_queue(db, timeout)
logger.warning("DONE : Electronic Structure Calculations")
if len(self.done_jobs) == n:
logger.warning("Your calculation completely finished!")
else:
logger.warning(
f"Remained {n - len(self.done_jobs)} jobs!"
+ " you should execute once again "
+ "with different conditions and judge_func"
)
for process in exe._processes.values():
process.kill()
logger.warning("DONE : Shutdown process executor")
def _throw_job_to_queue(
self,
exe,
db,
):
while self.remain_jobs:
dof_key, grid_key, _id, unique_id, error = (
self.remain_jobs.popleft()
)
if _id == unique_id:
row_unique = db.get(unique_id)
if self.judge_func(row_unique) and hasattr(
row_unique, "energy"
):
atoms = row_unique.toatoms()
if not isinstance(atoms.calc, Calculator):
atoms.calc.todict = types.MethodType(
_todict, atoms.calc
)
db.update(_id, atoms)
self.done_jobs.append((dof_key, grid_key, _id))
continue
atoms = db.get_atoms(_id)
if atoms.calc is None or self.reset_calc:
atoms.calc = deepcopy(self.calc)
if isinstance(atoms.calc, Calculator):
# Gaussian etc
atoms.calc.set_label(f"{self.jobname}/{_id:07}/calc")
elif isinstance(atoms.calc, GenericFileIOCalculator):
# Orca etc
atoms.calc.directory = Path(f"{self.jobname}/{_id:07}/calc")
else:
raise NotImplementedError
future = exe.submit(_run, atoms)
self.thrown_jobs.append(
(future, dof_key, grid_key, _id, time.time())
)
break
def _pick_up_job_from_queue(self, db, timeout: float = 60.0):
while True:
if self.thrown_jobs:
future, dof_key, grid_key, _id, start_time = (
self.thrown_jobs.popleft()
)
else:
break
if future.done():
if (error := future.exception()) is None:
atoms = future.result()[0]
if not isinstance(atoms.calc, Calculator):
atoms.calc.todict = types.MethodType(
_todict, atoms.calc
)
db.update(_id, atoms=atoms)
self.done_jobs.append((dof_key, grid_key, _id))
else:
logger.warn(
f"ERROR:\t{error}\tDOFs={dof_key}\tgrids={grid_key}\tID={_id}"
)
if os.path.exists(f"{self.jobname}/{_id:07}"):
shutil.rmtree(f"{self.jobname}/{_id:07}")
break
elif time.time() - start_time > timeout:
try:
if not future.running():
future.cancel()
# Wait 1.0 second and discard the job
future.result(1.0)
except Exception:
# TimeoutError or CancelledError
pass
logger.warning(
f"TIMEOUT:\tDOFs={dof_key}\tgrids={grid_key}\tID={_id} "
)
if os.path.exists(f"{self.jobname}/{_id:07}"):
shutil.rmtree(f"{self.jobname}/{_id:07}")
break
else:
time.sleep(1)
self.thrown_jobs.append(
(future, dof_key, grid_key, _id, start_time)
)