"""
Bands workchain with a more flexible input
TODO:
- Add option to use alternative pathways obtained using sumo-interface
- Improve the hybrid workchain by performing local dryrun to extract the full kpoints
- If running SOC, the ISYM should be turned to 0 or -1.
"""
from copy import deepcopy
from logging import getLogger
from typing import Any, List
import numpy as np
from aiida import orm
from aiida.common.extendeddicts import AttributeDict
from aiida.common.links import LinkType
from aiida.engine import WorkChain, append_, calcfunction, if_
from aiida.orm.nodes.data.base import to_aiida_type
from aiida.plugins import WorkflowFactory
from aiida.tools import get_explicit_kpoints_path
from ase.geometry import cell_to_cellpar
from aiida_vasp.common import OVERRIDE_NAMESPACE
from aiida_vasp.common.dryrun import dryrun_relax_builder
from aiida_vasp.common.transform import magnetic_structure_decorate, magnetic_structure_dedecorate
from aiida_vasp.data.chargedensity import ChargedensityData
from aiida_vasp.parsers.content_parsers.vasprun import VasprunParser
from aiida_vasp.protocols import ProtocolMixin, recursive_merge
from aiida_vasp.utils.extended_dicts import update_nested_dict, update_nested_dict_node
from aiida_vasp.utils.kmesh import get_ir_kpoints_data
from aiida_vasp.utils.opthold import BandOptions
from .mixins import WithBuilderUpdater
from .relax import VaspRelaxWorkChain
from .vasp import VaspWorkChain
SITE_MAG_THRESHOLD = 0 # Threshold for considering a site to be magnetic
logger = getLogger(__name__)
[docs]
class VaspBandsWorkChain(WorkChain, WithBuilderUpdater, ProtocolMixin):
"""
Workchain for running bands calculations.
This workchain does the following:
1. Relax the structure if requested (eg. inputs passed to the relax namespace).
2. Do a SCF singlepoint calculation.
3. Do a non-scf calculation for bands and dos.
Inputs must be passed for the SCF calculation, others are optional. The dos calculation
will only run if the kpoints for DOS are passed or a full specification is given under the
`dos` input namesace.
The SCF calculation may be skipped by passing a CHGCAR file/remote folder. In which case the SCF inputs
are carried on for non-scf calculations.
The band structure calculation will run unless `only_dos` is set to `Bool(True)`.
For magnetic structures, the workchain will internally create additional species for the symmetry
analysis and regenerate "undecorated" structures with corresponding initial magnetic moments. This
works for both FM and AFM species. Care should be taken when the MAGMOM is obtained from site projected
values in case of unexpected symmetry breaking.
Input for bands and dos calculations are optional. However, if they are needed, the full list of inputs must
be passed. For the `parameters` node, one may choose to only specify those fields that need to be updated.
For optics calculations, one should run with `only_dos`, set 'NBANDS' to a high value and
set 'LOPTICS' to be True.
"""
_base_wk_string = 'vasp.v2.vasp'
_base_workchain = VaspWorkChain
_relax_wk_string = 'vasp.v2.relax'
_relax_workchain = VaspRelaxWorkChain
_protocol_tag = 'band'
option_class = BandOptions
[docs]
@classmethod
def define(cls, spec: Any) -> None:
"""Initialise the WorkChain class"""
super().define(spec)
relax_work = WorkflowFactory(cls._relax_wk_string)
base_work = WorkflowFactory(cls._base_wk_string)
spec.input('structure', help='The input structure', valid_type=orm.StructureData)
spec.input(
'bs_kpoints',
help='Explicit kpoints for the bands. Will not generate kpoints if supplied.',
valid_type=orm.KpointsData,
required=False,
)
spec.input(
'band_settings',
help=BandOptions.aiida_description(),
valid_type=orm.Dict,
validator=BandOptions.aiida_validate,
serializer=BandOptions.aiida_serialize,
)
spec.expose_inputs(
relax_work,
namespace='relax',
exclude=('structure',),
namespace_options={
'required': False,
'populate_defaults': False,
'help': 'Inputs for Relaxation workchain, if needed',
},
)
spec.expose_inputs(
base_work,
namespace='scf',
exclude=('structure',),
namespace_options={
'required': True,
'populate_defaults': True,
'help': 'Inputs for SCF workchain, mandatory',
},
)
spec.expose_inputs(
base_work,
namespace='bands',
exclude=('structure', 'kpoints'),
namespace_options={
'required': False,
'populate_defaults': False,
'help': 'Inputs for bands calculation, if needed',
},
)
spec.expose_inputs(
base_work,
namespace='dos',
exclude=('structure',),
namespace_options={
'required': False,
'populate_defaults': False,
'help': 'Inputs for DOS calculation, if needed',
},
)
spec.input(
'clean_children_workdir',
valid_type=orm.Str,
serializer=to_aiida_type,
help='What part of the called children to clean',
required=False,
default=lambda: orm.Str('none'),
)
spec.input(
'chgcar',
required=False,
valid_type=ChargedensityData,
help='Explicit CHGCAR file used for DOS/Bands calculations',
)
spec.input(
'restart_folder',
required=False,
valid_type=orm.RemoteData,
help='A remote folder containing the CHGCAR file to be used',
)
spec.outline(
cls.setup,
if_(cls.should_do_relax)(
cls.run_relax,
cls.verify_relax,
),
if_(cls.should_generate_path)(cls.generate_path),
if_(cls.should_run_scf)(
cls.run_scf,
cls.verify_scf,
),
cls.run_bands_dos,
cls.inspect_bands_dos,
)
spec.output(
'primitive_structure',
required=False,
help='Primitive structure used for band structure calculations',
)
spec.output('band_structure', required=False, help='Computed band structure with labels')
spec.output('seekpath_parameters', help='Parameters used by seekpath', required=False)
spec.output('dos', required=False)
spec.output('projectors', required=False)
spec.exit_code(501, 'ERROR_SUB_PROC_RELAX_FAILED', message='Relaxation workchain failed')
spec.exit_code(502, 'ERROR_SUB_PROC_SCF_FAILED', message='SCF workchain failed')
spec.exit_code(
503,
'ERROR_SUB_PROC_BANDS_FAILED',
message='Band structure workchain failed',
)
spec.exit_code(504, 'ERROR_SUB_PROC_DOS_FAILED', message='DOS workchain failed')
spec.exit_code(
601,
'ERROR_INPUT_STRUCTURE_NOT_PRIMITIVE',
message='The input structure is not the primitive one!',
)
[docs]
def get_appended_label(self, suffix):
"""Return a label with appended suffix"""
return (self.inputs.metadata.get('label', '') or '') + ' ' + suffix
[docs]
@classmethod
def get_builder_from_protocol(
cls,
code: orm.AbstractCode,
structure: orm.StructureData,
protocol=None,
run_relax=True,
overrides=None,
options=None,
band_settings=None,
**kwargs,
):
overrides = overrides or {}
inputs = cls.get_protocol_inputs(protocol, overrides)
if band_settings:
overrides['band_settings'] = recursive_merge(overrides.get('band_settings'), band_settings)
scf_builder = cls._base_workchain.get_builder_from_protocol(
code=code,
structure=structure,
protocol=inputs.get('scf', {}).get('protocol', protocol),
overrides=inputs.get('scf', {}),
options=options,
**kwargs,
)
# Configure the relaxation step of the workchain
if run_relax:
relax_builder = cls._relax_workchain.get_builder_from_protocol(
code=code,
structure=structure,
protocol=inputs.get('relax', {}).get('protocol', protocol),
overrides=inputs.get('relax', {}),
options=options,
**kwargs,
)
relax_builder.pop('structure')
else:
relax_builder = None
scf_builder.pop('structure')
builder = cls.get_builder()
builder.scf = scf_builder
builder.structure = structure
if relax_builder is not None:
builder.relax = relax_builder
if inputs.get('band_settings'):
builder.band_settings = inputs.get('band_settings')
if inputs.get('clean_children_workdir'):
builder.clean_children_workdir = inputs.get('clean_children_workdir')
return builder
[docs]
def setup(self) -> None:
"""Setup the calculation"""
self.ctx.current_structure = self.inputs.structure
self.ctx.bs_kpoints = self.inputs.get('bs_kpoints')
param = self.inputs.scf.parameters.get_dict()
if 'magmom' in param[OVERRIDE_NAMESPACE] and not self.inputs.band_settings['only_dos']:
self.report('Magnetic system passed for BS')
self.ctx.magmom = param[OVERRIDE_NAMESPACE]['magmom']
else:
self.ctx.magmom = None
[docs]
def should_do_relax(self) -> bool:
"""Wether we should do relax or not"""
return 'relax' in self.inputs
[docs]
def run_relax(self) -> Any:
"""Run the relaxation"""
relax_work = WorkflowFactory(self._relax_wk_string)
inputs = self.exposed_inputs(relax_work, 'relax', agglomerate=True)
inputs = AttributeDict(inputs)
inputs.metadata.call_link_label = 'relax'
inputs.metadata.label = self.get_appended_label('RELAX')
inputs.structure = self.ctx.current_structure
# Ensure the WAVECAR is written by the calculation
if self.inputs.band_settings.get('hybrid_reuse_wavecar', False):
pdict = inputs.vasp.parameters.get_dict()
# Update the relax settings so we do not clean the final singepoint calculation
rdict = inputs.relax_settings.get_dict()
rdict['keep_sp_workdir'] = True
if rdict != inputs.relax_settings.get_dict():
inputs.relax_settings = orm.Dict(dict=rdict)
pdict['incar']['lwave'] = True
if pdict != inputs.vasp.parameters.get_dict():
inputs.vasp.parameters = orm.Dict(dict=pdict)
running = self.submit(relax_work, **inputs)
return self.to_context(workchain_relax=running)
[docs]
def verify_relax(self) -> Any:
"""Verify the relaxation"""
relax_workchain = self.ctx.workchain_relax
if not relax_workchain.is_finished_ok:
self.report('Relaxation finished with Error')
return self.exit_codes.ERROR_SUB_PROC_RELAX_FAILED
# Use the relaxed structure as the current structure
self.ctx.current_structure = relax_workchain.outputs.relax.structure
[docs]
def should_run_scf(self) -> bool:
"""Wether we should run SCF calculation"""
# Setup the CHGCAR and remote folder input if necessary
self.select_chgcar_from_inputs()
# Only need to run SCF calculation when no explicity CHGCAR or folder set
return not (self.ctx.chgcar or self.ctx.restart_folder)
[docs]
def should_generate_path(self) -> bool:
"""
Seekpath should only run if no explicit bands is provided or we are just
running for DOS, in which case the original structure is used.
"""
return 'bs_kpoints' not in self.inputs and (not self.inputs.band_settings['only_dos'])
[docs]
def generate_path(self) -> None:
"""
Run seekpath to obtain the primitive structure and bands
"""
current_structure_backup = self.ctx.current_structure
mode = self.inputs.band_settings['band_mode']
if mode == 'seekpath-aiida':
inputs = {
'band_settings': orm.Dict(
{
'reference_distance': self.inputs.band_settings['band_kpoints_distance'],
'symprec': self.inputs.band_settings['symprec'],
**self.inputs.band_settings['additional_band_analysis_parameters'],
}
),
'metadata': {'call_link_label': 'seekpath'},
}
func = seekpath_structure_analysis
else:
# Using sumo interface
try:
from aiida_vasp.common.sumo_kpath import kpath_from_sumo_v2 # noqa: PLC0415
except ImportError:
raise ImportError('Sumo is not installed, please install it to use this feature.')
inputs = {
'band_settings': orm.Dict(
{
'line_density': self.inputs.band_settings['line_density'],
'symprec': self.inputs.band_settings['symprec'],
'mode': mode,
**self.inputs.band_settings['additional_band_analysis_parameters'],
}
),
'metadata': {'call_link_label': 'sumo_kpath'},
}
func = kpath_from_sumo_v2
magmom = self.ctx.get('magmom', None)
# For magnetic structures, create different kinds for the analysis in case that the
# symmetry should be lowered. This also makes sure that the magnetic moments are consistent
if magmom:
decorate_result = magnetic_structure_decorate(self.ctx.current_structure, orm.List(list=magmom))
decorated = decorate_result['structure']
# Run seekpath on the decorated structure
kpath_results = func(decorated, **inputs)
decorated_primitive = kpath_results['primitive_structure']
# Convert back to undecorated structures and add consistent magmom input
dedecorate_result = magnetic_structure_dedecorate(decorated_primitive, decorate_result['mapping'])
self.ctx.magmom = dedecorate_result['magmom'].get_list()
self.ctx.current_structure = dedecorate_result['structure']
else:
kpath_results = func(self.ctx.current_structure, **inputs)
self.ctx.current_structure = kpath_results['primitive_structure']
if not np.allclose(self.ctx.current_structure.cell, current_structure_backup.cell):
if self.inputs.scf.get('kpoints'):
self.report(
'The primitive structure is not the same as the input structure but explicit kpoints are supplied'
' - aborting the workchain.'
)
return self.exit_codes.ERROR_INPUT_STRUCTURE_NOT_PRIMITIVE # pylint: disable=no-member
self.report(
'The primitive structure is not the same as the input structure - using the former for all calculations'
' from now.'
)
self.ctx.bs_kpoints = kpath_results['explicit_kpoints']
self.out('primitive_structure', self.ctx.current_structure)
if 'parameters' in kpath_results:
self.out('seekpath_parameters', kpath_results['parameters'])
[docs]
def run_scf(self) -> None:
"""
Run the SCF calculation
"""
base_work = WorkflowFactory(self._base_wk_string)
inputs = AttributeDict(self.exposed_inputs(base_work, namespace='scf'))
inputs.metadata.call_link_label = 'scf'
inputs.metadata.label = self.get_appended_label('SCF')
inputs.structure = self.ctx.current_structure
# Turn off cleaning of the working directory
if not inputs.get('keep_last_workdir', False):
inputs.keep_last_workdir = orm.Bool(True)
# Ensure that writing the CHGCAR file is on
pdict = inputs.parameters.get_dict()
if (pdict[OVERRIDE_NAMESPACE].get('lcharg') is False) or (pdict[OVERRIDE_NAMESPACE].get('LCHARG') is False):
pdict[OVERRIDE_NAMESPACE]['lcharg'] = True
inputs.parameters = orm.Dict(dict=pdict)
self.report('Correction: setting LCHARG to True')
# Take magmom from the context, in case that the magmom is rearranged in the primitive cell
magmom = self.ctx.get('magmom')
if magmom:
inputs.parameters = update_nested_dict_node(inputs.parameters, {OVERRIDE_NAMESPACE: {'magmom': magmom}})
running = self.submit(base_work, **inputs)
self.report(f'Running SCF calculation {running}')
self.to_context(workchain_scf=running)
[docs]
def verify_scf(self) -> Any:
"""Inspect the SCF calculation"""
scf_workchain = self.ctx.workchain_scf
if not scf_workchain.is_finished_ok:
self.report('SCF workchain finished with Error')
return self.exit_codes.ERROR_SUB_PROC_SCF_FAILED
# Store the charge density or remote reference
if 'chgcar' in scf_workchain.outputs:
self.ctx.chgcar = scf_workchain.outputs.chgcar
else:
self.ctx.chgcar = None
self.ctx.restart_folder = scf_workchain.outputs.remote_folder
self.report(f'SCF calculation {scf_workchain} completed')
[docs]
def run_bands_dos(self) -> Any:
"""Run the bands and the DOS calculations"""
base_work = WorkflowFactory(self._base_wk_string)
# Use the SCF inputs as the base
inputs = AttributeDict(self.exposed_inputs(base_work, namespace='scf'))
inputs.structure = self.ctx.current_structure
if self.ctx.restart_folder:
inputs.restart_folder = self.ctx.restart_folder
if self.ctx.chgcar:
inputs.chgcar = self.ctx.chgcar
if not (inputs.get('restart_folder') or inputs.get('chgcar')):
raise RuntimeError('One of the restart_folder or chgcar must be set for non-scf calculations')
running = {}
only_dos = self.inputs.band_settings['only_dos']
if only_dos is False:
if 'bands' in self.inputs:
bands_input = AttributeDict(self.exposed_inputs(base_work, namespace='bands'))
else:
bands_input = AttributeDict(
{
'settings': orm.Dict(dict={'parser_settings': {'include_node': ['bands']}}),
'parameters': orm.Dict(dict={'charge': {'constant_charge': True}}),
}
)
# Special treatment - combine the parameters
parameters = inputs.parameters.get_dict()
bands_parameters = bands_input.parameters.get_dict()
if 'charge' in bands_parameters:
bands_parameters['charge']['constant_charge'] = True
else:
bands_parameters['charge'] = {'constant_charge': True}
update_nested_dict(parameters, bands_parameters)
# Apply updated parameters
inputs.update(bands_input)
inputs.parameters = orm.Dict(dict=parameters)
# Check if add_bands
settings = inputs.get('settings')
essential = {'parser_settings': {'include_node': ['bands']}}
if settings is None:
inputs.settings = orm.Dict(dict=essential)
else:
inputs.settings = update_nested_dict_node(settings, essential, extend_list=True)
# Swap with the default kpoints generated
inputs.kpoints = self.ctx.bs_kpoints
# Tag the calculation
inputs.metadata.label = self.get_appended_label('BS')
inputs.metadata.call_link_label = 'bs'
bands_calc = self.submit(base_work, **inputs)
running['bands_workchain'] = bands_calc
self.report(f'Submitted workchain {bands_calc} for band structure')
# Do DOS calculation if dos input namespace is populated or a
# dos_kpoints input is passed.
if (self.inputs.band_settings['run_dos']) or ('dos' in self.inputs):
if 'dos' in self.inputs:
dos_input = AttributeDict(self.exposed_inputs(base_work, namespace='dos'))
else:
dos_input = AttributeDict(
{
'parameters': orm.Dict(dict={'charge': {'constant_charge': True}}),
}
)
# Use the supplied kpoints density for DOS
dos_kpoints = orm.KpointsData()
dos_kpoints.set_cell_from_structure(self.ctx.current_structure)
dos_kpoints.set_kpoints_mesh_from_density(self.inputs.band_settings['dos_kpoints_distance'] * 2 * np.pi)
dos_input.kpoints = dos_kpoints
# Special treatment - combine the parameters
parameters = inputs.parameters.get_dict()
dos_parameters = dos_input.parameters.get_dict()
update_nested_dict(parameters, dos_parameters)
# Ensure we start from constant charge
if 'charge' in dos_parameters:
dos_parameters['charge']['constant_charge'] = True
else:
dos_parameters['charge'] = {'constant_charge': True}
# Apply updated parameters
inputs.update(dos_input)
inputs.parameters = orm.Dict(dict=parameters)
if 'dos' not in self.inputs:
# kindly add `add_dos` if the `dos` input namespace is not
# explicitly defined.
settings = inputs.get('settings')
essential = {'parser_settings': {'include_node': ['dos', 'bands']}}
if settings is None:
inputs.settings = orm.Dict(dict=essential)
else:
inputs.settings = update_nested_dict_node(settings, essential, extend_list=True)
# Set the label
inputs.metadata.label = self.get_appended_label('DOS')
inputs.metadata.call_link_label = 'dos'
dos_calc = self.submit(base_work, **inputs)
running['dos_workchain'] = dos_calc
self.report(f'Submitted workchain {dos_calc} for DOS')
return self.to_context(**running)
[docs]
def inspect_bands_dos(self) -> Any:
"""Inspect the bands and dos calculations"""
exit_code = None
if 'bands_workchain' in self.ctx:
bands = self.ctx.bands_workchain
if not bands.is_finished_ok:
self.report(f'Bands calculation finished with error, exit_status: {bands}')
exit_code = self.exit_codes.ERROR_SUB_PROC_BANDS_FAILED
self.out(
'band_structure',
compose_labelled_bands(bands.outputs.bands, bands.inputs.kpoints),
)
else:
bands = None
if 'dos_workchain' in self.ctx:
dos = self.ctx.dos_workchain
if not dos.is_finished_ok:
self.report(f'DOS calculation finished with error, exit_status: {dos.exit_status}')
exit_code = self.exit_codes.ERROR_SUB_PROC_DOS_FAILED
# Attach outputs
self.out('dos', dos.outputs.dos)
if 'projectors' in dos.outputs:
self.out('projectors', dos.outputs.projectors)
else:
dos = None
return exit_code
[docs]
def on_terminated(self) -> None:
"""
Clean the remote directories of all called childrens
"""
super().on_terminated()
if self.inputs.clean_children_workdir.value != 'none':
cleaned_calcs = []
for called_descendant in self.node.called_descendants:
if isinstance(called_descendant, orm.CalcJobNode):
try:
called_descendant.outputs.remote_folder._clean() # pylint: disable=protected-access
cleaned_calcs.append(called_descendant.pk)
except (OSError, KeyError):
pass
if cleaned_calcs:
self.report(f'cleaned remote folders of calculations: {" ".join(map(str, cleaned_calcs))}')
[docs]
@calcfunction
def seekpath_structure_analysis(structure: orm.StructureData, band_settings: orm.Dict) -> Any:
"""Primitivize the structure with SeeKpath and generate the high symmetry k-point path through its Brillouin zone.
This calcfunction will take a structure and pass it through SeeKpath to get the normalized primitive cell and the
path of high symmetry k-points through its Brillouin zone. Note that the returned primitive cell may differ from the
original structure in which case the k-points are only congruent with the primitive cell.
The keyword arguments can be used to specify various Seekpath parameters, such as:
- with_time_reversal: True
- reference_distance: 0.025
- recipe: 'hpkot'
- threshold: 1e-07
- symprec: 1e-05
- angle_tolerance: -1.0
Note that exact parameters that are available and their defaults will depend on your Seekpath version.
"""
# All keyword arugments should be `Data` node instances of base type and so should have the `.value` attribute
return get_explicit_kpoints_path(structure, **band_settings.get_dict())
[docs]
@calcfunction
def compose_labelled_bands(bands: orm.BandsData, kpoints: orm.KpointsData) -> orm.BandsData:
"""
Add additional information from the kpoints allow richer informations
to be stored such as band structure labels.
"""
new_bands = deepcopy(bands)
new_bands.set_kpointsdata(kpoints)
return new_bands
[docs]
@calcfunction
def get_primitive_strucrture_and_scf_kpoints(structure: orm.StructureData) -> Any:
"""
This function dryruns a VASP calculation using the primitive structure obtained by performing seekpath analyses
The input StructureData should be returned by an VaspRelaxWorkChain which will be used for dryun using local
VASP and getting the explicity kpoints for SCF calculation.
"""
# Locate the relaxation work
# Locate the relaxation work
relax_work = structure.base.links.get_incoming(link_label_filter='relax__structure').one().node
primitive = get_explicit_kpoints_path(structure)['primitive_structure']
# Create an restart builder
builder = relax_work.get_builder_restart()
builder.structure = primitive
# Dryrun and construct the SCF kpoints
kpoint_weights = np.array(dryrun_relax_builder(builder)['kpoints_and_weights'])
scf_kpoints = orm.KpointsData()
scf_kpoints.set_kpoints(kpoint_weights[:, :3], weights=kpoint_weights[:, -1])
return {'primitive': primitive, 'scf_kpoints': scf_kpoints}
[docs]
class VaspHybridBandsWorkChain(VaspBandsWorkChain):
"""
Bands workchain for hybrid calculations
This workchain compute the bandstructure by adding band path segments as zero-weighted
kpoints for self-consistent calculations. This is mainly for hybrid calculations, but can
also be used for GGA calculations, although it would be not as efficient as the non-SCF
method implemented in ``VaspBandsWorkChain``.
In contrast to ``VaspBandsWorkChain`` this workflow requires and explicitly defined kpoints
set for the ``scf.kpoints`` port. This can be obtained by parsing the ``IBZKPT`` file from
and existing calculation or dryrun. Or by parsing the ``vasprun.xml`` file.
If a relaxation workchain is run as part of the process, the ``kpoints`` output returned can
be used for this purpose automatically.
Only the `scf` namespace will be used for performing the calculation
TODO:
- Warn if the calculation is not actually a hybrid one
- Automatic Kpoints from dryruns
"""
[docs]
@classmethod
def define(cls, spec: Any) -> None:
"""Initialise the WorkChain class"""
super().define(spec)
relax_work = WorkflowFactory(cls._relax_wk_string)
base_work = WorkflowFactory(cls._base_wk_string)
spec.input('structure', help='The input structure', valid_type=orm.StructureData)
spec.expose_inputs(
relax_work,
namespace='relax',
exclude=('structure',),
namespace_options={
'required': False,
'populate_defaults': False,
'help': 'Inputs for Relaxation workchain, if needed',
},
)
spec.expose_inputs(
base_work,
namespace='scf',
exclude=('structure',),
namespace_options={
'required': True,
'populate_defaults': True,
'help': 'Inputs for SCF workchain, mandatory',
},
)
spec.input(
'clean_children_workdir',
valid_type=orm.Str,
serializer=to_aiida_type,
help='What part of the called children to clean',
required=False,
default=lambda: orm.Str('none'),
)
spec.outline(
cls.setup,
if_(cls.should_do_relax)(
cls.run_relax,
cls.verify_relax,
),
if_(cls.should_generate_path)(cls.generate_path),
# Find the SCF kpoints from geometry optimisation
if_(cls.no_scf_kpoints)(cls.get_scf_kpoints_relax),
# Generate the SCF kpoints using spglib
if_(cls.no_scf_kpoints)(cls.get_scf_kpoints_spglib),
# If the above fails, we need to run a SCF calculation to get the kpoints
if_(cls.no_scf_kpoints)(cls.run_scf_for_kpoints, cls.verify_scf_for_kpoints),
cls.make_splitted_kpoints, # Split the kpoints
cls.run_scf_multi, # Launch split calculation
cls.inspect_and_combine_bands, # Combined the band structure
)
spec.output(
'primitive_structure',
required=False,
help='Primitive structure used for band structure calculations',
)
spec.output('band_structure', required=False, help='Computed band structure with labels')
spec.output('seekpath_parameters', help='Parameters used by seekpath', required=False)
spec.exit_code(501, 'ERROR_SUB_PROC_RELAX_FAILED', message='Relaxation workchain failed')
spec.exit_code(502, 'ERROR_SUB_PROC_SCF_FAILED', message='SCF workchain failed')
spec.exit_code(
503,
'ERROR_SUB_PROC_BANDS_FAILED',
message='Band structure workchain failed',
)
spec.exit_code(504, 'ERROR_SUB_PROC_DOS_FAILED', message='DOS workchain failed')
spec.exit_code(
505,
'ERROR_NO_VALID_SCF_KPOINTS_INPUT',
message='Cannot found valid inputs for SCF kpoints',
)
spec.exit_code(
601,
'ERROR_INPUT_STRUCTURE_NOT_PRIMITIVE',
message='The input structure is not the primitive one!',
)
[docs]
def setup(self) -> None:
super().setup()
self.ctx.scf_kpoints = None
if 'kpoints' in self.inputs.scf:
self.ctx.scf_kpoints = self.inputs.scf.kpoints
[docs]
def no_scf_kpoints(self) -> bool:
"""Check if the kpoints for SCF has NOT been set"""
if self.ctx.scf_kpoints is None:
return True
return False
[docs]
def get_scf_kpoints_relax(self) -> None:
"""Try extract SCF kpoints from relaxation workchain"""
if 'workchain_relax' not in self.ctx:
self.report('No workchain_relax found in context - skip extracting scf kpoints from previous calculation')
return
# Check if symmetry is consistent
incar_relax = self.ctx['workchain_relax'].inputs.vasp.parameters['incar']
incar_scf = self.inputs.scf.parameters['incar']
if incar_scf.get('isym', 2) != incar_relax.get('isym', 2):
self.report(
'The symmetry of the SCF calculation is not consistent with the relaxation calculation. '
'Cannot reuse the IBZ kpoitns from the relaxation calculation.'
)
return
# Check if the scf_kpoints is consistent with the current structure
relaxed = self.ctx.workchain_relax.outputs.relax.structure
kpt_cell = np.array(relaxed.cell)
current_cell = np.array(self.ctx.current_structure.cell)
assert kpt_cell.shape == (3, 3)
assert current_cell.shape == (3, 3)
# Compute the cell parameters
par1 = cell_to_cellpar(kpt_cell)
par2 = cell_to_cellpar(current_cell)
if not np.allclose(par1, par2, 1e-5):
self.report(
'Cell of the last relaxation step does not match the current structure.SCF kpoints cannot be used'
)
return
# Extract the kpoints from the relaxation calculation.
if 'kpoints' in self.ctx['workchain_relax'].outputs:
self.report(f'Using output from <{self.ctx.workchain_relax}> for SCF kpoints.')
self.ctx.scf_kpoints = self.ctx.workchain_relax.outputs.kpoints
else:
# Parse from relaxation output
# Try getting the kpoints from the retrieved folder
self.report(f'Extracted SCF kpoints from retrieved vasprun.xml of <{self.ctx.workchain_relax}>.')
self.ctx.scf_kpoints = extract_kpoints_from_calc(self.ctx.workchain_relax)
[docs]
def get_scf_kpoints_spglib(self) -> None:
"""
Generate SCF kpoints using spglib
"""
incar_scf = self.inputs.scf.parameters['incar']
magmom = incar_scf.get('magmom', None)
symmetry_reduce = incar_scf.get('isym', 2) > 0
if symmetry_reduce and magmom is not None:
# Check if symmetry is broken by magnetic moments
species = self.inputs.structure.get_ase().get_chemical_symbols()
if isinstance(magmom, str):
magmom = magmom.split()
assert len(magmom) == len(species), (
f'Mismatch between the magmom ({len(magmom)}) and the number of atoms ({len(species)}).'
)
if len(set(zip(magmom, species))) != len(set(species)):
self.report('Symmetry is broken by magnetic magmoms, cannot use spglib to generate the kpoints')
return
kpt = get_ir_kpoints_data(
self.ctx.current_structure,
self.inputs.scf.kpoints_spacing * np.pi * 2, # Note that aiida-vasp assumes a 2pi factor in the unit
symprec=self.inputs.band_settings.get('symprec', 1e-5),
symmetry_reduce=symmetry_reduce,
)
self.ctx.scf_kpoints = kpt
[docs]
def make_splitted_kpoints(self) -> Any:
"""Split the kpoints"""
# Fully specified band structure kpoints
full_kpoints = self.ctx.bs_kpoints
scf_kpoints = self.ctx.scf_kpoints
if scf_kpoints is None:
self.report('No valid SCF kpoints is avaliable to use. Please define scf.kpoints explicitly!')
return self.exit_codes.ERROR_NO_VALID_SCF_KPOINTS_INPUT # pylint: disable=no-member
# Number of kpoints per split, NOT including the SCF kpoints
nscf = scf_kpoints.get_kpoints().shape[0]
per_split = orm.Int(self.inputs.band_settings['kpoints_per_split'] - nscf)
if (per_split / nscf) <= 0.5:
per_split = int(nscf * 0.5)
self.report(f'WARNING: Too few actual band k points per split, setting it to: {per_split + nscf}')
kpoints_for_calc = split_kpoints(scf_kpoints, full_kpoints, per_split)
self.ctx.kpoints_for_calc = kpoints_for_calc
[docs]
def should_do_scf_for_scf_kpoints(self) -> bool:
"""Check if one should redo a SCF run to obtain the IBZKPT"""
scf_kpoints = self.get_scf_kpoints()
if scf_kpoints is None:
return True
# Check if the scf_kpoints is consistent with the current structure
scf_cell = np.array(scf_kpoints.cell)
current_cell = np.array(self.ctx.current_structure.cell)
assert scf_cell.shape == (3, 3)
assert current_cell.shape == (3, 3)
# Compute the cell parameters
par1 = cell_to_cellpar(scf_cell)
par2 = cell_to_cellpar(current_cell)
# If the cells parameters are different - we should run a SCF calculation to get the kpoint for this new cell
if not np.allclose(par1, par2, 1e-5):
return True
return False
[docs]
def run_scf_for_kpoints(self) -> Any:
"""
Run an SCF calculation to just obtain the kpoints for the current structure
Ideally we should do this in a dryrun mode @ local machine
"""
workflow_class = WorkflowFactory(self._base_wk_string)
# Check if we need to turn off spin polarization
inputs = self.exposed_inputs(workflow_class, 'scf')
pdict = inputs.parameters.get_dict()
# Reuse the wavecar if requested
if self.inputs.band_settings.get('hybrid_reuse_wavecar', False):
self.report('Setting ISTART=1 to reuse WAVECAR from the previous calculation.')
pdict['incar']['istart'] = 1
inputs.parameters = orm.Dict(pdict)
# Ensure that the kpoints are returned by the parser
if 'settings' not in inputs:
inputs.settings = orm.Dict(dict={'parser_settings': {'include_node': ['kpoints']}})
else:
# Merge with 'parser_settings'
inputs.settings = update_nested_dict_node(
inputs.settings, {'parser_settings': {'include_node': ['kpoints']}}, extend_list=True
)
inputs.metadata.label = self.get_appended_label('SCF KPOINTS')
inputs.metadata.call_link_label = 'scf_for_kpoints'
inputs.structure = self.ctx.current_structure # Use the current structure as reference
running = self.submit(workflow_class, **inputs)
return self.to_context(workchain_scf_for_kpoints=running)
[docs]
def verify_scf_for_kpoints(self) -> Any:
"""Inspect the SCF for kpoints calculation"""
scf_workchain = self.ctx.workchain_scf_for_kpoints
if not scf_workchain.is_finished_ok:
self.report('SCF for kpoints workchain finished with Error')
return self.exit_codes.ERROR_SUB_PROC_SCF_FAILED
# Save the obtained kpoints
self.ctx.scf_kpoints = scf_workchain.outputs.kpoints
self.report(f'SCF calculation {scf_workchain} completed and obtained kpoints {self.ctx.scf_kpoints}')
[docs]
def run_scf_multi(self) -> Any:
"""
Launch multiple SCF calculations with zero-weighted kpoints for segments of the band structure
"""
workflow_class = WorkflowFactory(self._base_wk_string)
# Check if we need to turn off spin polarization
inputs = self.exposed_inputs(workflow_class, 'scf')
pdict = inputs.parameters.get_dict()
# Check if we really need to run spin polarized calculation
relax_work = self.ctx.get('workchain_relax', None)
if relax_work is not None and pdict.get('incar', {}).get('ispin') == 2:
self.report('Checking the magnetization of the relaxed structure.')
# Check if the site magnetizations are all zero
mag = relax_work.outputs.misc.get('site_magnetization')
if not _is_magnetic_via_site_moment(mag):
pdict['incar']['ispin'] = 1
self.report('Turnning off spin polarization for band structure calculation for non-magnetic system.')
inputs.parameters = orm.Dict(pdict)
# Reuse the wavecar if requested
if self.inputs.band_settings.get('hybrid_reuse_wavecar', False):
self.report('Setting ISTART=1 to reuse WAVECAR from the previous calculation.')
pdict['incar']['istart'] = 1
inputs.parameters = orm.Dict(pdict)
pnode = inputs.parameters
for key, value in self.ctx.kpoints_for_calc.items():
idx = int(key.split('_')[-1])
inputs = self.exposed_inputs(workflow_class, 'scf')
# Use the updated parameters
inputs.parameters = pnode
if self.inputs.band_settings.get('hybrid_reuse_wavecar', False):
inputs.restart_folder = relax_work.outputs.remote_folder
# Ensure that the bands are parsed
if 'settings' not in inputs:
inputs.settings = orm.Dict(dict={'parser_settings': {'include_node': ['bands']}})
else:
# Merge with 'parser_settings'
inputs.settings = update_nested_dict_node(
inputs.settings, {'parser_settings': {'include_node': ['bands']}}, extend_list=True
)
# Swap the kpoints the the one with zero-weight parts
inputs.kpoints = value
inputs.metadata.label = self.get_appended_label(f' SPLIT {idx}')
inputs.metadata.call_link_label = f'bandstructure_split_{idx:03d}'
inputs.structure = self.ctx.current_structure
running = self.submit(workflow_class, **inputs)
self.report(f'launching {workflow_class.__name__}<{running.pk}> for split #{idx}')
self.to_context(workchains=append_(running))
[docs]
def inspect_and_combine_bands(self) -> None:
"""
Inspect that all calculations have finished OK
"""
workchains = self.ctx.workchains
return_codes = [work.exit_status for work in workchains]
if any(return_codes):
self.report('At least one calculation did not have zero return code!')
# Extract the bands information
self.report(f'Extracting output bandstructure from {len(self.ctx.workchains)} workchains.')
kwargs = {}
for work in workchains:
link_label = work.base.links.get_incoming(link_type=LinkType.CALL_WORK).one().link_label
link_idx = int(link_label.split('_')[-1])
kwargs[f'band_{link_idx:03d}'] = work.outputs.bands
kwargs[f'kpoint_{link_idx:03d}'] = work.inputs.kpoints
combined_bands = combine_bands_data(self.ctx.bs_kpoints, **kwargs)
self.out('band_structure', combined_bands)
[docs]
@calcfunction
def split_kpoints(scf_kpoints: orm.KpointsData, band_kpoints: orm.KpointsData, kpn_per_split: orm.Int) -> Any:
"""
Split the kpoints into multiple one and combined with SCF kpoints
The kpoints for band structure calculation has zero weights
"""
return _split_kpoints(scf_kpoints, band_kpoints, kpn_per_split)
[docs]
def _split_kpoints(scf_kpoints: orm.KpointsData, band_kpoints: orm.KpointsData, kpn_per_split: orm.Int) -> Any:
"""
Split the kpoints into multiple one and combined with SCF kpoints
The kpoints for band structure calculation has zero weights
"""
scf_kpoints_array, scf_weights_array = scf_kpoints.get_kpoints(also_weights=True)
band_kpn = band_kpoints.get_kpoints()
nband_kpts = band_kpn.shape[0]
nscf_kpts = scf_kpoints_array.shape[0]
# Split the kpoints
kpn_per_split = int(kpn_per_split)
kpt_splits = [band_kpn[i : i + kpn_per_split] for i in range(0, nband_kpts, kpn_per_split)]
splitted_kpoints = {}
for isplit, skpts in enumerate(kpt_splits):
kpt = orm.KpointsData()
kpt_array = np.concatenate([scf_kpoints_array, skpts], axis=0)
weights_array = np.zeros(kpt_array.shape[0])
# Set the weights for SCF kpoints
weights_array[:nscf_kpts] = scf_weights_array
# Set kpoints and the weights
kpt.set_kpoints(kpt_array, weights=weights_array)
kpt.label = f'SPLIT {isplit:03d}'
kpt.description = 'Splitted kpoints'
splitted_kpoints[f'bs_kpoints_{isplit:03d}'] = kpt
return splitted_kpoints
[docs]
def dryrun_split_kpoints(
structure: orm.StructureData,
scf_kpoints: orm.KpointsData,
kpn_per_split: orm.Int,
kpoints_args: Any = None,
verbose: bool = True,
) -> Any:
"""
Perform a "dryrun" for splitting the kpoints
"""
if kpoints_args is None:
kpoints_args = {}
seekpath_results = get_explicit_kpoints_path(structure, **kpoints_args)
explicit_kpoints = seekpath_results['explicit_kpoints']
splitted = _split_kpoints(scf_kpoints, explicit_kpoints, kpn_per_split)
if verbose:
nseg = len(splitted)
nkpts = [kpn.get_kpoints().shape[0] for kpn in splitted.values()]
print(f'Splitted in to {nseg} segements with number of kpoints: {nkpts}')
return seekpath_results, splitted
[docs]
@calcfunction
def combine_bands_data(bs_kpoints: orm.KpointsData, **kwargs: Any) -> orm.BandsData:
"""
Combine splitted bands and kpoints data
The inputs should be supplied as keyword arguments like `band_001`, `kpoint_001` for the splitted
kpoints and correspdonging bands data from each calculation.
The `bs_kpoints` is the originally generated band structure path.
Returns a `BandsData` by combining the zero-weighted bands from each calculation.
"""
kpoints_list = [[node, int(key.split('_')[1])] for key, node in kwargs.items() if 'kpoint' in key]
kpoints_list.sort(key=lambda x: x[1])
kpoints_list = [item[0] for item in kpoints_list]
bands_list = [[node, int(key.split('_')[1])] for key, node in kwargs.items() if 'band' in key]
bands_list.sort(key=lambda x: x[1])
bands_list = [item[0] for item in bands_list]
return _combine_bands_data(bs_kpoints, kpoints_list, bands_list)
[docs]
def _combine_bands_data(
bs_kpoints: orm.KpointsData,
kpoints_list: List[orm.KpointsData],
bands_list: List[orm.BandsData],
) -> orm.BandsData:
"""
Combine bands from splitted kpoints into a single bands node.
The list of kpoints and bands must be sorted in the right order.
"""
bands_array_combine = []
occu_array_combine = []
kpoints_combine = []
fermi_levels = []
for skpts, sbands in zip(kpoints_list, bands_list):
fermi_levels.append(sbands.base.attributes.get('fermi_level', None))
kpt_array, weights_array = skpts.get_kpoints(also_weights=True)
zero_weight_mask = weights_array == 0.0
kpoints_combine.append(kpt_array[zero_weight_mask, :])
bands_array = sbands.get_bands()
if 'occupations' in sbands.get_arraynames():
occ_array = sbands.get_array('occupations')
else:
occ_array = None
# Bands array can have three or two dimensions, we have to handle it separately
if bands_array.ndim == 3:
bands_array_combine.append(bands_array[:, zero_weight_mask, :])
if occ_array is not None:
occu_array_combine.append(occ_array[:, zero_weight_mask, :])
else:
bands_array_combine.append(bands_array[zero_weight_mask, :])
if occ_array is not None:
occu_array_combine.append(occ_array[zero_weight_mask, :])
# Concatenate arrays
if bands_array.ndim == 3:
band_array_full = np.concatenate(bands_array_combine, axis=1)
if occu_array_combine:
occu_array_full = np.concatenate(occu_array_combine, axis=1)
else:
occu_array_full = None
else:
band_array_full = np.concatenate(bands_array_combine, axis=0)
if occu_array_combine:
occu_array_full = np.concatenate(occu_array_combine, axis=0)
else:
occu_array_full = None
# Sanity check all valid kpoints should combine into the original path
all_kpoints = np.concatenate(kpoints_combine, axis=0)
if not np.allclose(all_kpoints, bs_kpoints.get_kpoints()):
raise ValueError('The k-path segements do not much the original path when combined!')
# Compose the node
band_data = orm.BandsData()
band_data.set_kpointsdata(bs_kpoints)
band_data.set_bands(band_array_full, occupations=occu_array_full)
# Set the fermi level of the combined bands
if any(x is None for x in fermi_levels) or any(abs(entry - fermi_levels[0]) > 0.01 for entry in fermi_levels):
logger.warning(
f'Fermi level of the splitted calculations ({fermi_levels}) are not consistent! '
'Using the first one as the combined fermi level.'
)
band_data.base.attributes.set('fermi_level', fermi_levels[0])
band_data.base.attributes.set('efermi', fermi_levels[0]) # Alias for fermi level
return band_data
[docs]
def _is_magnetic_via_site_moment(mag: Any) -> bool:
has_mag = False
# Iterate over dictionaries of the site moments of each site
for site in mag['sphere']['x']['site_moment'].values():
# Check if any of the moments is non-zero
if any(abs(x) > SITE_MAG_THRESHOLD for x in site.values()):
has_mag = True
break
return has_mag