diff --git a/deepmd/calculator.py b/deepmd/calculator.py index 6f863ab09b..0fbc447aaa 100644 --- a/deepmd/calculator.py +++ b/deepmd/calculator.py @@ -45,6 +45,8 @@ class DP(Calculator): will infer this information from model, by default None neighbor_list : ase.neighborlist.NeighborList, optional The neighbor list object. If None, then build the native neighbor list. + head : Union[str, None], optional + a specific model branch choosing from pretrained model, by default None Examples -------- @@ -84,10 +86,15 @@ def __init__( label: str = "DP", type_dict: Optional[dict[str, int]] = None, neighbor_list=None, + head=None, **kwargs, ) -> None: Calculator.__init__(self, label=label, **kwargs) - self.dp = DeepPot(str(Path(model).resolve()), neighbor_list=neighbor_list) + self.dp = DeepPot( + str(Path(model).resolve()), + neighbor_list=neighbor_list, + head=head, + ) if type_dict: self.type_dict = type_dict else: