"""
Utils for AiiDA.
Utilities for making working against AiiDA a bit easier. Mostly here due to
historical reasons when AiiDA was rapidly developed. In the future most routines
that have now standardized in AiiDA will be removed.
"""
# ruff: noqa: PLC0415
from __future__ import annotations
import warnings
from functools import wraps
from typing import Any, Callable
import numpy as np
from aiida import __version__ as aiida_version_
from aiida import orm
from aiida.common.exceptions import MissingEntryPointError
from aiida.orm import AuthInfo, QueryBuilder, User, load_node
from aiida.plugins import DataFactory
from packaging import version
BASIC_DATA_TYPES: list[str] = ['core.bool', 'core.float', 'core.int', 'core.list', 'core.str', 'core.dict']
[docs]
def querybuild(cls: type, **kwargs: Any) -> QueryBuilder:
"""
Instantiates and returns a QueryBuilder instance.
The QueryBuilder's path has one vertice so far, namely this class.
Additional parameters (e.g. filters or a label),
can be passes as keyword arguments.
:param label: Label to give
:param filters: filters to apply
:param project: projections
:returns: a QueryBuilder instance.
"""
query_builder = QueryBuilder()
filters = kwargs.pop('filters', {})
query_builder.append(cls, filters=filters, **kwargs)
return query_builder
[docs]
def get_data_class(data_type: str) -> type:
"""Provide access to the orm.data classes with deferred dbenv loading."""
data_cls = None
try:
data_cls = DataFactory(data_type)
except MissingEntryPointError as err:
raise err
return data_cls
[docs]
def get_current_user() -> User:
"""Get current user."""
current_user = User.collection.get_default()
return current_user
[docs]
def copy_parameter(old_parameter: orm.Dict) -> orm.Dict:
"""Assemble a new Dict."""
return orm.Dict(dict=old_parameter.get_dict())
[docs]
def displaced_structure(structure: orm.StructureData, displacement: np.ndarray, entry: int) -> orm.StructureData:
disp_structure = structure.clone()
displace_position(disp_structure, displacement, entry)
return disp_structure
[docs]
def compressed_structure(structure: orm.StructureData, volume_change: float) -> orm.StructureData:
comp_structure = structure.clone()
compress_cell(comp_structure, volume_change)
return comp_structure
[docs]
def displace_position(structure: orm.StructureData, displacement: np.ndarray, entry: int) -> None:
"""Displace a position in the StructureData."""
sites = structure.sites
positions = []
for site in sites:
positions.append(site.position)
new_position = np.asarray(positions[entry - 1]) + displacement
new_position = new_position.tolist()
positions[entry - 1] = tuple(new_position)
structure.reset_sites_positions(positions)
[docs]
def compress_cell(structure: orm.StructureData, volume_change: float) -> None:
"""Apply compression or tensile forces to the unit cell."""
cell = structure.cell
new_cell = np.array(cell) * volume_change
structure.reset_cell(new_cell.tolist())
[docs]
def aiida_version() -> version.Version:
return version.parse(aiida_version_)
[docs]
def cmp_version(string: str) -> version.Version:
return version.parse(string)
[docs]
def cmp_load_verdi_data() -> Any:
"""Load the verdi data click command group for any version since 0.11."""
verdi_data = None
import_errors = []
try:
from aiida.cmdline.commands import data_cmd as verdi_data
except ImportError as err:
import_errors.append(err)
if not verdi_data:
try:
from aiida.cmdline.commands import verdi_data
except ImportError as err:
import_errors.append(err)
if not verdi_data:
try:
from aiida.cmdline.commands.cmd_data import verdi_data
except ImportError as err:
import_errors.append(err)
if not verdi_data:
err_messages = '\n'.join([f' * {err}' for err in import_errors])
raise ImportError('The verdi data base command group could not be found:\n' + err_messages)
return verdi_data
[docs]
def create_authinfo(computer: orm.Computer, store: bool = False) -> AuthInfo:
"""Allow the current user to use the given computer."""
authinfo = AuthInfo(computer=computer, user=get_current_user())
if store:
authinfo.store()
return authinfo
[docs]
def cmp_get_authinfo(computer: orm.Computer) -> AuthInfo | None:
"""Get an existing authinfo or None for the given computer and current user."""
return computer.get_authinfo(get_current_user())
[docs]
def cmp_get_transport(computer: orm.Computer) -> Any:
if hasattr(computer, 'get_transport'):
return computer.get_transport()
authinfo = cmp_get_authinfo(computer)
return authinfo.get_transport()
[docs]
def ensure_node_first_arg(func: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator to load a node if it is passed as a string."""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Make sure the first node is a Node instance."""
if len(args) > 0:
node = args[0]
if not isinstance(node, orm.Node):
node = load_node(node)
args = list(args)
args[0] = node
return func(*args, **kwargs)
return wrapper
[docs]
def ensure_node_kwargs(func: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator to load a node if it is passed as a key word argument ends with 'node'."""
@wraps(func)
def wrapper(node: Any, *args: Any, **kwargs: Any) -> Any:
"""Make sure the key world arguments ends with '_node' node is a Node instance."""
new_kwargs = dict(kwargs)
for name, value in kwargs.items():
if name.endswith('node'):
if not isinstance(value, orm.Node):
new_kwargs[name] = load_node(value)
return func(node, *args, **new_kwargs)
return wrapper
[docs]
def convert_dict_case(
dict_in: dict[str, Any],
recursive: bool = True,
warn: bool = False,
lower: bool = True,
raise_convert: bool = False,
) -> dict[str, Any]:
"""
Recursively convert the keys of a dictionary to lower or upper cases, returns a new dictionary.
:param dict_in: The input dictionary whose keys need to be converted.
:param recursive: If True, the function will recursively convert keys in nested dictionaries.
:param warn: If True, the function will print a warning if a key is converted.
:param lower: If True, convert keys to lowercase; otherwise, convert to uppercase.
:param raise_convert: If True, raise an error if a key is converted.
:return: A new dictionary with keys converted to the specified case.
"""
converted_dict = {}
for key, value in dict_in.items():
new_key = key.lower() if lower else key.upper()
if new_key != key:
expected = 'upper' if lower is False else 'lower'
if warn:
expected = 'upper' if lower is False else 'lower'
warnings.warn(f"Key '{key}' converted to '{new_key}' - please use {expected} case keys")
if raise_convert:
raise ValueError(f"Key '{key}' converted to '{new_key}' - please use {expected} case keys")
if recursive and isinstance(value, dict):
converted_dict[new_key] = convert_dict_case(value, recursive, warn, lower, raise_convert)
else:
converted_dict[new_key] = value
return converted_dict