Source code for pytdscf.ase_handler

"""
Multiprocessing mesh electronic structure calculation
caller using ASE library
"""

import itertools
import os
import pickle
import shutil
import time
import types
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from copy import deepcopy
from pathlib import Path
from typing import Callable

import numpy as np
from ase.atoms import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.genericfileio import GenericFileIOCalculator
from ase.db import connect
from loguru import logger

from pytdscf import units
from pytdscf._helper import from_dbkey, progressbar
from pytdscf.basis.abc import DVRPrimitivesMixin

logger = logger.bind(name="main")


def _run(atoms: Atoms) -> tuple[Atoms]:
    """Run for multiprocessing
    (do not include in DVR_Mesh class because it may become slow.)
    """
    try:
        atoms.get_total_energy()  # 1 / Hartree
        atoms.get_forces()  # 1 / Hartree * Bohr
        atoms.get_dipole_moment()  # 1 / Bohr
    except Exception as e:
        logger.warning(f"ERROR: {e}")
    return (atoms,)


def _todict(self):
    """
    If calculator has not attribute 'todict',
    add this method to enable db.update
    """
    original_dict = vars(self)
    return original_dict["parameters"] | original_dict["results"]


[docs] class DVR_Mesh: """ DVR grid coordinate Args: dvr_prims (List[DVRPrimitivesMixin]) : DVR primitive list atoms (List[List[str, Tuple[float,float,float]]] or ase.atoms.Atoms) : \ reference coordinate. Format is the same as PySCF ones disp_vec (np.ndarray) : displacement vectors (row vectors) in angstrom unit (Optional[bool]) : \ Input reference coordinate unit. Defaults to 'angstrom' """ displace: dict geometry: dict grid_id: dict jobname: str remain_jobs: deque reset_calc: bool done_jobs: deque thrown_jobs: deque def __init__( self, dvr_prims: list[DVRPrimitivesMixin], atoms: Atoms, disp_vec: np.ndarray, unit: str = "angstrom", ): self.grid_list = [g.get_grids() for g in dvr_prims] self.dvr_prims = dvr_prims self.ndof = len(dvr_prims) self.disp_vec = disp_vec self.zero_indices = [None for _ in range(self.ndof)] if len(dvr_prims) != len(disp_vec): raise TypeError if disp_vec.shape[-1] != 3: raise TypeError for i, prim in enumerate(dvr_prims): grids = prim.get_grids() for j, grid in enumerate(grids): if abs(grid) < 1.0e-10: self.zero_indices[i] = j break if type(atoms) is Atoms: unit = "angstrom" self.symbols = atoms.symbols self.positions = atoms.positions self.masses = atoms.get_masses() else: if unit.lower() == "angstrom": self.positions = np.array([position for _, position in atoms]) elif unit.lower() in ["bohr", "au", "a.u."]: self.positions = ( np.array([position for _, position in atoms]) * units.au_in_angstrom ) else: raise NotImplementedError self.symbols = [element for element, _ in atoms] if disp_vec.shape[1] != len(self.positions): raise TypeError
[docs] def save_geoms( self, jobname: str, nMR: int | None = None, overwrite: bool | None = None, ) -> dict[str, dict[str, int]]: """ Generate cartesian coordinate geometry for each grid mesh. Args: nMR (Optional[int]) : Tne number of mode representation. \ limits n dimensional mesh. \ Defaults to ``None``, thus, \ ``ngrid**ndof`` coords will be generated. overwrite (Optional[bool]) : overwrite detabase Returns: Dict[str, Dict[str, int]] : DVR Mesh coordinates. \ E.g. [(0,1)][(2,3)] gives 2nd, 3rd grids of 0-mode, \ 1-mode coordinate. """ self.jobname = jobname if nMR is None: nMR = self.ndof logger.warning("START : Displacement Generation") for iMR in range(1, nMR + 1): if iMR == 1: for idof in range(self.ndof): self.displace[(idof,)] = {} for igrid, coef in enumerate(self.grid_list[idof]): self.displace[(idof,)][(igrid,)] = ( coef * self.disp_vec[idof] ) else: for dof_key in itertools.combinations(range(self.ndof), r=iMR): self.displace[dof_key] = {} add_dof = (dof_key[-1],) orig_dof = dof_key[:-1] for orig_grid_key, orig_disp_vec in self.displace[ orig_dof ].items(): for add_grid_key, add_disp_vec in self.displace[ add_dof ].items(): grid_key = orig_grid_key + add_grid_key disp_vec = orig_disp_vec + add_disp_vec self.displace[dof_key][grid_key] = disp_vec for dof_key, grid_dict in self.displace.items(): if len(dof_key) < self.ndof == nMR: continue self.geometry[dof_key] = {} for grid_key, disp_vec in grid_dict.items(): self.geometry[dof_key][grid_key] = self.positions + disp_vec logger.warning("DONE : Displacement Generation") if os.path.exists(f"{self.jobname}.db"): if overwrite is None: yes_or_else = input( f"{self.jobname}.db is already exists!" + " Dou you remove the database ?[y/n]" + "(Defaults to 'y')" ) else: yes_or_else = "y" if overwrite else "n" if yes_or_else.lower() in ["y", "yes", ""]: os.remove(f"{self.jobname}.db") overwrite = True else: overwrite = False self.grid_id = {} with connect(f"{self.jobname}.db") as db: _iter = 0 for dof_key, grid_dict in self.geometry.items(): self.grid_id[dof_key] = {} for grid_key, coord in progressbar( grid_dict.items(), prefix=f"Save geometry in DB of DOFs = {dof_key}", ): if overwrite: atoms_grid = Atoms(self.symbols, positions=coord) grid = deepcopy(self.zero_indices) for p, g in zip(dof_key, grid_key, strict=True): grid[p] = g _id = db.write( atoms_grid, dofs="|" + " ".join(map(str, dof_key)), grids="|" + " ".join(map(str, grid)), ) self.grid_id[dof_key][grid_key] = _id else: _iter += 1 self.grid_id[dof_key][grid_key] = _iter with open(f"{self.jobname}_grid_id.pkl", "wb") as f: pickle.dump(self.grid_id, f) """ nMR coord can get by (n-1)MR coord. May be too memory consumption""" # Free Memory del self.geometry del self.displace return self.grid_id
[docs] def execute_multiproc( self, calculator: Calculator, max_workers: int | None = None, timeout: float = 60.0, jobname: str | None = None, reset_calc: bool = False, judge_func: Callable | None = None, ): """Execute electronic structure calculation by multiprocessing Args: calculator (Calculator) : calculator for each geomtry max_workers (Optional[int]) : maximum workers in multi-processing. Defaults to None. If None, use cpu_count - 1. timeout (float) : Timeout calculation in second. Defaults to 60.0 jobname (Optional[str]) : jobname reset_calc (Optional[bool]) : set new calculator in any case. Defaults to False. judge_func (Optional[Callable[[Any],bool]]) : judge function whether re-calculation is needed. Defaults to None. """ # noqa: E501 self.calc = calculator if isinstance(jobname, str): self.jobname = jobname if self.jobname is None: raise ValueError("required jobname argument.") self.remain_jobs = deque() self.thrown_jobs = deque() self.done_jobs = deque() # In case of duplicated jobs, we should remove duplicated jobs. if judge_func is None: self.judge_func = lambda row: True else: self.judge_func = judge_func with connect(f"{self.jobname}.db") as db: unique_jobs = dict() count_unique_jobs = 0 for row in db.select(): if self.judge_func(row): dof_key = tuple(from_dbkey(row.dofs)) grid_key_tmp = from_dbkey(row.grids) grid_key = tuple([grid_key_tmp[p] for p in dof_key]) if row.grids not in unique_jobs: unique_jobs[row.grids] = row.id count_unique_jobs += 1 self.remain_jobs.append( ( dof_key, grid_key, row.id, unique_jobs[row.grids], None, ) ) logger.warning(f"unique jobs : {count_unique_jobs}") self.reset_calc = reset_calc if max_workers is None: ncpu = os.cpu_count() assert isinstance(ncpu, int) max_workers = ncpu - 1 wait_process = max_workers logger.warning("START : Electronic Structure Calculations") n = len(self.remain_jobs) if n > 0: with ProcessPoolExecutor(max_workers) as exe: with connect(f"{self.jobname}.db") as db: logger.warning(f"Connected: {self.jobname}.db") for _ in range(min(wait_process, len(self.remain_jobs))): self._throw_job_to_queue(exe, db) for _iter in progressbar(range(n)): self._pick_up_job_from_queue(db, timeout) if self.remain_jobs: self._throw_job_to_queue(exe, db) logger.warning("WAIT : Remaining future task") while self.thrown_jobs: self._pick_up_job_from_queue(db, timeout) logger.warning("DONE : Electronic Structure Calculations") if len(self.done_jobs) == n: logger.warning("Your calculation completely finished!") else: logger.warning( f"Remained {n - len(self.done_jobs)} jobs!" + " you should execute once again " + "with different conditions and judge_func" ) for process in exe._processes.values(): process.kill() logger.warning("DONE : Shutdown process executor")
def _throw_job_to_queue( self, exe, db, ): while self.remain_jobs: dof_key, grid_key, _id, unique_id, error = ( self.remain_jobs.popleft() ) if _id == unique_id: row_unique = db.get(unique_id) if self.judge_func(row_unique) and hasattr( row_unique, "energy" ): atoms = row_unique.toatoms() if not isinstance(atoms.calc, Calculator): atoms.calc.todict = types.MethodType( _todict, atoms.calc ) db.update(_id, atoms) self.done_jobs.append((dof_key, grid_key, _id)) continue atoms = db.get_atoms(_id) if atoms.calc is None or self.reset_calc: atoms.calc = deepcopy(self.calc) if isinstance(atoms.calc, Calculator): # Gaussian etc atoms.calc.set_label(f"{self.jobname}/{_id:07}/calc") elif isinstance(atoms.calc, GenericFileIOCalculator): # Orca etc atoms.calc.directory = Path(f"{self.jobname}/{_id:07}/calc") else: raise NotImplementedError future = exe.submit(_run, atoms) self.thrown_jobs.append( (future, dof_key, grid_key, _id, time.time()) ) break def _pick_up_job_from_queue(self, db, timeout: float = 60.0): while True: if self.thrown_jobs: future, dof_key, grid_key, _id, start_time = ( self.thrown_jobs.popleft() ) else: break if future.done(): if (error := future.exception()) is None: atoms = future.result()[0] if not isinstance(atoms.calc, Calculator): atoms.calc.todict = types.MethodType( _todict, atoms.calc ) db.update(_id, atoms=atoms) self.done_jobs.append((dof_key, grid_key, _id)) else: logger.warn( f"ERROR:\t{error}\tDOFs={dof_key}\tgrids={grid_key}\tID={_id}" ) if os.path.exists(f"{self.jobname}/{_id:07}"): shutil.rmtree(f"{self.jobname}/{_id:07}") break elif time.time() - start_time > timeout: try: if not future.running(): future.cancel() # Wait 1.0 second and discard the job future.result(1.0) except Exception: # TimeoutError or CancelledError pass logger.warning( f"TIMEOUT:\tDOFs={dof_key}\tgrids={grid_key}\tID={_id} " ) if os.path.exists(f"{self.jobname}/{_id:07}"): shutil.rmtree(f"{self.jobname}/{_id:07}") break else: time.sleep(1) self.thrown_jobs.append( (future, dof_key, grid_key, _id, start_time) )