Skip to content

Commit

Permalink
add category property to OutputVariableDef (#3228)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
njzjz and wanghan-iapcm authored Feb 6, 2024
1 parent 79f98ca commit 6c12380
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 2 deletions.
103 changes: 101 additions & 2 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
from enum import (
IntEnum,
)
from typing import (
Dict,
List,
Expand Down Expand Up @@ -107,6 +110,38 @@ def __call__(
return wrapper


class OutputVariableOperation(IntEnum):
"""Defines the operation of the output variable."""

_NONE = 0
"""No operation."""
REDU = 1
"""Reduce the output variable."""
DERV_R = 2
"""Derivative w.r.t. coordinates."""
DERV_C = 4
"""Derivative w.r.t. cell."""
_SEC_DERV_R = 8
"""Second derivative w.r.t. coordinates."""


class OutputVariableCategory(IntEnum):
"""Defines the category of the output variable."""

OUT = OutputVariableOperation._NONE
"""Output variable. (e.g. atom energy)"""
REDU = OutputVariableOperation.REDU
"""Reduced output variable. (e.g. system energy)"""
DERV_R = OutputVariableOperation.DERV_R
"""Negative derivative w.r.t. coordinates. (e.g. force)"""
DERV_C = OutputVariableOperation.DERV_C
"""Atomic component of the virial, see PRB 104, 224202 (2021) """
DERV_C_REDU = OutputVariableOperation.DERV_C | OutputVariableOperation.REDU
"""Virial, the transposed negative gradient with cell tensor times cell tensor, see eq 40 JCP 159, 054801 (2023). """
DERV_R_DERV_R = OutputVariableOperation.DERV_R | OutputVariableOperation._SEC_DERV_R
"""Hession matrix, the second derivative w.r.t. coordinates."""


class OutputVariableDef:
"""Defines the shape and other properties of the one output variable.
Expand All @@ -129,7 +164,8 @@ class OutputVariableDef:
If the variable is differentiated with respect to coordinates
of atoms and cell tensor (pbc case). Only reduciable variable
are differentiable.
category : int
The category of the output variable.
"""

def __init__(
Expand All @@ -139,6 +175,7 @@ def __init__(
reduciable: bool = False,
differentiable: bool = False,
atomic: bool = True,
category: int = OutputVariableCategory.OUT.value,
):
self.name = name
self.shape = list(shape)
Expand All @@ -149,6 +186,7 @@ def __init__(
raise ValueError("only reduciable variable are differentiable")
if self.reduciable and not self.atomic:
raise ValueError("only reduciable variable should be atomic")
self.category = category


class FittingOutputDef:
Expand Down Expand Up @@ -255,6 +293,60 @@ def get_deriv_name(name: str) -> Tuple[str, str]:
return name + "_derv_r", name + "_derv_c"


def apply_operation(var_def: OutputVariableDef, op: OutputVariableOperation) -> int:
"""Apply a operation to the category of a variable definition.
Parameters
----------
var_def : OutputVariableDef
The variable definition.
op : OutputVariableOperation
The operation to be applied.
Returns
-------
int
The new category of the variable definition.
Raises
------
ValueError
If the operation has been applied to the variable definition,
and exceed the maximum limitation.
"""
if op == OutputVariableOperation.REDU or op == OutputVariableOperation.DERV_C:
if check_operation_applied(var_def, op):
raise ValueError(f"operation {op} has been applied")
elif op == OutputVariableOperation.DERV_R:
if check_operation_applied(var_def, OutputVariableOperation.DERV_R):
op = OutputVariableOperation._SEC_DERV_R
if check_operation_applied(var_def, OutputVariableOperation._SEC_DERV_R):
raise ValueError(f"operation {op} has been applied twice")
else:
raise ValueError(f"operation {op} not supported")
return var_def.category | op.value


def check_operation_applied(
var_def: OutputVariableDef, op: OutputVariableOperation
) -> bool:
"""Check if a operation has been applied to a variable definition.
Parameters
----------
var_def : OutputVariableDef
The variable definition.
op : OutputVariableOperation
The operation to be checked.
Returns
-------
bool
True if the operation has been applied, False otherwise.
"""
return var_def.category & op.value == op.value


def do_reduce(
def_outp_data: Dict[str, OutputVariableDef],
) -> Dict[str, OutputVariableDef]:
Expand All @@ -263,7 +355,12 @@ def do_reduce(
if vv.reduciable:
rk = get_reduce_name(kk)
def_redu[rk] = OutputVariableDef(
rk, vv.shape, reduciable=False, differentiable=False, atomic=False
rk,
vv.shape,
reduciable=False,
differentiable=False,
atomic=False,
category=apply_operation(vv, OutputVariableOperation.REDU),
)
return def_redu

Expand All @@ -282,12 +379,14 @@ def do_derivative(
reduciable=False,
differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_R),
)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
reduciable=True,
differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_C),
)
return def_derv_r, def_derv_c
98 changes: 98 additions & 0 deletions source/tests/common/dpmodel/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
model_check_output,
)
from deepmd.dpmodel.output_def import (
OutputVariableCategory,
OutputVariableOperation,
apply_operation,
check_var,
)

Expand Down Expand Up @@ -103,6 +106,101 @@ def test_model_output_def(self):
self.assertEqual(md["energy_derv_r"].atomic, True)
self.assertEqual(md["energy_derv_c"].atomic, True)
self.assertEqual(md["energy_derv_c_redu"].atomic, False)
# category
self.assertEqual(md["energy"].category, OutputVariableCategory.OUT)
self.assertEqual(md["dos"].category, OutputVariableCategory.OUT)
self.assertEqual(md["foo"].category, OutputVariableCategory.OUT)
self.assertEqual(md["energy_redu"].category, OutputVariableCategory.REDU)
self.assertEqual(md["energy_derv_r"].category, OutputVariableCategory.DERV_R)
self.assertEqual(md["energy_derv_c"].category, OutputVariableCategory.DERV_C)
self.assertEqual(
md["energy_derv_c_redu"].category, OutputVariableCategory.DERV_C_REDU
)
# flag
self.assertEqual(md["energy"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["energy"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["energy"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["dos"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["foo"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(
md["energy_redu"].category & OutputVariableOperation.REDU,
OutputVariableOperation.REDU,
)
self.assertEqual(md["energy_redu"].category & OutputVariableOperation.DERV_R, 0)
self.assertEqual(md["energy_redu"].category & OutputVariableOperation.DERV_C, 0)
self.assertEqual(md["energy_derv_r"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(
md["energy_derv_r"].category & OutputVariableOperation.DERV_R,
OutputVariableOperation.DERV_R,
)
self.assertEqual(
md["energy_derv_r"].category & OutputVariableOperation.DERV_C, 0
)
self.assertEqual(md["energy_derv_c"].category & OutputVariableOperation.REDU, 0)
self.assertEqual(
md["energy_derv_c"].category & OutputVariableOperation.DERV_R, 0
)
self.assertEqual(
md["energy_derv_c"].category & OutputVariableOperation.DERV_C,
OutputVariableOperation.DERV_C,
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.REDU,
OutputVariableOperation.REDU,
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.DERV_R, 0
)
self.assertEqual(
md["energy_derv_c_redu"].category & OutputVariableOperation.DERV_C,
OutputVariableOperation.DERV_C,
)

# apply_operation
self.assertEqual(
apply_operation(md["energy"], OutputVariableOperation.REDU),
md["energy_redu"].category,
)
self.assertEqual(
apply_operation(md["energy"], OutputVariableOperation.DERV_R),
md["energy_derv_r"].category,
)
self.assertEqual(
apply_operation(md["energy"], OutputVariableOperation.DERV_C),
md["energy_derv_c"].category,
)
self.assertEqual(
apply_operation(md["energy_derv_c"], OutputVariableOperation.REDU),
md["energy_derv_c_redu"].category,
)
# raise ValueError
with self.assertRaises(ValueError):
apply_operation(md["energy_redu"], OutputVariableOperation.REDU)
with self.assertRaises(ValueError):
apply_operation(md["energy_derv_c"], OutputVariableOperation.DERV_C)
with self.assertRaises(ValueError):
apply_operation(md["energy_derv_c_redu"], OutputVariableOperation.REDU)
# hession
hession_cat = apply_operation(
md["energy_derv_r"], OutputVariableOperation.DERV_R
)
self.assertEqual(
hession_cat & OutputVariableOperation.DERV_R, OutputVariableOperation.DERV_R
)
self.assertEqual(
hession_cat & OutputVariableOperation._SEC_DERV_R,
OutputVariableOperation._SEC_DERV_R,
)
self.assertEqual(hession_cat, OutputVariableCategory.DERV_R_DERV_R)
hession_vardef = OutputVariableDef(
"energy_derv_r_derv_r", [1], False, False, category=hession_cat
)
with self.assertRaises(ValueError):
apply_operation(hession_vardef, OutputVariableOperation.DERV_R)

def test_raise_no_redu_deriv(self):
with self.assertRaises(ValueError) as context:
Expand Down

0 comments on commit 6c12380

Please sign in to comment.