diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a85b033 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*.pyc +*.egg-info +*.csv +*.npz +.idea diff --git a/DeepSolid/__init__.py b/DeepSolid/__init__.py new file mode 100644 index 0000000..1c452c8 --- /dev/null +++ b/DeepSolid/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/DeepSolid/base_config.py b/DeepSolid/base_config.py new file mode 100644 index 0000000..4d19e6a --- /dev/null +++ b/DeepSolid/base_config.py @@ -0,0 +1,161 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import ml_collections +from ml_collections import config_dict + + +def default() -> ml_collections.ConfigDict: + """Create set of default parameters for running qmc.py. + + Note: placeholders (cfg.system.molecule and cfg.system.electrons) must be + replaced with appropriate values. + + Returns: + ml_collections.ConfigDict containing default settings. + """ + # wavefunction output. + cfg = ml_collections.ConfigDict({ + 'batch_size': 100, # batch size + # Config module used. Should be set in get_config function as either the + # absolute module or relative to the configs subdirectory. Relative + # imports must start with a '.' (e.g. .atom). Do *not* override on + # command-line. Do *not* set using __name__ from inside a get_config + # function, as config_flags overrides this when importing the module using + # importlib.import_module. + 'config_module': __name__, + 'use_x64': True, # use float64 or 32 + 'optim': { + 'iterations': 1000000, # number of iterations + 'optimizer': 'kfac', + 'local_energy_outlier_width': 5.0, + 'lr': { + 'rate': 5.e-2, # learning rate, different from the reported lr in FermiNet + # since DeepSolid energy gradient is not batch-size dependent + 'decay': 1.0, # exponent of learning rate decay + 'delay': 10000.0, # term that sets the scale of the rate decay + }, + 'clip_el': 5.0, # If not none, scale at which to clip local energy + 'clip_type': 'real', # Clip real and imag part of gradient. + 'gradient_clip': 5.0, + # ADAM hyperparameters. See optax documentation for details. + 'adam': { + 'b1': 0.9, + 'b2': 0.999, + 'eps': 1.e-8, + 'eps_root': 0.0, + }, + 'kfac': { + 'invert_every': 1, + 'cov_update_every': 1, + 'damping': 0.001, + 'cov_ema_decay': 0.95, + 'momentum': 0.0, + 'momentum_type': 'regular', + # Warning: adaptive damping is not currently available. + 'min_damping': 1.e-4, + 'norm_constraint': 0.001, + 'mean_center': True, + 'l2_reg': 0.0, + 'register_only_generic': False, + }, + 'ministeps': 1, + 'laplacian_mode': 'for', # specify the laplacian evaluation mode, mode is one of 'for', 'partition' or 'hessian' + # 'for' mode calculates the laplacian of each electron one by one, which is slow but save GPU memory + # 'hessian' mode calculates the laplacian in a highly parallized mode, which is fast but require GPU memory + # 'partition' mode calculate the laplacian in a moderate way. + 'partition_number': 3, + # Only used for 'partition' mode. + # partition_number must be divisivle by (dim * number of electrons). The smaller the faster, but requires more memory. + }, + 'log': { + 'stats_frequency': 1, # iterations between logging of stats + 'save_frequency': 10.0, # minutes between saving network params + 'save_frequency_in_step': -1, + 'save_path': '', + # specify the local save path + 'restore_path': '', + # specify the restore path which contained saved Model parameters. + 'local_energies': False, + 'complex_polarization': False, # log polarization order parameter which is useful for hydrogen chain. + 'structure_factor': False, + # return the strture factor S(k) at reciprocal lattices of supercell + # log S(k) requires a lot of storage space, be careful. + 'stats_file_name': 'train_stats' + }, + 'system': { + 'pyscf_cell': None, # simulation cell obj + 'ndim': 3, #dimension of the system + 'internal_cell': None, + }, + 'mcmc': { + # Note: HMC options are not currently used. + # Number of burn in steps after pretraining. If zero do not burn in + # or reinitialize walkers. + 'burn_in': 100, + 'steps': 20, # Number of MCMC steps to make between network updates. + # Width of (atom-centred) Gaussian used to generate initial electron + # configurations. + 'init_width': 0.8, + # Width of Gaussian used for random moves for RMW or step size for + # HMC. + 'move_width': 0.02, + # Number of steps after which to update the adaptive MCMC step size + 'adapt_frequency': 100, + 'init_means': (), # Not implemented in JAX. + # If true, scale the proposal width for each electron by the harmonic + # mean of the distance to the nuclei. + 'importance_sampling': False, + # whether to use importance sampling in MCMC step, untested yet + # Metropolis sampling will be used if false + 'one_electron': False + # If true, use one-electron moves, untested yet + }, + 'network': { + 'detnet': { + 'envelope_type': 'isotropic', + # only isotropic mode has been tested + 'bias_orbitals': False, + 'use_last_layer': False, + 'full_det': False, + 'hidden_dims': ((256, 32), (256, 32), (256, 32)), + 'determinants': 8, + 'after_determinants': 1, + }, + 'twist': (0.0, 0.0, 0.0), # Difine the twist of wavefunction, + # twists are given in terms of fractions of supercell reciprocal vectors + }, + 'debug': { + # Check optimizer state, parameters and loss and raise an exception if + # NaN is found. + 'check_nan': False, # check whether the gradient contain nans before optimize, if True, retry. + 'deterministic': False, # Use a deterministic seed. + }, + 'pretrain': { + 'method': 'net', # Method is one of 'hf', 'net'. + 'iterations': 1000, + 'lr': 3e-4, + 'steps': 1, #mcmc steps between each pretrain iterations + }, + }) + + return cfg + + +def resolve(cfg): + cfg = cfg.copy_and_resolve_references() + return cfg diff --git a/DeepSolid/checkpoint.py b/DeepSolid/checkpoint.py new file mode 100644 index 0000000..dba4317 --- /dev/null +++ b/DeepSolid/checkpoint.py @@ -0,0 +1,165 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import datetime +import os +from typing import Optional +import zipfile + +from absl import logging +import jax +import numpy as np + + +def get_restore_path(restore_path: Optional[str] = None) -> Optional[str]: + """Gets the path containing checkpoints from a previous calculation. + + Args: + restore_path: path to checkpoints. + + Returns: + The path or None if restore_path is falsy. + """ + if restore_path: + ckpt_restore_path = restore_path + else: + ckpt_restore_path = None + return ckpt_restore_path + + +def find_last_checkpoint(ckpt_path: Optional[str] = None) -> Optional[str]: + """Finds most recent valid checkpoint in a directory. + + Args: + ckpt_path: Directory containing checkpoints. + + Returns: + Last QMC checkpoint (ordered by sorting all checkpoints by name in reverse) + or None if no valid checkpoint is found or ckpt_path is not given or doesn't + exist. A checkpoint is regarded as not valid if it cannot be read + successfully using np.load. + """ + if ckpt_path and os.path.exists(ckpt_path): + files = [f for f in os.listdir(ckpt_path) if 'qmcjax_ckpt_' in f] + # Handle case where last checkpoint is corrupt/empty. + for file in sorted(files, reverse=True): + fname = os.path.join(ckpt_path, file) + with open(fname, 'rb') as f: + try: + np.load(f, allow_pickle=True) + return fname + except (OSError, EOFError, zipfile.BadZipFile): + logging.info('Error loading checkpoint %s. Trying next checkpoint...', + fname) + return None + + +def create_save_path(save_path: Optional[str],) -> str: + """Creates the directory for saving checkpoints, if it doesn't exist. + + Args: + save_path: directory to use. If false, create a directory in the working + directory based upon the current time. + + Returns: + Path to save checkpoints to. + """ + timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') + default_save_path = os.path.join(os.getcwd(), f'DeepSolid_{timestamp}') + ckpt_save_path = save_path or default_save_path + + + if ckpt_save_path and not os.path.isdir(ckpt_save_path): + os.makedirs(ckpt_save_path) + + return ckpt_save_path + + +def save(save_path: str, t: int, data, params, opt_state, mcmc_width, + remote_save_path: Optional[int] = None) -> str: + """Saves checkpoint information to a npz file. + + Args: + save_path: path to directory to save checkpoint to. The checkpoint file is + save_path/qmcjax_ckpt_$t.npz, where $t is the number of completed + iterations. + t: number of completed iterations. + data: MCMC walker configurations. + params: pytree of network parameters. + opt_state: optimization state. + mcmc_width: width to use in the MCMC proposal distribution. + + Returns: + path to checkpoint file. + """ + ckpt_filename = os.path.join(save_path, f'qmcjax_ckpt_{t:06d}.npz') + logging.info('Saving checkpoint %s', ckpt_filename) + with open(ckpt_filename, 'wb') as f: + np.savez( + f, + t=t, + data=data, + params=params, + opt_state=opt_state, + mcmc_width=mcmc_width) + + return ckpt_filename + + +def restore(restore_filename: str, batch_size: Optional[int] = None, shape_check=True): + """Restores data saved in a checkpoint. + + Args: + restore_filename: filename containing checkpoint. + batch_size: total batch size to be used. If present, check the data saved in + the checkpoint is consistent with the batch size requested for the + calculation. + + Returns: + (t, data, params, opt_state, mcmc_width) tuple, where + t: number of completed iterations. + data: MCMC walker configurations. + params: pytree of network parameters. + opt_state: optimization state. + mcmc_width: width to use in the MCMC proposal distribution. + + Raises: + ValueError: if the leading dimension of data does not match the number of + devices (i.e. the number of devices being parallelised over has changed) or + if the total batch size is not equal to the number of MCMC configurations in + data. + """ + logging.info('Loading checkpoint %s', restore_filename) + with open(restore_filename, 'rb') as f: + ckpt_data = np.load(f, allow_pickle=True) + # Retrieve data from npz file. Non-array variables need to be converted back + # to natives types using .tolist(). + t = ckpt_data['t'].tolist() + 1 # Return the iterations completed. + data = ckpt_data['data'] + params = ckpt_data['params'].tolist() + opt_state = ckpt_data['opt_state'].tolist() + mcmc_width = ckpt_data['mcmc_width'].tolist() + if shape_check: + if data.shape[0] != jax.local_device_count(): + raise ValueError( + 'Incorrect number of devices found. Expected {}, found {}.'.format( + data.shape[0], jax.local_device_count())) + if batch_size and data.shape[0] * data.shape[1] != batch_size: + raise ValueError( + 'Wrong batch size in loaded data. Expected {}, found {}.'.format( + batch_size, data.shape[0] * data.shape[1])) + return t, data, params, opt_state, mcmc_width \ No newline at end of file diff --git a/DeepSolid/config/diamond.py b/DeepSolid/config/diamond.py new file mode 100644 index 0000000..90bd6c7 --- /dev/null +++ b/DeepSolid/config/diamond.py @@ -0,0 +1,36 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from pyscf.pbc import gto + +from DeepSolid import base_config +from DeepSolid import supercell +from DeepSolid.utils import units + + +def get_config(input_str): + X, Y, L_Ang, S, basis= input_str.split(',') + S = np.eye(3) * int(S) + cfg = base_config.default() + L_Ang = float(L_Ang) + L_Bohr = units.angstrom2bohr(L_Ang) + + # Set up cell + cell = gto.Cell() + cell.atom = [[X, [0.0, 0.0, 0.0]], + [Y, [0.25 * L_Bohr, 0.25 * L_Bohr, 0.25 * L_Bohr]]] + + cell.basis = basis + cell.a = (np.ones((3, 3)) - np.eye(3)) * L_Bohr / 2 + cell.unit = "B" + cell.verbose = 5 + cell.exp_to_discard = 0.1 + cell.build() + simulation_cell = supercell.get_supercell(cell, S) + cfg.system.pyscf_cell = simulation_cell + + return cfg \ No newline at end of file diff --git a/DeepSolid/config/graphene.py b/DeepSolid/config/graphene.py new file mode 100644 index 0000000..2e445d0 --- /dev/null +++ b/DeepSolid/config/graphene.py @@ -0,0 +1,40 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from pyscf.pbc import gto + +from DeepSolid import base_config +from DeepSolid import supercell +from DeepSolid.utils import units + + +def get_config(input_str): + X, Y, L_Ang, S, z, basis = input_str.split(',') + S = np.diag([int(S), int(S), 1]) + cfg = base_config.default() + L_Ang = float(L_Ang) + z = float(z) + L_Bohr = units.angstrom2bohr(L_Ang) + + # Set up cell + cell = gto.Cell() + cell.atom = [[X, [3**(-0.5) * L_Bohr, 0.0, 0.0]], + [Y, [2*3**(-0.5) * L_Bohr, 0.0, 0.0]]] + + cell.basis = basis + cell.a = np.array([[L_Bohr * np.cos(np.pi/6), -L_Bohr * 0.5, 0], + [L_Bohr * np.cos(np.pi/6), L_Bohr * 0.5, 0], + [0, 0, z], + ]) + cell.unit = "B" + cell.verbose = 5 + cell.exp_to_discard = 0.1 + cell.build() + simulation_cell = supercell.get_supercell(cell, S) + cfg.system.pyscf_cell = simulation_cell + + return cfg \ No newline at end of file diff --git a/DeepSolid/config/hydrogen_chain.py b/DeepSolid/config/hydrogen_chain.py new file mode 100644 index 0000000..2b83c9a --- /dev/null +++ b/DeepSolid/config/hydrogen_chain.py @@ -0,0 +1,42 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import numpy as np +from pyscf.pbc import gto + +from DeepSolid import base_config +from DeepSolid import supercell + + +def get_config(input_str): + symbol, Sx, Sy, Sz, L, spin, basis= input_str.split(',') + Sx = int(Sx) + Sy = int(Sy) + Sz = int(Sz) + S = np.diag([Sx, Sy, Sz]) + L = float(L) + spin = int(spin) + cfg = base_config.default() + + # Set up cell + cell = gto.Cell() + cell.atom = f""" + {symbol} {L/2} {0} {0} + """ + cell.basis = basis + cell.a = np.array([[L, 0, 0], + [0, 100, 0], + [0, 0, 100]]) + cell.unit = "B" + cell.spin = spin + cell.verbose = 5 + cell.exp_to_discard = 0.1 + cell.build() + simulation_cell = supercell.get_supercell(cell, S) + cfg.system.pyscf_cell = simulation_cell + + return cfg \ No newline at end of file diff --git a/DeepSolid/config/poscar/bcc_li.vasp b/DeepSolid/config/poscar/bcc_li.vasp new file mode 100644 index 0000000..ca580e9 --- /dev/null +++ b/DeepSolid/config/poscar/bcc_li.vasp @@ -0,0 +1,10 @@ +Li2 +1.0 + 3.4268178940 0.0000000000 0.0000000000 + 0.0000000000 3.4268178940 0.0000000000 + 0.0000000000 0.0000000000 3.4268178940 + Li + 2 +Cartesian + 0.000000000 0.000000000 0.000000000 + 1.713408947 1.713408947 1.713408947 diff --git a/DeepSolid/config/read_poscar.py b/DeepSolid/config/read_poscar.py new file mode 100644 index 0000000..8d25c8a --- /dev/null +++ b/DeepSolid/config/read_poscar.py @@ -0,0 +1,31 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from DeepSolid import base_config +from DeepSolid import supercell +from DeepSolid.utils import poscar_to_cell +import numpy as np + + +def get_config(input_str): + poscar_path, S, basis = input_str.split(',') + cell = poscar_to_cell.read_poscar(poscar_path) + S = int(S) + S = np.diag([S, S, S]) + cell.verbose = 5 + cell.basis = basis + cell.exp_to_discard = 0.1 + cell.build() + cfg = base_config.default() + + # Set up cell + + simulation_cell = supercell.get_supercell(cell, S) + if cell.spin != 0: + simulation_cell.hf_type = 'uhf' + cfg.system.pyscf_cell = simulation_cell + + return cfg \ No newline at end of file diff --git a/DeepSolid/config/rock_salt.py b/DeepSolid/config/rock_salt.py new file mode 100644 index 0000000..93d0f5e --- /dev/null +++ b/DeepSolid/config/rock_salt.py @@ -0,0 +1,37 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from pyscf.pbc import gto + +from DeepSolid import base_config +from DeepSolid import supercell +from DeepSolid.utils import units + + +def get_config(input_str): + X, Y, L_Ang, S, basis= input_str.split(',') + S = np.eye(3) * int(S) + cfg = base_config.default() + L_Ang = float(L_Ang) + L_Bohr = units.angstrom2bohr(L_Ang) + + # Set up cell + cell = gto.Cell() + cell.atom = [[X, [0.0, 0.0, 0.0]], + [Y, [0.5 * L_Bohr, 0.5 * L_Bohr, 0.5 * L_Bohr]]] + + + cell.basis = basis + cell.a = (np.ones((3, 3)) - np.eye(3)) * L_Bohr / 2 + cell.unit = "B" + cell.verbose = 5 + cell.exp_to_discard = 0.1 + cell.build() + simulation_cell = supercell.get_supercell(cell, S) + cfg.system.pyscf_cell = simulation_cell + + return cfg \ No newline at end of file diff --git a/DeepSolid/config/two_hydrogen_cell.py b/DeepSolid/config/two_hydrogen_cell.py new file mode 100644 index 0000000..45e1bbe --- /dev/null +++ b/DeepSolid/config/two_hydrogen_cell.py @@ -0,0 +1,44 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import numpy as np +from pyscf.pbc import gto + +from DeepSolid import base_config +from DeepSolid import supercell + + +def get_config(input_str): + symbol, Sx, Sy, Sz, L, spin, basis= input_str.split(',') + Sx = int(Sx) + Sy = int(Sy) + Sz = int(Sz) + S = np.diag([Sx, Sy, Sz]) + L = float(L) + spin = int(spin) + cfg = base_config.default() + + # Set up cell + cell = gto.Cell() + cell.atom = f""" + {symbol} {L} {0} {0} + {symbol} 0 0 0 + """ + cell.basis = basis + cell.a = np.array([[2*L, 0, 0], + [0, 100, 0], + [0, 0, 100]]) + cell.unit = "B" + cell.spin = spin + cell.verbose = 5 + cell.exp_to_discard = 0.1 + cell.build() + simulation_cell = supercell.get_supercell(cell, S) + simulation_cell.hf_type = 'uhf' + cfg.system.pyscf_cell = simulation_cell + + return cfg \ No newline at end of file diff --git a/DeepSolid/constants.py b/DeepSolid/constants.py new file mode 100644 index 0000000..76b8fd8 --- /dev/null +++ b/DeepSolid/constants.py @@ -0,0 +1,57 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import functools +import jax +import jax.numpy as jnp +from jax import core +from typing import TypeVar + +T = TypeVar("T") + +PMAP_AXIS_NAME = 'qmc_pmap_axis' + +pmap = functools.partial(jax.pmap, axis_name=PMAP_AXIS_NAME) +broadcast_all_local_devices = jax.pmap(lambda x: x) +p_split = jax.pmap(lambda key: tuple(jax.random.split(key))) + + +def wrap_if_pmap(p_func): + def p_func_if_pmap(obj, axis_name): + try: + core.axis_frame(axis_name) + return p_func(obj, axis_name) + except NameError: + return obj + + return p_func_if_pmap + + +pmean_if_pmap = wrap_if_pmap(jax.lax.pmean) +psum_if_pmap = wrap_if_pmap(jax.lax.psum) + + +def replicate_all_local_devices(obj: T) -> T: + n = jax.local_device_count() + obj_stacked = jax.tree_map(lambda x: jnp.stack([x] * n, axis=0), obj) + return broadcast_all_local_devices(obj_stacked) + + +def make_different_rng_key_on_all_devices(rng: jnp.ndarray) -> jnp.ndarray: + rng = jax.random.fold_in(rng, jax.host_id()) + rng = jax.random.split(rng, jax.local_device_count()) + return broadcast_all_local_devices(rng) diff --git a/DeepSolid/curvature_tags_and_blocks.py b/DeepSolid/curvature_tags_and_blocks.py new file mode 100644 index 0000000..6a2b265 --- /dev/null +++ b/DeepSolid/curvature_tags_and_blocks.py @@ -0,0 +1,160 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +"""Curvature blocks for FermiNet.""" +import functools +from typing import Optional, Mapping, Union +import jax +import jax.numpy as jnp + +from DeepSolid.utils.kfac_ferminet_alpha import curvature_blocks as blocks +from DeepSolid.utils.kfac_ferminet_alpha import layers_and_loss_tags as tags +from DeepSolid.utils.kfac_ferminet_alpha import utils + + +vmap_psd_inv_cholesky = jax.vmap(utils.psd_inv_cholesky, (0, None), 0) +vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0) + + +qmc1_tag = tags.LayerTag("qmc1_tag", 1, 1) + + +def register_qmc1(y, x, w, **kwargs): + return qmc1_tag.bind(y, x, w, **kwargs) + + +qmc2_tag = tags.LayerTag("qmc2_tag", 1, 1) + + +def register_qmc2(y, x, w, **kwargs): + return qmc2_tag.bind(y, x, w, **kwargs) + + +repeated_dense_tag = tags.LayerTag("repeated_dense_tag", 1, 1) + + +def register_repeated_dense(y, x, w, b): + if b is None: + return repeated_dense_tag.bind(y, x, w) + return repeated_dense_tag.bind(y, x, w, b) + + +class QmcBlockedDense(blocks.TwoKroneckerFactored): + """A factor that is the Kronecker product of two matrices.""" + + def update_curvature_inverse_estimate(self, diagonal_weight, pmap_axis_name): + self.inputs_factor.sync(pmap_axis_name) + + self.outputs_factor.sync(pmap_axis_name) + vmap_pi_adjusted_inverse = jax.vmap( + functools.partial(utils.pi_adjusted_inverse, + pmap_axis_name=pmap_axis_name), + (0, 0, None), (0, 0) + ) + self.inputs_factor_inverse, self.outputs_factor_inverse = ( + vmap_pi_adjusted_inverse(self.inputs_factor.value, + self.outputs_factor.value, + diagonal_weight / self.extra_scale)) + + def multiply_matpower(self, vec, exp, diagonal_weight): + w, = vec + # kmjn + v = w + k, m, j, n = v.shape + if exp == 1: + inputs_factor = self.inputs_factor.value + outputs_factor = self.outputs_factor.value + scale = self.extra_scale + elif exp == -1: + inputs_factor = self.inputs_factor_inverse + outputs_factor = self.outputs_factor_inverse + scale = 1.0 / self.extra_scale + diagonal_weight = 0.0 + else: + raise NotImplementedError() + # jk(mn) + v = jnp.transpose(v, [2, 0, 1, 3]).reshape([j, k, m * n]) + v = vmap_matmul(inputs_factor, v) + v = vmap_matmul(v, outputs_factor) + # kmjn + v = jnp.transpose(v.reshape([j, k, m, n]), [1, 2, 0, 3]) + v = v * scale + diagonal_weight * w + return (v,) + + def update_curvature_matrix_estimate( + self, + info: blocks._BlockInfo, # pylint: disable=protected-access + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + (x,), (dy,) = info["inputs"], info["outputs_tangent"] + assert batch_size == x.shape[0] + normalizer = x.shape[0] * x.shape[1] + # The forward computation is + # einsum(x,w): bijk,bkmjn -> bijmn + inputs_cov = jnp.einsum("bijk,bijl->jkl", x, x) / normalizer + dy = jnp.reshape(dy, dy.shape[:-2] + (-1,)) + outputs_cov = jnp.einsum("bijk,bijl->jkl", dy, dy) / normalizer + self.inputs_factor.update(inputs_cov, ema_old, ema_new) + self.outputs_factor.update(outputs_cov, ema_old, ema_new) + + def init(self, rng): + del rng + k, m, j, n = self.params_shapes[0] + return dict( + inputs_factor=utils.WeightedMovingAverage.zero([j, k, k]), + inputs_factor_inverse=jnp.zeros([j, k, k]), + outputs_factor=utils.WeightedMovingAverage.zero([j, m * n, m * n]), + outputs_factor_inverse=jnp.zeros([j, m * n, m * n]), + extra_scale=jnp.asarray(m) + ) + + def input_size(self) -> int: + raise NotImplementedError() + + def output_size(self) -> int: + raise NotImplementedError() + + +class RepeatedDenseBlock(blocks.DenseTwoKroneckerFactored): + """Dense block that is repeated.""" + + def compute_extra_scale(self) -> Optional[jnp.ndarray]: + (x_shape,) = self.inputs_shapes + return utils.product(x_shape) // (x_shape[0] * x_shape[-1]) + + def update_curvature_matrix_estimate( + self, + info: Mapping[str, blocks._Arrays], # pylint: disable=protected-access + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + info = dict(**info) + (x,), (dy,) = info["inputs"], info["outputs_tangent"] + assert x.shape[0] == batch_size + info["inputs"] = (x.reshape([-1, x.shape[-1]]),) + info["outputs_tangent"] = (dy.reshape([-1, dy.shape[-1]]),) + super().update_curvature_matrix_estimate(info, x.size // x.shape[-1], + ema_old, ema_new, pmap_axis_name) + + +blocks.set_default_tag_to_block("qmc1_tag", QmcBlockedDense) +blocks.set_default_tag_to_block("repeated_dense_tag", RepeatedDenseBlock) diff --git a/DeepSolid/distance.py b/DeepSolid/distance.py new file mode 100644 index 0000000..1ce00f8 --- /dev/null +++ b/DeepSolid/distance.py @@ -0,0 +1,185 @@ +# MIT License +# +# Copyright (c) 2019 Lucas K Wagner +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +from functools import partial +import jax +import jax.numpy as jnp +import logging + + +class MinimalImageDistance: + """Computer minimal image distance between particles and its images""" + + def __init__(self, latvec, verbose=0): + """ + + :param latvec: array with shape [3,3], each row with a lattice vector + """ + + latvec = jnp.asarray(latvec) + ortho_tol = 1e-10 + diagonal = jnp.all(jnp.abs(latvec - jnp.diag(jnp.diagonal(latvec))) < ortho_tol) + if diagonal: + self.dist_i = self.diagonal_dist_i + if verbose == 0: + logging.info("Diagonal lattice vectors") + else: + orthogonal = ( + jnp.dot(latvec[0], latvec[1]) < ortho_tol + and jnp.dot(latvec[1], latvec[2]) < ortho_tol + and jnp.dot(latvec[2], latvec[0]) < ortho_tol + ) + if orthogonal: + self.dist_i = self.orthogonal_dist_i + if verbose == 0: + logging.info("Orthogonal lattice vectors") + else: + self.dist_i = self.general_dist_i + if verbose == 0: + logging.info("Non-orthogonal lattice vectors") + self._latvec = latvec + self._invvec = jnp.linalg.inv(latvec) + self.dim = self._latvec.shape[-1] + # list of all 26 neighboring cells + mesh_grid = jnp.meshgrid(*[jnp.array([0, 1, 2]) for _ in range(3)]) + self.point_list = jnp.stack([m.ravel() for m in mesh_grid], axis=0).T - 1 + self.shifts = self.point_list @ self._latvec + + def general_dist_i(self, configs, vec, return_wrap=False): + """ + calculate minimal distance between electron and ion in the most general lattice vector + + :param configs: ion coordinate with shape [N_atom * 3] + :param vec: electron coordinate with shape [N_ele * 3] + :return: minimal image distance between electron and atom with shape [N_ele, N_atom, 3] + """ + configs = configs.reshape([1, -1, self.dim]) + v = vec.reshape([-1, 1, self.dim]) + d1 = v - configs + shifts = self.shifts.reshape((-1, *[1] * (len(d1.shape) - 1), 3)) + d1all = d1[None] + shifts + dists = jnp.linalg.norm(d1all, axis=-1) + mininds = jnp.argmin(dists, axis=0) + inds = jnp.meshgrid(*[jnp.arange(n) for n in mininds.shape], indexing='ij') + if return_wrap: + return d1all[(mininds, *inds)], -self.point_list[mininds] + else: + return d1all[(mininds, *inds)] + + def orthogonal_dist_i(self, configs, vec, return_wrap=False): + """ + calculate minimal distance between electron and ion in the orthogonal lattice vector + + :param configs: ion coordinate with shape [N_atom * 3] + :param vec: electron coordinate with shape [N_ele * 3] + :return: minimal image distance between electron and atom with shape [N_ele, N_atom, 3] + """ + configs = configs.reshape([1, -1, self.dim]).real + v = vec.reshape([-1, 1, self.dim]).real + d1 = v - configs + frac_disps = jnp.einsum("...ij,jk->...ik", d1, self._invvec) + replace_frac_disps = (frac_disps + 0.5) % 1 - 0.5 + if return_wrap == False: + return jnp.einsum("...ij,jk->...ik", replace_frac_disps, self._latvec) + else: + wrap = -((frac_disps + 0.5) // 1) + return jnp.einsum("...ij,jk->...ik", replace_frac_disps, self._latvec), wrap + + def diagonal_dist_i(self, configs, vec, return_wrap=False): + """ + calculate minimal distance between electron and ion in the diagonal lattice vector + + :param configs: ion coordinate with shape [N_atom * 3] + :param vec: electron coordinate with shape [N_ele * 3] + :return: minimal image distance between electron and atom with shape [N_ele, N_atom, 3] + """ + configs = configs.reshape([1, -1, self.dim]).real + v = vec.reshape([-1, 1, self.dim]).real + d1 = v - configs + latvec_diag = jnp.diagonal(self._latvec) + replace_d1 = (d1 + latvec_diag / 2) % latvec_diag - latvec_diag / 2 + if return_wrap == False: + return replace_d1 + else: + ## minus applies after //, order of // and - sign matters + wrap = -((d1 + latvec_diag / 2) // latvec_diag) + return replace_d1, wrap + + def dist_matrix(self, configs): + """ + calculate minimal distance between electrons + + :param configs: electron coordinate with shape [N_ele * 3] + :return: vs: electron coordinate diffs with shape [N_ele, N_ele, 3] + """ + + vs = self.dist_i(configs, configs) + vs = vs * (1 - jnp.eye(vs.shape[0]))[..., None] + + return vs + + +@partial(jax.vmap, in_axes=(None, 0), out_axes=0) +def enforce_pbc(latvec, epos): + """ + Enforces periodic boundary conditions on a set of configs. + + :param lattvecs: orthogonal lattice vectors defining 3D torus: (3,3) + :param epos: attempted new electron coordinates: (N_ele * 3) + :return: final electron coordinates with PBCs imposed: (N_ele * 3) + """ + + # Writes epos in terms of (lattice vecs) fractional coordinates + dim = latvec.shape[-1] + epos = epos.reshape(-1, dim) + recpvecs = jnp.linalg.inv(latvec) + epos_lvecs_coord = jnp.einsum("ij,jk->ik", epos, recpvecs) + + tmp = jnp.divmod(epos_lvecs_coord, 1) + wrap = tmp[0] + final_epos = jnp.matmul(tmp[1], latvec).ravel() + return final_epos, wrap + +import numpy as np + +def np_enforce_pbc(latvec, epos): + """ + Enforces periodic boundary conditions on a set of configs. Used in float 32 mode. + + :param lattvecs: orthogonal lattice vectors defining 3D torus: (3,3) + :param epos: attempted new electron coordinates: (N_ele * 3) + :return: final electron coordinates with PBCs imposed: (N_ele * 3) + """ + + # Writes epos in terms of (lattice vecs) fractional coordinates + dim = latvec.shape[-1] + epos = epos.reshape(-1, dim) + recpvecs = np.linalg.inv(latvec) + epos_lvecs_coord = np.einsum("ij,jk->ik", epos, recpvecs) + + tmp = np.divmod(epos_lvecs_coord, 1) + wrap = tmp[0] + final_epos = np.matmul(tmp[1], latvec).ravel() + return final_epos, wrap diff --git a/DeepSolid/distributed.py b/DeepSolid/distributed.py new file mode 100644 index 0000000..b8385f5 --- /dev/null +++ b/DeepSolid/distributed.py @@ -0,0 +1,55 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import functools + +from absl import logging +from jax._src.lib import xla_bridge +from jax._src.lib import xla_client +from jax._src.lib import xla_extension + +_service = None + + +def initialize(coordinator_address: str, num_processes: int, process_id: int, + xla_client_config): + """Initialize distributed system for topology discovery. + + Currently, calling ``initialize`` sets up the multi-host GPU backend, and + is not required for CPU or TPU backends. + + Args: + coordinator_address: IP address of the coordinator. + num_processes: Number of processes. + process_id: Id of the current processe. + + """ + if process_id == 0: + global _service + assert _service is None, 'initialize should be called once only' + logging.info('Starting JAX distributed service on %s', coordinator_address) + _service = xla_extension.get_distributed_runtime_service(coordinator_address, + num_processes) + + client = xla_extension.get_distributed_runtime_client(coordinator_address, + process_id, + **xla_client_config) + logging.info('Connecting to JAX distributed service on %s', coordinator_address) + client.connect() + + factory = functools.partial(xla_client.make_gpu_client, client, process_id) + xla_bridge.register_backend_factory('gpu', factory, priority=300) diff --git a/DeepSolid/estimator.py b/DeepSolid/estimator.py new file mode 100644 index 0000000..fb3c6d2 --- /dev/null +++ b/DeepSolid/estimator.py @@ -0,0 +1,78 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import jax +import jax.numpy as jnp +import functools + +import pyscf.pbc.gto +from DeepSolid import constants + + +def make_complex_polarization(simulation_cell: pyscf.pbc.gto.Cell, + direction: int = 0, + ndim=3): + ''' + the order parameter which is used to specify the hydrogen chain + :param simulation_cell: + :param direction: + :param ndim: + :return: + ''' + + rec_vec = simulation_cell.reciprocal_vectors()[direction] + + def complex_polarization(data): + """ + + :param data: electron walkers with shape [batch, ne * ndim] + :return: complex polarization with shape [] + """ + leading_shape = list(data.shape[:-1]) + data = data.reshape(leading_shape + [-1, ndim]) + dots = jnp.einsum('i,...i->...', rec_vec, data) + dots = jnp.sum(dots, axis=-1) + polarization = jnp.exp(1j * dots) + polarization = jnp.mean(polarization, axis=-1) + polarization = constants.pmean_if_pmap(polarization, axis_name=constants.PMAP_AXIS_NAME) + return polarization + + return complex_polarization + +def make_structure_factor(simulation_cell: pyscf.pbc.gto.Cell, + nq=4, + ndim=3): + mesh_grid = jnp.meshgrid(*[jnp.array(range(0, nq)) for _ in range(3)]) + point_list = jnp.stack([m.ravel() for m in mesh_grid], axis=0).T + rec_vec = simulation_cell.reciprocal_vectors() + + qvecs = point_list @ rec_vec + rec_vec = qvecs + nelec = simulation_cell.nelectron + + def structure_factor(data): + """ + + :param data: electron walkers with shape [batch, ne * ndim] + :return: complex polarization with shape [] + """ + leading_shape = list(data.shape[:-1]) + data = data.reshape(leading_shape + [-1, ndim]) + dots = jnp.einsum('kj,...j->...k', rec_vec, data) + # batch ne npoint + rho_k = jnp.exp(1j * dots) + rho_k = jnp.sum(rho_k, axis=1) + rho_k_one = jnp.mean(rho_k, axis=0) + rho_k_one_mean = constants.pmean_if_pmap(rho_k_one, axis_name=constants.PMAP_AXIS_NAME) + rho_k_two = jnp.mean(jnp.abs(rho_k)**2, axis=0) + rho_k_two_mean = constants.pmean_if_pmap(rho_k_two, axis_name=constants.PMAP_AXIS_NAME) + + sk = rho_k_two_mean - jnp.abs(rho_k_one_mean)**2 + sk = sk / nelec + + return sk + + return structure_factor \ No newline at end of file diff --git a/DeepSolid/ewaldsum.py b/DeepSolid/ewaldsum.py new file mode 100644 index 0000000..727264a --- /dev/null +++ b/DeepSolid/ewaldsum.py @@ -0,0 +1,200 @@ +# MIT License +# +# Copyright (c) 2019 Lucas K Wagner +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import logging + +import jax +import jax.numpy as jnp +from DeepSolid import distance + + +class EwaldSum: + def __init__(self, cell, ewald_gmax=200, nlatvec=1): + """ + :parameter cell: pyscf Cell object (simulation cell) + :parameter int ewald_gmax: how far to take reciprocal sum; probably never needs to be changed. + :parameter int nlatvec: how far to take real-space sum; probably never needs to be changed. + """ + self.nelec = cell.nelec + self.atom_coords = jnp.asarray(cell.atom_coords()) + self.atom_charges = jnp.asarray(cell.atom_charges()) + self.latvec = jnp.asarray(cell.lattice_vectors()) + self.dist = distance.MinimalImageDistance(self.latvec) + self.set_lattice_displacements(nlatvec) + self.set_up_reciprocal_ewald_sum(ewald_gmax) + + def set_lattice_displacements(self, nlatvec): + """ + Generates list of lattice-vector displacements to add together for real-space sum + + :parameter int nlatvec: sum goes from `-nlatvec` to `nlatvec` in each lattice direction. + """ + XYZ = jnp.meshgrid(*[jnp.arange(-nlatvec, nlatvec + 1)] * 3, indexing="ij") + xyz = jnp.stack(XYZ, axis=-1).reshape((-1, 3)) + self.lattice_displacements = jnp.asarray(jnp.dot(xyz, self.latvec)) + + def set_up_reciprocal_ewald_sum(self, ewald_gmax): + cellvolume = jnp.linalg.det(self.latvec) + recvec = jnp.linalg.inv(self.latvec).T + + # Determine alpha + smallestheight = jnp.amin(1 / jnp.linalg.norm(recvec, axis=1)) + self.alpha = 5.0 / smallestheight + logging.info(f"Setting Ewald alpha to {self.alpha.item()}") + + # Determine G points to include in reciprocal Ewald sum + gptsXpos = jnp.meshgrid( + jnp.arange(1, ewald_gmax + 1), + *[jnp.arange(-ewald_gmax, ewald_gmax + 1)] * 2, + indexing="ij" + ) + zero = jnp.asarray([0]) + gptsX0Ypos = jnp.meshgrid( + zero, + jnp.arange(1, ewald_gmax + 1), + jnp.arange(-ewald_gmax, ewald_gmax + 1), + indexing="ij", + ) + gptsX0Y0Zpos = jnp.meshgrid( + zero, zero, jnp.arange(1, ewald_gmax + 1), indexing="ij" + ) + gs = zip( + *[ + select_big(x, cellvolume, recvec, self.alpha) + for x in (gptsXpos, gptsX0Ypos, gptsX0Y0Zpos) + ] + ) + self.gpoints, self.gweight = [jnp.concatenate(x, axis=0) for x in gs] + self.set_ewald_constants(cellvolume) + + def set_ewald_constants(self, cellvolume): + self.i_sum = jnp.sum(self.atom_charges) + ii_sum2 = jnp.sum(self.atom_charges ** 2) + ii_sum = (self.i_sum ** 2 - ii_sum2) / 2 + + self.ijconst = -jnp.pi / (cellvolume * self.alpha ** 2) + self.squareconst = -self.alpha / jnp.sqrt(jnp.pi) + self.ijconst / 2 + + self.ii_const = ii_sum * self.ijconst + ii_sum2 * self.squareconst + self.e_single_test = -self.i_sum * self.ijconst + self.squareconst + self.ion_ion = self.ewald_ion() + + # XC correction not used, so we can compare to other codes + # rs = lambda ne: (3 / (4 * np.pi) / (ne * cellvolume)) ** (1 / 3) + # cexc = 0.36 + # xc_correction = lambda ne: cexc / rs(ne) + + def ee_const(self, ne): + return ne * (ne - 1) / 2 * self.ijconst + ne * self.squareconst + + def ei_const(self, ne): + return -ne * self.i_sum * self.ijconst + + def e_single(self, ne): + return ( + 0.5 * (ne - 1) * self.ijconst - self.i_sum * self.ijconst + self.squareconst + ) + + def ewald_ion(self): + # Real space part + if len(self.atom_charges) == 1: + ion_ion_real = 0 + else: + ion_distances = self.dist.dist_matrix(self.atom_coords.ravel()) + rvec = ion_distances[None, :, :, :] + self.lattice_displacements[:, None, None, :] + r = jnp.linalg.norm(rvec, axis=-1) + charge_ij = self.atom_charges[..., None] * self.atom_charges[None, ...] + ion_ion_real = jnp.sum(jnp.triu(charge_ij * jax.lax.erfc(self.alpha * r) / r, k=1)) + # Reciprocal space part + GdotR = jnp.dot(self.gpoints, jnp.asarray(self.atom_coords.T)) + self.ion_exp = jnp.dot(jnp.exp(1j * GdotR), self.atom_charges) + ion_ion_rec = jnp.dot(self.gweight, jnp.abs(self.ion_exp) ** 2) + + ion_ion = ion_ion_real + ion_ion_rec + return ion_ion + + def _real_cij(self, dists): + r = dists[:, :, None, :] + self.lattice_displacements + r = jnp.linalg.norm(r, axis=-1) + cij = jnp.sum(jax.lax.erfc(self.alpha * r) / r, axis=-1) + return cij + + def ewald_electron(self, configs): + nelec = sum(self.nelec) + + # Real space electron-ion part + # ei_distances shape (elec, atom, dim) + ei_distances = self.dist.dist_i(self.atom_coords.ravel(), configs) + ei_cij = self._real_cij(ei_distances) + ei_real_separated = jnp.sum(-self.atom_charges[None, :] * ei_cij) + + # Real space electron-electron part + ee_real_separated = jnp.array(0.) + if nelec > 1: + ee_distances = self.dist.dist_matrix(configs) + rvec = ee_distances[None, :, :, :] + self.lattice_displacements[:, None, None, :] + r = jnp.linalg.norm(rvec, axis=-1) + ee_real_separated = jnp.sum(jnp.triu(jax.lax.erfc(self.alpha * r) / r, k=1)) + + # ee_distances = self.dist.dist_matrix(configs) + # ee_cij = self._real_cij(ee_distances) + # + # for ((i, j), val) in zip(ee_inds, ee_cij.T): + # ee_real_separated[:, i] += val + # ee_real_separated[:, j] += val + # ee_real_separated /= 2 + + ee_recip, ei_recip = self.reciprocal_space_electron(configs) + ee = ee_real_separated + ee_recip + ei = ei_real_separated + ei_recip + return ee, ei + + def reciprocal_space_electron(self, configs): + # Reciprocal space electron-electron part + e_GdotR = jnp.einsum("ik,jk->ij", configs.reshape(sum(self.nelec), -1), self.gpoints) + sum_e_sin = jnp.sin(e_GdotR).sum(axis=0) + sum_e_cos = jnp.cos(e_GdotR).sum(axis=0) + ee_recip = jnp.dot(sum_e_sin ** 2 + sum_e_cos ** 2, self.gweight) + ## Reciprocal space electron-ion part + coscos_sinsin = -self.ion_exp.real * sum_e_cos - self.ion_exp.imag * sum_e_sin + ei_recip = 2 * jnp.dot(coscos_sinsin, self.gweight) + return ee_recip, ei_recip + + def energy(self, configs): + nelec = sum(self.nelec) + ee, ei = self.ewald_electron(configs) + ee += self.ee_const(nelec) + ei += self.ei_const(nelec) + ii = self.ion_ion + self.ii_const + return jnp.asarray(ee), jnp.asarray(ei), jnp.asarray(ii) + + +def select_big(gpts, cellvolume, recvec, alpha): + gpoints = jnp.einsum("j...,jk->...k", gpts, recvec) * 2 * jnp.pi + gsquared = jnp.einsum("...k,...k->...", gpoints, gpoints) + gweight = 4 * jnp.pi * jnp.exp(-gsquared / (4 * alpha ** 2)) + gweight /= cellvolume * gsquared + bigweight = gweight > 1e-12 + return gpoints[bigweight], gweight[bigweight] diff --git a/DeepSolid/hamiltonian.py b/DeepSolid/hamiltonian.py new file mode 100644 index 0000000..a5f975d --- /dev/null +++ b/DeepSolid/hamiltonian.py @@ -0,0 +1,204 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import jax +import jax.numpy as jnp +from DeepSolid import ewaldsum +from DeepSolid import network + + +def local_kinetic_energy(f): + ''' + holomorphic mode, which seems dangerous since many op don't support complex number now. + :param f: + :return: + ''' + def _lapl_over_f(params, x): + ne = x.shape[-1] + eye = jnp.eye(ne) + grad_f = jax.grad(f, argnums=1, holomorphic=True) + grad_f_closure = lambda y: grad_f(params, y) + + def _body_fun(i, val): + primal, tangent = jax.jvp(grad_f_closure, (x + 0j,), (eye[i] + 0j,)) + return val + tangent[i] + primal[i] ** 2 + + return -0.5 * jax.lax.fori_loop(0, ne, _body_fun, 0.0) + + return _lapl_over_f + + +def local_kinetic_energy_real_imag(f): + ''' + evaluate real and imaginary part of laplacian, which is slower than holomorphic mode but is much safer. + :param f: + :return: + ''' + def _lapl_over_f(params, x): + ne = x.shape[-1] + eye = jnp.eye(ne) + grad_f_real = jax.grad(lambda p, y: f(p, y).real, argnums=1) + grad_f_imag = jax.grad(lambda p, y: f(p, y).imag, argnums=1) + grad_f_real_closure = lambda y: grad_f_real(params, y) + grad_f_imag_closure = lambda y: grad_f_imag(params, y) + + def _body_fun(i, val): + primal_real, tangent_real = jax.jvp(grad_f_real_closure, (x,), (eye[i],)) + primal_imag, tangent_imag = jax.jvp(grad_f_imag_closure, (x,), (eye[i],)) + kine_real = val[0] + tangent_real[i] + primal_real[i] ** 2 - primal_imag[i] ** 2 + kine_imag = val[1] + tangent_imag[i] + 2 * primal_real[i] * primal_imag[i] + return [kine_real, kine_imag] + + result = jax.lax.fori_loop(0, ne, _body_fun, [0.0, 0.0]) + complex = [1., 1j] + return [-0.5 * re * com for re, com in zip(result, complex)] + + return lambda p, y: _lapl_over_f(p, y) + + +def local_kinetic_energy_real_imag_dim_batch(f): + + def _lapl_over_f(params, x): + ne = x.shape[-1] + eye = jnp.eye(ne) + grad_f_real = jax.grad(lambda p, y: f(p, y).real, argnums=1) + grad_f_imag = jax.grad(lambda p, y: f(p, y).imag, argnums=1) + grad_f_real_closure = lambda y: grad_f_real(params, y) + grad_f_imag_closure = lambda y: grad_f_imag(params, y) + + def _body_fun(dummy_eye): + primal_real, tangent_real = jax.jvp(grad_f_real_closure, (x,), (dummy_eye,)) + primal_imag, tangent_imag = jax.jvp(grad_f_imag_closure, (x,), (dummy_eye,)) + kine_real = ((tangent_real + primal_real ** 2 - primal_imag ** 2) * dummy_eye).sum() + kine_imag = ((tangent_imag + 2 * primal_real * primal_imag) * dummy_eye).sum() + return [kine_real, kine_imag] + + # result = jax.lax.fori_loop(0, ne, _body_fun, [0.0, 0.0]) + result = jax.vmap(_body_fun, in_axes=0)(eye) + result = [re.sum() for re in result] + complex = [1., 1j] + return [-0.5 * re * com for re, com in zip(result, complex)] + + return lambda p, y: _lapl_over_f(p, y) + + +def local_kinetic_energy_real_imag_hessian(f): + ''' + Use jax.hessian to evaluate laplacian, which requires huge amount of memory. + :param f: + :return: + ''' + def _lapl_over_f(params, x): + ne = x.shape[-1] + grad_f_real = jax.grad(lambda p, y: f(p, y).real, argnums=1) + grad_f_imag = jax.grad(lambda p, y: f(p, y).imag, argnums=1) + hessian_f_real = jax.hessian(lambda p, y: f(p, y).real, argnums=1) + hessian_f_imag = jax.hessian(lambda p, y: f(p, y).imag, argnums=1) + v_grad_f_real = grad_f_real(params, x) + v_grad_f_imag = grad_f_imag(params, x) + real_kinetic = jnp.trace(hessian_f_real(params, x),) + jnp.sum(v_grad_f_real**2) - jnp.sum(v_grad_f_imag**2) + imag_kinetic = jnp.trace(hessian_f_imag(params, x),) + jnp.sum(2 * v_grad_f_real * v_grad_f_imag) + + complex = [1., 1j] + return [-0.5 * re * com for re, com in zip([real_kinetic, imag_kinetic], complex)] + + return lambda p, y: _lapl_over_f(p, y) + + +def local_kinetic_energy_partition(f, partition_number=3): + ''' + Try to parallelize the evaluation of laplacian + :param f: + :param partition_number: + :return: + ''' + vjvp = jax.vmap(jax.jvp, in_axes=(None, None, 0)) + + def _lapl_over_f(params, x): + n = x.shape[0] + eye = jnp.eye(n) + grad_f_real = jax.grad(lambda p, y: f(p, y).real, argnums=1) + grad_f_imag = jax.grad(lambda p, y: f(p, y).imag, argnums=1) + grad_f_closure_real = lambda y: grad_f_real(params, y) + grad_f_closure_imag = lambda y: grad_f_imag(params, y) + + eyes = jnp.asarray(jnp.array_split(eye, partition_number)) + def _body_fun(val, e): + primal_real, tangent_real = vjvp(grad_f_closure_real, (x,), (e,)) + primal_imag, tangent_imag = vjvp(grad_f_closure_imag, (x,), (e,)) + return val, ([primal_real, primal_imag], [tangent_real, tangent_imag]) + _, (plist, tlist) = \ + jax.lax.scan(_body_fun, None, eyes) + primal = [primal.reshape((-1, primal.shape[-1])) for primal in plist] + tangent = [tangent.reshape((-1, tangent.shape[-1])) for tangent in tlist] + + real_kinetic = jnp.trace(tangent[0]) + jnp.trace(primal[0]**2).sum() - jnp.trace(primal[1]**2).sum() + imag_kinetic = jnp.trace(tangent[1]) + jnp.trace(2 * primal[0] * primal[1]).sum() + return [-0.5 * real_kinetic, -0.5 * 1j * imag_kinetic] + + return _lapl_over_f + + + +def local_ewald_energy(simulation_cell): + ewald = ewaldsum.EwaldSum(simulation_cell) + assert jnp.allclose(simulation_cell.energy_nuc(), + (ewald.ion_ion + ewald.ii_const), + rtol=1e-8, atol=1e-5) + ## check pyscf madelung constant agrees with DeepSolid + + def _local_ewald_energy(x): + energy = ewald.energy(x) + return sum(energy) + + return _local_ewald_energy + + +def local_energy(f, simulation_cell): + ke = local_kinetic_energy(f) + ew = local_ewald_energy(simulation_cell) + + def _local_energy(params, x): + kinetic = ke(params, x) + ewald = ew(x) + return kinetic + ewald + + return _local_energy + + +def local_energy_seperate(f, simulation_cell, mode='for', partition_number=3): + + if mode == 'for': + ke_ri = local_kinetic_energy_real_imag(f) + elif mode == 'hessian': + ke_ri = local_kinetic_energy_real_imag_hessian(f) + elif mode == 'dim_batch': + ke_ri = local_kinetic_energy_real_imag_dim_batch(f) + elif mode == 'partition': + ke_ri = local_kinetic_energy_partition(f, partition_number=partition_number) + else: + raise ValueError('Unrecognized laplacian evaluation mode.') + ke = lambda p, y: sum(ke_ri(p, y)) + # ke = local_kinetic_energy(f) + ew = local_ewald_energy(simulation_cell) + + def _local_energy(params, x): + kinetic = ke(params, x) + ewald = ew(x) + return kinetic, ewald + + return _local_energy diff --git a/DeepSolid/hf.py b/DeepSolid/hf.py new file mode 100644 index 0000000..658c748 --- /dev/null +++ b/DeepSolid/hf.py @@ -0,0 +1,193 @@ +# MIT License +# +# Copyright (c) 2019 Lucas K Wagner +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +from pyscf.pbc import gto, scf +from DeepSolid import supercell +from DeepSolid import distance +import numpy as np + +_gldict = {"laplacian": np.s_[:1], "gradient_laplacian": np.s_[0:4]} + + +def _aostack_mol(ao, gl): + return np.concatenate( + [ao[_gldict[gl]], ao[[4, 7, 9]].sum(axis=0, keepdims=True)], axis=0 + ) + + +def _aostack_pbc(ao, gl): + return [_aostack_mol(ak, gl) for ak in ao] + + +class SCF: + def __init__(self, cell, twist=np.ones(3)*0.5): + """ + Hartree Fock wave function class for QMC simulation + + :param cell: pyscf.pbc.gto.Cell, simulation object + :param twist:np.array with shape [3] + """ + self._aostack = _aostack_pbc + self.coeff_key = ("mo_coeff_alpha", "mo_coeff_beta") + self.param_split = {} + self.parameters = {} + self.k_split = {} + self.ns_tol = cell.scale + self.simulation_cell = cell + self.primitive_cell = cell.original_cell + self.sim_nelec = self.simulation_cell.nelec + self.kpts = supercell.get_supercell_kpts(self.simulation_cell) + self.kpts = self.kpts + np.dot(np.linalg.inv(cell.a), np.mod(twist, 1.0)) * 2 * np.pi + if hasattr(self.simulation_cell, 'hf_type'): + hf_type = self.simulation_cell.hf_type + else: + hf_type = 'rhf' + + if hf_type == 'uhf': + self.kmf = scf.KUHF(self.primitive_cell, exxdiv='ewald', kpts=self.kpts).density_fit() + + # break initial guess symmetry for UHF + dm_up, dm_down = self.kmf.get_init_guess() + dm_down[:, :2, :2] = 0 + dm = (dm_up, dm_down) + elif hf_type == 'rhf': + self.kmf = scf.KHF(self.primitive_cell, exxdiv='ewald', kpts=self.kpts).density_fit() + dm = self.kmf.get_init_guess() + else: + raise ValueError('Unrecognized Hartree Fock type.') + + self.kmf.kernel(dm) + # self.init_scf() + + def init_scf(self): + self.klist = [] + for s, key in enumerate(self.coeff_key): + mclist = [] + for k in range(self.kmf.kpts.shape[0]): + # restrict or not + if len(self.kmf.mo_coeff[0][0].shape) == 2: + mca = self.kmf.mo_coeff[s][k][:, np.asarray(self.kmf.mo_occ[s][k] > 0.9)] + else: + minocc = (0.9, 1.1)[s] + mca = self.kmf.mo_coeff[k][:, np.asarray(self.kmf.mo_occ[k] > minocc)] + mclist.append(mca) + self.param_split[key] = np.cumsum([m.shape[1] for m in mclist]) + self.parameters[key] = np.concatenate(mclist, axis=-1) + self.k_split[key] = np.array([m.shape[1] for m in mclist]) + self.klist.append(np.concatenate([np.tile(kpt[None, :], (split, 1)) + for kpt, split in + zip(self.kmf.kpts, self.k_split[self.coeff_key[s]])])) + + def eval_orbitals_pbc(self, coord, eval_str="GTOval_sph"): + prim_coord, wrap = distance.np_enforce_pbc(self.primitive_cell.a, coord.reshape([coord.shape[0], -1])) + prim_coord = prim_coord.reshape([-1, 3]) + wrap = wrap.reshape([-1, 3]) + ao = self.primitive_cell.eval_gto("PBC" + eval_str, prim_coord, kpts=self.kmf.kpts) + + kdotR = np.einsum('ij,kj,nk->in', self.kmf.kpts, self.primitive_cell.a, wrap) + wrap_phase = np.exp(1j*kdotR) + ao = [ao[k] * wrap_phase[k][:, None] for k in range(len(self.kmf.kpts))] + + return ao + + def eval_mos_pbc(self, aos, s): + c = self.coeff_key[s] + p = np.split(self.parameters[c], self.param_split[c], axis=-1) + mo = [ao.dot(p[k]) for k, ao in enumerate(aos)] + return np.concatenate(mo, axis=-1) + + def eval_orb_mat(self, coord): + batch, nelec, ndim = coord.shape + aos = self.eval_orbitals_pbc(coord) + aos_shape = (self.ns_tol, batch, nelec, -1) + + aos = np.reshape(aos, aos_shape) + mos = [] + for s in [0, 1]: + i0, i1 = s * self.sim_nelec[0], self.sim_nelec[0] + s * self.sim_nelec[1] + ne = self.sim_nelec[s] + mo = self.eval_mos_pbc(aos[:, :, i0:i1], s).reshape([batch, ne, ne]) + mos.append(mo) + return mos + + def eval_slogdet(self, coord): + mos = self.eval_orb_mat(coord) + slogdets = [np.linalg.slogdet(mo) for mo in mos] + phase, slogdet = list(map(lambda x, y: [x[0] * y[0], x[1] + y[1]], *zip(slogdets)))[0] + + return phase, slogdet + + def eval_phase(self, coord): + """ + + :param coord: + :return: a list of phase with shape [B, nk * nao] + """ + coords = np.split(coord, (self.sim_nelec[0], sum(self.sim_nelec)), axis=1) + kdots = [np.einsum('ijl, kl->ijk', cor, kpt) for cor, kpt in zip(coords, self.klist)] + phase = [np.exp(1j * kdot) for kdot in kdots] + return phase + + def pure_periodic(self, coord): + orbitals = self.eval_orb_mat(coord) + ## minus symbol makes mos to be periodical + phases = self.eval_phase(-coord) + return [orbital * phase for orbital, phase in zip(orbitals, phases)] + + def eval_inverse(self, coord): + mats = self.eval_orb_mat(coord) + inverse = [np.linalg.inv(mat) for mat in mats] + + return inverse + + def _testrow(self, e, vec, inverse, mask=None, spin=None): + """vec is a nconfig,nmo vector which replaces row e""" + s = int(e >= self.sim_nelec[0]) if spin is None else spin + elec = e - s * self.sim_nelec[0] + if mask is None: + return np.einsum("i...j,ij...->i...", vec, inverse[s][:, :, elec]) + + return np.einsum("i...j,ij...->i...", vec, inverse[s][mask][:, :, elec]) + + def laplacian(self, e, coord, inverse): + s = int(e >= self.sim_nelec[0]) + ao = self.eval_orbitals_pbc(coord, eval_str="GTOval_sph_deriv2") + mo = self.eval_mos_pbc(self._aostack(ao, "laplacian"), s) + ratios = np.asarray([self._testrow(e, x, inverse=inverse) for x in mo]) + return ratios[1] / ratios[0] + + def kinetic(self, coord): + ke = np.zeros(coord.shape[0]) + inverse = self.eval_inverse(coord) + for e in range(self.simulation_cell.nelectron): + ke = ke - 0.5 * np.real(self.laplacian(e, + coord[:, e, :], + inverse=inverse)) + return ke + + def __call__(self, coord): + phase, slogdet = self.eval_slogdet(coord) + psi = np.exp(slogdet) * phase + return psi diff --git a/DeepSolid/init_guess.py b/DeepSolid/init_guess.py new file mode 100644 index 0000000..aa830e1 --- /dev/null +++ b/DeepSolid/init_guess.py @@ -0,0 +1,96 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import jax +import jax.numpy as jnp +import numpy as np +import pyscf.pbc.gto +from typing import Sequence +from DeepSolid.utils import system +from DeepSolid import distance + + +def init_electrons( + key, + cell: Sequence[system.Atom], + latvec, + electrons: Sequence[int], + batch_size: int, + init_width=0.5, +) -> jnp.ndarray: + """ + Initializes electron positions around each atom. + + :param key: jax key for random + :param cell: internal representation of simulation cell + :param latvec: lattice vector of cell + :param electrons: list of up, down electrons + :param batch_size: batch_size for simulation + :param init_width: std of gaussian used for initialization + :return: jnp.array with shape [Batch_size, N_ele * ndim] + """ + if sum(atom.charge for atom in cell) != sum(electrons): + if len(cell) == 1: + atomic_spin_configs = [electrons] + else: + raise NotImplementedError('No initialization policy yet ' + 'exists for charged molecules.') + else: + + atomic_spin_configs = [ + (atom.element.nalpha - int((atom.atomic_number - atom.charge) // 2), + atom.element.nbeta - int((atom.atomic_number - atom.charge) // 2)) + for atom in cell + ] + # element.nalpha return the up spin number of the single element, if ecp is used, [nalpha,nbeta] should be reduce + # with the the core charge which equals atomic_number - atom.charge + assert sum(sum(x) for x in atomic_spin_configs) == sum(electrons) + while tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons: + i = np.random.randint(len(atomic_spin_configs)) + nalpha, nbeta = atomic_spin_configs[i] + if atomic_spin_configs[i][0] > 0: + atomic_spin_configs[i] = nalpha - 1, nbeta + 1 + + # Assign each electron to an atom initially. + electron_positions = [] + for i in range(2): + for j in range(len(cell)): + atom_position = jnp.asarray(cell[j].coords) + electron_positions.append(jnp.tile(atom_position, atomic_spin_configs[j][i])) + electron_positions = jnp.concatenate(electron_positions) + # Create a batch of configurations with a Gaussian distribution about each + # atom. + key, subkey = jax.random.split(key) + guess = electron_positions + init_width * jax.random.normal(subkey, shape=(batch_size, electron_positions.size)) + replaced_guess, _ = distance.enforce_pbc(latvec, guess) + return replaced_guess + + + +def pyscf_to_cell(cell: pyscf.pbc.gto.Cell): + """ + Converts the pyscf cell to the internal representation. + + :param cell: pyscf.cell object + :return: internal cell representation + """ + internal_cell = [system.Atom(cell.atom_symbol(i), + cell.atom_coords()[i], + charge=cell.atom_charges()[i], ) + for i in range(cell.natm)] + ## cfg.system.pyscf_mol.atom_charges()[i] return the screen charge of i atom if ecp is used + return internal_cell \ No newline at end of file diff --git a/DeepSolid/network.py b/DeepSolid/network.py new file mode 100644 index 0000000..f34835c --- /dev/null +++ b/DeepSolid/network.py @@ -0,0 +1,567 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +"""Implementation of Fermionic Neural Network in JAX.""" +import functools +from collections import namedtuple +from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Union + +from DeepSolid import curvature_tags_and_blocks +from DeepSolid import distance + +import jax +import jax.numpy as jnp + +_MAX_POLY_ORDER = 5 # highest polynomial used in envelopes + +FermiLayers = Tuple[Tuple[int, int], ...] +# Recursive types are not yet supported in pytype - b/109648354. +# pytype: disable=not-supported-yet +ParamTree = Union[jnp.ndarray, Iterable['ParamTree'], Mapping[Any, 'ParamTree']] +# pytype: enable=not-supported-yet +# init(key) -> params +FermiNetInit = Callable[[jnp.ndarray], ParamTree] +# network(params, x) -> sign_out, log_out +FermiNetApply = Callable[[ParamTree, jnp.ndarray], Tuple[jnp.ndarray, + jnp.ndarray]] + +def enforce_pbc(latvec, epos): + """ + Enforces periodic boundary conditions on a set of configs. + + :param latvec: orthogonal lattice vectors defining 3D torus: (3,3) + :param epos: attempted new electron coordinates: (N_ele, 3) + :return: final electron coordinates with PBCs imposed: (N_ele, 3) + """ + + # Writes epos in terms of (lattice vecs) fractional coordinates + recpvecs = jnp.linalg.inv(latvec) + epos_lvecs_coord = jnp.einsum("ij,jk->ik", epos, recpvecs) + wrap = epos_lvecs_coord // 1 + final_epos = jnp.matmul(epos_lvecs_coord - wrap, latvec) + + return final_epos, wrap + + +def init_solid_fermi_net_params( + key: jnp.ndarray, + data, + atoms: jnp.ndarray, + spins: Tuple[int, int], + envelope_type: str = 'full', + bias_orbitals: bool = False, + use_last_layer: bool = False, + eps: float = 0.01, + full_det: bool = True, + hidden_dims: FermiLayers = ((256, 32), (256, 32), (256, 32)), + determinants: int = 16, + after_determinants: Union[int, Tuple[int, ...]] = 1, +): + """Initializes parameters for the Fermionic Neural Network. + + Args: + key: JAX RNG state. + atoms: (natom, 3) array of atom positions. + spins: Tuple of the number of spin-up and spin-down electrons. + envelope_type: Envelope to use to impose orbitals go to zero at infinity. + See solid_fermi_net_orbitals. + bias_orbitals: If true, include a bias in the final linear layer to shape + the outputs into orbitals. + use_last_layer: If true, the outputs of the one- and two-electron streams + are combined into permutation-equivariant features and passed into the + final orbital-shaping layer. Otherwise, just the output of the + one-electron stream is passed into the orbital-shaping layer. + hf_solution: If present, initialise the parameters to match the Hartree-Fock + solution. Otherwise a random initialisation is use. + eps: If hf_solution is present, scale all weights and biases except the + first layer by this factor such that they are initialised close to zero. + full_det: If true, evaluate determinants over all electrons. Otherwise, + block-diagonalise determinants into spin channels. + hidden_dims: Tuple of pairs, where each pair contains the number of hidden + units in the one-electron and two-electron stream in the corresponding + layer of the FermiNet. The number of layers is given by the length of the + tuple. + determinants: Number of determinants to use. + after_determinants: currently ignored. + + Returns: + PyTree of network parameters. + """ + # after_det is from the legacy QMC TF implementation. Reserving for future + # use. + del after_determinants + del data + + natom = atoms.shape[0] + in_dims = (natom * 4, 4) + active_spin_channels = [spin for spin in spins if spin > 0] + nchannels = len(active_spin_channels) + # The input to layer L of the one-electron stream is from + # construct_symmetric_features and shape (nelectrons, nfeatures), where + # nfeatures is i) output from the previous one-electron layer; ii) the mean + # for each spin channel from each layer; iii) the mean for each spin channel + # from each two-electron layer. We don't create features for spin channels + # which contain no electrons (i.e. spin-polarised systems). + dims_one_in = ( + [(nchannels + 1) * in_dims[0] + nchannels * in_dims[1]] + + [(nchannels + 1) * hdim[0] + nchannels * hdim[1] for hdim in hidden_dims]) + if not use_last_layer: + dims_one_in[-1] = hidden_dims[-1][0] + dims_one_out = [hdim[0] for hdim in hidden_dims] + dims_two = [in_dims[1]] + [hdim[1] for hdim in hidden_dims] + + len_double = len(hidden_dims) if use_last_layer else len(hidden_dims) - 1 + params = { + 'single': [{} for _ in range(len(hidden_dims))], + 'double': [{} for _ in range(len_double)], + 'orbital': [], + 'envelope': [{} for _ in active_spin_channels], + } + + # params['envelope'] = [{} for _ in active_spin_channels] + for i, spin in enumerate(active_spin_channels): + nparam = sum(spins) * determinants if full_det else spin * determinants + params['envelope'][i]['pi'] = jnp.ones((natom, nparam)) + if envelope_type == 'isotropic': + params['envelope'][i]['sigma'] = jnp.ones((natom, nparam)) + elif envelope_type == 'diagonal': + params['envelope'][i]['sigma'] = jnp.ones((natom, 3, nparam)) + elif envelope_type == 'full': + params['envelope'][i]['sigma'] = jnp.tile( + jnp.eye(3)[..., None, None], [1, 1, natom, nparam]) + + for i in range(len(hidden_dims)): + key, subkey = jax.random.split(key) + params['single'][i]['w'] = (jax.random.normal( + subkey, shape=(dims_one_in[i], dims_one_out[i])) / + jnp.sqrt(float(dims_one_in[i]))) + + key, subkey = jax.random.split(key) + params['single'][i]['b'] = jax.random.normal( + subkey, shape=(dims_one_out[i],)) + + if i < len_double: + key, subkey = jax.random.split(key) + params['double'][i]['w'] = (jax.random.normal( + subkey, shape=(dims_two[i], dims_two[i + 1])) / + jnp.sqrt(float(dims_two[i]))) + + key, subkey = jax.random.split(key) + params['double'][i]['b'] = jax.random.normal(subkey, + shape=(dims_two[i + 1],)) + + for i, spin in enumerate(active_spin_channels): + nparam = sum(spins) * determinants if full_det else spin * determinants + key, subkey = jax.random.split(key) + params['orbital'].append({}) + params['orbital'][i]['w'] = (jax.random.normal( + subkey, shape=(dims_one_in[-1], 2 * nparam)) / + jnp.sqrt(float(dims_one_in[-1]))) + if bias_orbitals: + key, subkey = jax.random.split(key) + params['orbital'][i]['b'] = jax.random.normal( + subkey, shape=(2 * nparam,)) + + return params + + +def construct_input_features( + x: jnp.ndarray, + atoms: jnp.ndarray, + ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + assert atoms.shape[1] == ndim + ae = jnp.reshape(x, [-1, 1, ndim]) - atoms[None, ...] + ee = jnp.reshape(x, [1, -1, ndim]) - jnp.reshape(x, [-1, 1, ndim]) + + r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True) + # Avoid computing the norm of zero, as is has undefined grad + n = ee.shape[0] + r_ee = ( + jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n))) + + return ae, ee, r_ae, r_ee[..., None] + + +def scaled_f(w): + return jnp.abs(w) * (1 - jnp.abs(w / jnp.pi) ** 3 / 4.) + + +def scaled_g(w): + return w * (1 - 3. / 2. * jnp.abs(w / jnp.pi) + 1. / 2. * jnp.abs(w / jnp.pi) ** 2) + + +def sin_relative_distance(xea, a, b): + ''' + + :param xea: + :param a: + :param b: + :return: periodic relative distance [ne, na, 6] + ''' + w = jnp.einsum('...ijk,lk->...ijl', xea, b) + mod = (w + jnp.pi) // (2 * jnp.pi) + w = (w - mod * 2 * jnp.pi) + # w = jnp.mod(w + jnp.pi, 2 * jnp.pi) - jnp.pi + # r1 = jnp.einsum('...i,ij->...j', scaled_f(w), a) + r2 = jnp.einsum('...i,ij->...j', scaled_g(w), a) + return r2 + + +def nu_distance(xea, a, b): + w = jnp.einsum('...ijk,lk->...ijl', xea, b) + mod = (w + jnp.pi) // (2 * jnp.pi) + w = (w - mod * 2 * jnp.pi) + r1 = (jnp.linalg.norm(a, axis=-1) * scaled_f(w)) ** 2 + sg = scaled_g(w) + rel = jnp.einsum('...i,ij->...j', sg, a) + r2 = jnp.einsum('ij,kj->ik', a, a) * (sg[..., :, None] * sg[..., None, :]) + result = jnp.sum(r1, axis=-1) + jnp.sum(r2 * (jnp.ones(r2.shape[-2:]) - jnp.eye(r2.shape[-1])), axis=[-1, -2]) + sd = result ** 0.5 + return sd, rel + +def construct_periodic_input_features( + x: jnp.ndarray, + atoms: jnp.ndarray, + simulation_cell=None, + ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + primitive_cell = simulation_cell.original_cell + x = x.reshape(-1, ndim) + n = x.shape[0] + prim_x, _ = enforce_pbc(primitive_cell.a, x) + + # prim_xea = minimal_imag.dist_i(atoms.ravel(), prim_x.ravel()) + prim_xea = prim_x[..., None, :] - atoms + prim_periodic_sea, prim_periodic_xea = nu_distance(prim_xea, + primitive_cell.AV, + primitive_cell.BV) + prim_periodic_sea = prim_periodic_sea[..., None] + + sim_x, _ = enforce_pbc(simulation_cell.a, x) + # sim_xee = sim_minimal_imag.dist_matrix(sim_x.ravel()) + sim_xee = sim_x[:, None, :] - sim_x[None, :, :] + + sim_periodic_see, sim_periodic_xee = nu_distance(sim_xee + jnp.eye(n)[..., None], + simulation_cell.AV, + simulation_cell.BV) + sim_periodic_see = sim_periodic_see * (1.0 - jnp.eye(n)) + sim_periodic_see = sim_periodic_see[..., None] + + sim_periodic_xee = sim_periodic_xee * (1.0 - jnp.eye(n))[..., None] + + return prim_periodic_xea, sim_periodic_xee, prim_periodic_sea, sim_periodic_see + + +def construct_symmetric_features(h_one: jnp.ndarray, h_two: jnp.ndarray, + spins: Tuple[int, int]) -> jnp.ndarray: + # Split features into spin up and spin down electrons + h_ones = jnp.split(h_one, spins[0:1], axis=0) + h_twos = jnp.split(h_two, spins[0:1], axis=0) + + # Construct inputs to next layer + # h.size == 0 corresponds to unoccupied spin channels. + g_one = [jnp.mean(h, axis=0, keepdims=True) for h in h_ones if h.size > 0] + g_two = [jnp.mean(h, axis=0) for h in h_twos if h.size > 0] + + g_one = [jnp.tile(g, [h_one.shape[0], 1]) for g in g_one] + + return jnp.concatenate([h_one] + g_one + g_two, axis=1) + + +def isotropic_envelope(ae, params): + """Computes an isotropic exponentially-decaying multiplicative envelope.""" + return jnp.sum(jnp.exp(-jnp.abs(ae * params['sigma'])) * params['pi'], axis=1) + + +def diagonal_envelope(ae, params): + """Computes a diagonal exponentially-decaying multiplicative envelope.""" + r_ae = jnp.linalg.norm(ae[..., None] * params['sigma'], axis=2) + return jnp.sum(jnp.exp(-r_ae) * params['pi'], axis=1) + + +vdot = jax.vmap(jnp.dot, (0, 0)) + + +def apply_covariance(x, y): + """Equivalent to jnp.einsum('ijk,kmjn->ijmn', x, y).""" + i, _, _ = x.shape + k, m, j, n = y.shape + x = x.transpose((1, 0, 2)) + y = y.transpose((2, 0, 1, 3)).reshape((j, k, m * n)) + return vdot(x, y).reshape((j, i, m, n)).transpose((1, 0, 2, 3)) + + +def full_envelope(ae, params): + """Computes a fully anisotropic exponentially-decaying multiplicative envelope.""" + r_ae = apply_covariance(ae, params['sigma']) + r_ae = curvature_tags_and_blocks.register_qmc1(r_ae, ae, params['sigma'], + type='full') + r_ae = jnp.linalg.norm(r_ae, axis=2) + return jnp.sum(jnp.exp(-r_ae) * params['pi'], axis=1) + + +def output_envelope(ae, params): + """Fully anisotropic envelope, but only one output.""" + sigma = jnp.expand_dims(params['sigma'], -1) + ae_sigma = jnp.squeeze(apply_covariance(ae, sigma), axis=-1) + r_ae = jnp.linalg.norm(ae_sigma, axis=2) + return jnp.sum(jnp.log(jnp.sum(jnp.exp(-r_ae + params['pi']), axis=1))) + + +def slogdet_op(x): + """Computes sign and log of determinants of matrices. + + This is a jnp.linalg.slogdet with a special (fast) path for small matrices. + + Args: + x: square matrix. + + Returns: + sign, (natural) logarithm of the determinant of x. + """ + if x.shape[-1] == 1: + sign = jnp.exp(1j*jnp.angle(x[..., 0, 0])) + logdet = jnp.log(jnp.abs(x[..., 0, 0])) + else: + sign, logdet = jnp.linalg.slogdet(x) + + return sign, logdet + + +def logdet_matmul(xs: Sequence[jnp.ndarray], + w: Optional[jnp.ndarray] = None) -> jnp.ndarray: + """Combines determinants and takes dot product with weights in log-domain. + + We use the log-sum-exp trick to reduce numerical instabilities. + + Args: + xs: FermiNet orbitals in each determinant. Either of length 1 with shape + (ndet, nelectron, nelectron) (full_det=True) or length 2 with shapes + (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta) (full_det=False, + determinants are factorised into block-diagonals for each spin channel). + w: weight of each determinant. If none, a uniform weight is assumed. + + Returns: + sum_i w_i D_i in the log domain, where w_i is the weight of D_i, the i-th + determinant (or product of the i-th determinant in each spin channel, if + full_det is not used). + """ + slogdets = [slogdet_op(x) for x in xs] + sign_in, slogdet = functools.reduce( + lambda a, b: (a[0] * b[0], a[1] + b[1]), slogdets) + max_idx = jnp.argmax(slogdet) + # sign_in_max = sign_in[max_idx] + slogdet_max = slogdet[max_idx] + # log-sum-exp trick + det = sign_in * jnp.exp(slogdet-slogdet_max) + if w is None: + result = jnp.sum(det) + else: + result = jnp.matmul(det, w)[0] + sign_out = jnp.exp(1j*jnp.angle(result)) + slog_out = jnp.log(jnp.abs(result)) + slogdet_max + return sign_out, slog_out + + +def linear_layer(x, w, b=None): + """Evaluates a linear layer, x w + b. + + Args: + x: inputs. + w: weights. + b: optional bias. Only x w is computed if b is None. + + Returns: + x w + b if b is given, x w otherwise. + """ + y = jnp.dot(x, w) + y = y + b if b is not None else y + return curvature_tags_and_blocks.register_repeated_dense(y, x, w, b) + + +vmap_linear_layer = jax.vmap(linear_layer, in_axes=(0, None, None), out_axes=0) + + +def eval_phase(x, klist, ndim=3, spins=None, full_det=False): + x = x.reshape([-1, ndim]) + xs = jnp.split(x, spins[0:1], axis=-2) + if full_det: + klist = jnp.concatenate(klist, axis=0) + kdot_xs = [jnp.matmul(x, klist.T) for x, ne in zip(xs, spins) if ne > 0] + else: + kdot_xs = [jnp.matmul(x, kpt.T) for x, kpt, ne in zip(xs, klist, spins) if ne > 0] + phases = [jnp.exp(1j * kdot_x) for kdot_x in kdot_xs] + return phases + + +def solid_fermi_net_orbitals(params, x, + simulation_cell=None, + klist=None, + atoms=None, + spins=(None, None), + envelope_type=None, + full_det=False): + + ae_, ee_, r_ae, r_ee = construct_periodic_input_features(x, atoms, + simulation_cell=simulation_cell, + ) + ae = jnp.concatenate((r_ae, ae_), axis=2) + ae = jnp.reshape(ae, [jnp.shape(ae)[0], -1]) + ee = jnp.concatenate((r_ee, ee_), axis=2) + + # which variable do we pass to envelope? + to_env = r_ae if envelope_type == 'isotropic' else ae_ + + if envelope_type == 'isotropic': + envelope = isotropic_envelope + elif envelope_type == 'diagonal': + envelope = diagonal_envelope + elif envelope_type == 'full': + envelope = full_envelope + + h_one = ae # single-electron features + h_two = ee # two-electron features + residual = lambda x, y: (x + y) / jnp.sqrt(2.0) if x.shape == y.shape else y + for i in range(len(params['double'])): + h_one_in = construct_symmetric_features(h_one, h_two, spins) + + # Execute next layer + h_one_next = jnp.tanh(linear_layer(h_one_in, **params['single'][i])) + h_two_next = jnp.tanh(vmap_linear_layer(h_two, params['double'][i]['w'], + params['double'][i]['b'])) + h_one = residual(h_one, h_one_next) + h_two = residual(h_two, h_two_next) + if len(params['double']) != len(params['single']): + h_one_in = construct_symmetric_features(h_one, h_two, spins) + h_one_next = jnp.tanh(linear_layer(h_one_in, **params['single'][-1])) + h_one = residual(h_one, h_one_next) + h_to_orbitals = h_one + else: + h_to_orbitals = construct_symmetric_features(h_one, h_two, spins) + # Note split creates arrays of size 0 for spin channels without any electrons. + h_to_orbitals = jnp.split(h_to_orbitals, spins[0:1], axis=0) + + active_spin_channels = [spin for spin in spins if spin > 0] + orbitals = [linear_layer(h, **p) + for h, p in zip(h_to_orbitals, params['orbital'])] + + for i, spin in enumerate(active_spin_channels): + nparams = params['orbital'][i]['w'].shape[-1] // 2 + orbitals[i] = orbitals[i][..., :nparams] + 1j * orbitals[i][..., nparams:] + + if envelope_type in ['isotropic', 'diagonal', 'full']: + orbitals = [envelope(te, param) * orbital for te, orbital, param in + zip(jnp.split(to_env, active_spin_channels[:-1], axis=0), + orbitals, params['envelope'])] + # Reshape into matrices and drop unoccupied spin channels. + orbitals = [jnp.reshape(orbital, [spin, -1, sum(spins) if full_det else spin]) + for spin, orbital in zip(active_spin_channels, orbitals) if spin > 0] + orbitals = [jnp.transpose(orbital, (1, 0, 2)) for orbital in orbitals] + phases = eval_phase(x, klist=klist, ndim=3, spins=spins, full_det=full_det) + + orbitals = [orb * p[None, :, :] for orb, p in zip(orbitals, phases)] + if full_det: + orbitals = [jnp.concatenate(orbitals, axis=1)] + return orbitals, to_env + + +def eval_func(params, x, + klist=None, + simulation_cell=None, + atoms=None, + spins=(None, None), + envelope_type='full', + full_det=False, + method_name='eval_slogdet'): + + orbitals, to_env = solid_fermi_net_orbitals(params, x, + klist=klist, + simulation_cell=simulation_cell, + atoms=atoms, + spins=spins, + envelope_type=envelope_type, + full_det=full_det) + if method_name == 'eval_slogdet': + _, result = logdet_matmul(orbitals) + elif method_name == 'eval_logdet': + sign, slogdet = logdet_matmul(orbitals) + result = jnp.log(sign) + slogdet + elif method_name == 'eval_phase_and_slogdet': + result = logdet_matmul(orbitals) + elif method_name == 'eval_mats': + result = orbitals + else: + raise ValueError('Unrecognized method name') + + return result + + +def make_solid_fermi_net( + envelope_type: str = 'full', + bias_orbitals: bool = False, + use_last_layer: bool = False, + klist=None, + simulation_cell=None, + full_det: bool = True, + hidden_dims: FermiLayers = ((256, 32), (256, 32), (256, 32)), + determinants: int = 16, + after_determinants: Union[int, Tuple[int, ...]] = 1, + method_name='eval_logdet', +): + ''' + + :param envelope_type: specify envelope + :param bias_orbitals: whether to contain bias in the last layer of orbitals + :param use_last_layer: wheter to use two-electron feature in the last layer + :param klist: occupied k points from HF + :param simulation_cell: simulation cell + :param full_det: specify the mode of wavefunction, spin diagonalized or not. + :param hidden_dims: specify the dimension of one-electron and two-electron layer + :param determinants: the number of determinants used + :param after_determinants: deleted + :param method_name: specify the returned function + :return: a haiku like module which contain init and apply method. init is used to initialize the parameter of + network and apply method perform the calculation. + ''' + if method_name not in ['eval_slogdet', 'eval_logdet', 'eval_mats', 'eval_phase_and_slogdet']: + raise ValueError('Method name is not in class dir.') + + method = namedtuple('method', ['init', 'apply']) + init = functools.partial( + init_solid_fermi_net_params, + atoms=simulation_cell.original_cell.atom_coords(), + spins=simulation_cell.nelec, + envelope_type=envelope_type, + bias_orbitals=bias_orbitals, + use_last_layer=use_last_layer, + full_det=full_det, + hidden_dims=hidden_dims, + determinants=determinants, + after_determinants=after_determinants, + ) + network = functools.partial( + eval_func, + simulation_cell=simulation_cell, + klist=klist, + atoms=simulation_cell.original_cell.atom_coords(), + spins=simulation_cell.nelec, + envelope_type=envelope_type, + full_det=full_det, + method_name=method_name, + ) + method.init = init + method.apply = network + return method diff --git a/DeepSolid/pretrain.py b/DeepSolid/pretrain.py new file mode 100644 index 0000000..a94d63e --- /dev/null +++ b/DeepSolid/pretrain.py @@ -0,0 +1,228 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import functools + +import numpy as np +from absl import logging +import jax +import jax.numpy as jnp +import optax + +from DeepSolid import hf +from DeepSolid import qmc +from DeepSolid import constants + + +def _batch_slater_slogdet(scf: hf.SCF, dim=3): + + def batch_slater_slogdet(params, x): + del params + batch = x.shape[0] + x = x.reshape([batch, -1, dim]) + result = scf.eval_slogdet(x)[1] + return result + + return batch_slater_slogdet + + +def make_pretrain_step(batch_orbitals, + batch_network, + latvec, + optimizer, + full_det=False, + ): + """ + + :param batch_orbitals: + :param batch_network: + :param latvec: + :param optimizer: + :return: + """ + + def pretrain_step(data, target, params, state, key): + """One iteration of pretraining to match HF.""" + + def loss_fn(x, p, target): + predict = batch_orbitals(p, x) + if full_det: + batch_size = predict[0].shape[0] + na = target[0].shape[1] + nb = target[1].shape[1] + target = [jnp.concatenate( + (jnp.concatenate((target[0], jnp.zeros((batch_size, na, nb))), axis=-1), + jnp.concatenate((jnp.zeros((batch_size, nb, na)), target[1]), axis=-1)), + axis=-2)] + result = jnp.array([jnp.mean(jnp.abs(tar[:, None, ...] - pre)**2) + for tar, pre in zip(target, predict)]).mean() + return constants.pmean_if_pmap(result, axis_name=constants.PMAP_AXIS_NAME) + + val_and_grad = jax.value_and_grad(loss_fn, argnums=1) + loss_val, search_direction = val_and_grad(data, params, target) + search_direction = constants.pmean_if_pmap( + search_direction, axis_name=constants.PMAP_AXIS_NAME) + updates, state = optimizer.update(search_direction, state, params) + params = optax.apply_updates(params, updates) + logprob = 2 * batch_network(params, data) + data, key, logprob, num_accepts = qmc.mh_update(params=params, + f=batch_network, + x1=data, + key=key, + lp_1=logprob, + num_accepts=0, + latvec=latvec) + return data, params, state, loss_val, logprob, num_accepts + + return pretrain_step + + +def pretrain_hartree_fock(params, + data, + batch_network, + batch_orbitals, + sharded_key, + cell, + scf_approx: hf.SCF, + full_det=False, + iterations=1000, + learning_rate=5e-3, + ): + + optimizer = optax.adam(learning_rate) + opt_state_pt = constants.pmap(optimizer.init)(params) + leading_shape = data.shape[:-1] + + pretrain_step = make_pretrain_step(batch_orbitals=batch_orbitals, + batch_network=batch_network, + latvec=cell.lattice_vectors(), + optimizer=optimizer, + full_det=full_det,) + pretrain_step = constants.pmap(pretrain_step) + + for t in range(iterations): + target = scf_approx.eval_orb_mat(np.array(data.reshape([-1, cell.nelectron, 3]), dtype=np.float64)) + # PYSCF PBC eval_gto seems only accept float64 array, float32 array will easily cause nan or underflow. + target = [jnp.array(tar) for tar in target] + target = [tar.reshape([*leading_shape, ne, ne]) for tar, ne in zip(target, cell.nelec) if ne > 0] + + slogprob_target = [2 * jnp.linalg.slogdet(tar)[1] for tar in target] + slogprob_target = functools.reduce(lambda x, y: x+y, slogprob_target) + sharded_key, subkeys = constants.p_split(sharded_key) + data, params, opt_state_pt, loss, logprob, num_accepts = pretrain_step( + data, target, params, opt_state_pt, subkeys) + logging.info('Pretrain iter %05d: Loss=%03.6f, pmove=%0.2f, ' + 'Norm of Net prob=%03.4f, Norm of HF prob=%03.4f', + t, loss[0], + jnp.mean(num_accepts) / leading_shape[-1], + jnp.mean(logprob), + jnp.mean(slogprob_target)) + + return params, data + + +def pretrain_hartree_fock_usingHF(params, + data, + batch_orbitals, + sharded_key, + cell, + scf_approx: hf.SCF, + iterations=1000, + learning_rate=5e-3, + nsteps=1, + full_det=False, + ): + + optimizer = optax.adam(learning_rate) + opt_state_pt = constants.pmap(optimizer.init)(params) + leading_shape = data.shape[:-1] + + def make_pretrain_step(batch_orbitals, + latvec, + optimizer, + ): + + def pretrain_step(data, target, params, state): + + def loss_fn(x, p, target): + predict = batch_orbitals(p, x) + if full_det: + batch_size = predict[0].shape[0] + na = target[0].shape[1] + nb = target[1].shape[1] + target = [jnp.concatenate( + (jnp.concatenate((target[0], jnp.zeros((batch_size, na, nb))), axis=-1), + jnp.concatenate((jnp.zeros((batch_size, nb, na)), target[1]), axis=-1)), + axis=-2)] + result = jnp.array([jnp.mean(jnp.abs(tar[:, None, ...] - pre) ** 2) + for tar, pre in zip(target, predict)]).mean() + return constants.pmean_if_pmap(result, axis_name=constants.PMAP_AXIS_NAME) + + val_and_grad = jax.value_and_grad(loss_fn, argnums=1) + loss_val, search_direction = val_and_grad(data, params, target) + search_direction = constants.pmean_if_pmap( + search_direction, axis_name=constants.PMAP_AXIS_NAME) + updates, state = optimizer.update(search_direction, state, params) + params = optax.apply_updates(params, updates) + + return params, state, loss_val + + return pretrain_step + + + pretrain_step = make_pretrain_step(batch_orbitals=batch_orbitals, + latvec=cell.lattice_vectors(), + optimizer=optimizer,) + pretrain_step = constants.pmap(pretrain_step) + batch_network = _batch_slater_slogdet(scf_approx) + logprob = 2 * batch_network(None, data.reshape([-1, cell.nelectron * 3])) + + def step_fn(inputs): + return qmc.mh_update(params, + batch_network, + *inputs, + latvec=cell.lattice_vectors(), + ) + + for t in range(iterations): + + for _ in range(nsteps): + sharded_key, subkeys = constants.p_split(sharded_key) + inputs = (data.reshape([-1, cell.nelectron * 3]), + sharded_key[0], + logprob, + 0.) + data, _, logprob, num_accepts = step_fn(inputs) + + data = data.reshape([*leading_shape, -1]) + target = scf_approx.eval_orb_mat(data.reshape([-1, cell.nelectron, 3])) + target = [tar.reshape([*leading_shape, ne, ne]) for tar, ne in zip(target, cell.nelec) if ne > 0] + + slogprob_net = [2 * jnp.linalg.slogdet(net_mat)[1] for net_mat in constants.pmap(batch_orbitals)(params, data)] + slogprob_net = functools.reduce(lambda x, y: x+y, slogprob_net) + + sharded_key, subkeys = constants.p_split(sharded_key) + params, opt_state_pt, loss = pretrain_step(data, target, params, opt_state_pt) + + logging.info('Pretrain iter %05d: Loss=%03.6f, pmove=%0.2f, ' + 'Norm of Net prob=%03.4f, Norm of HF prob=%03.4f', + t, loss[0], + jnp.mean(num_accepts) / functools.reduce(lambda x, y: x*y, leading_shape), + jnp.mean(slogprob_net), + jnp.mean(logprob)) + + return params, data diff --git a/DeepSolid/process.py b/DeepSolid/process.py new file mode 100644 index 0000000..24b6122 --- /dev/null +++ b/DeepSolid/process.py @@ -0,0 +1,383 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import jax +import jax.numpy as jnp +import numpy as np +import datetime +import ml_collections +import logging +import time +import optax +import chex +import pandas as pd + +from DeepSolid.utils.kfac_ferminet_alpha import optimizer as kfac_optim +from DeepSolid.utils.kfac_ferminet_alpha import utils as kfac_utils + +from DeepSolid import constants +from DeepSolid import network +from DeepSolid import train +from DeepSolid import pretrain +from DeepSolid import qmc +from DeepSolid import init_guess +from DeepSolid import hf +from DeepSolid import checkpoint +from DeepSolid.utils import writers +from DeepSolid import estimator + + +def get_params_initialization_key(deterministic): + ''' + The key point here is to make sure different hosts uses the same RNG key + to initialize network parameters. + ''' + if deterministic: + seed = 888 + else: + + # The overly complicated action here is to make sure different hosts get + # the same seed. + @constants.pmap + def average_seed(seed_array): + return jax.lax.pmean(jnp.mean(seed_array), axis_name=constants.PMAP_AXIS_NAME) + + local_seed = time.time() + float_seed = average_seed(jnp.ones(jax.local_device_count()) * local_seed)[0] + seed = int(1e6 * float_seed) + print(f'params initialization seed: {seed}') + return jax.random.PRNGKey(seed) + + +def process(cfg: ml_collections.ConfigDict): + + num_hosts, host_idx = 1, 0 + + # Device logging + num_devices = jax.local_device_count() + local_batch_size = cfg.batch_size // num_hosts + logging.info('Starting QMC with %i XLA devices', num_devices) + if local_batch_size % num_devices != 0: + raise ValueError('Batch size must be divisible by number of devices, ' + 'got batch size {} for {} devices.'.format( + local_batch_size, num_devices)) + ckpt_save_path = checkpoint.create_save_path(cfg.log.save_path,) + ckpt_restore_path = checkpoint.get_restore_path(cfg.log.restore_path) + ckpt_restore_filename = ( + checkpoint.find_last_checkpoint(ckpt_save_path) or + checkpoint.find_last_checkpoint(ckpt_restore_path)) + + simulation_cell = cfg.system.pyscf_cell + cfg.system.internal_cell = init_guess.pyscf_to_cell(cell=simulation_cell) + + hartree_fock = hf.SCF(cell=simulation_cell, twist=jnp.array(cfg.network.twist)) + hartree_fock.init_scf() + + + if cfg.system.ndim != 3: + # The network (at least the input feature construction) and initial MCMC + # molecule configuration (via system.Atom) assume 3D systems. This can be + # lifted with a little work. + raise ValueError('Only 3D systems are currently supported.') + data_shape = (num_devices, local_batch_size // num_devices) + + if cfg.debug.deterministic: + seed = 666 + else: + seed = int(1e6 * time.time()) + + key = jax.random.PRNGKey(seed) + key = jax.random.fold_in(key, host_idx) + + system_dict = { + 'klist': hartree_fock.klist, + 'simulation_cell': simulation_cell, + } + system_dict.update(cfg.network.detnet) + + slater_mat = network.make_solid_fermi_net(**system_dict, method_name='eval_mats') + slater_logdet = network.make_solid_fermi_net(**system_dict, method_name='eval_logdet') + slater_slogdet = network.make_solid_fermi_net(**system_dict, method_name='eval_slogdet') + + batch_slater_logdet = jax.vmap(slater_logdet.apply, in_axes=(None, 0), out_axes=0) + batch_slater_slogdet = jax.vmap(slater_slogdet.apply, in_axes=(None, 0), out_axes=0) + batch_slater_mat = jax.vmap(slater_mat.apply, in_axes=(None, 0), out_axes=0) + + if ckpt_restore_filename: + t_init, data, params, opt_state_ckpt, mcmc_width_ckpt = checkpoint.restore( + ckpt_restore_filename, local_batch_size) + + else: + logging.info('No checkpoint found. Training new model.') + t_init = 0 + opt_state_ckpt = None + mcmc_width_ckpt = None + data = init_guess.init_electrons(key=key, cell=cfg.system.internal_cell, + latvec=simulation_cell.lattice_vectors(), + electrons=simulation_cell.nelec, + batch_size=local_batch_size, + init_width=cfg.mcmc.init_width) + data = jnp.reshape(data, data_shape + data.shape[1:]) + data = constants.broadcast_all_local_devices(data) + params_initialization_key = get_params_initialization_key(cfg.debug.deterministic) + params = slater_logdet.init(key=params_initialization_key, data=None) + params = constants.replicate_all_local_devices(params) + + pmoves = np.zeros(cfg.mcmc.adapt_frequency) + shared_t = constants.replicate_all_local_devices(jnp.zeros([])) + shared_mom = kfac_utils.replicate_all_local_devices(jnp.zeros([])) + shared_damping = kfac_utils.replicate_all_local_devices( + jnp.asarray(cfg.optim.kfac.damping)) + sharded_key = constants.make_different_rng_key_on_all_devices(key) + sharded_key, subkeys = constants.p_split(sharded_key) + + if (t_init == 0 and cfg.pretrain.method == 'net' and + cfg.pretrain.iterations > 0): + logging.info('Pretrain using Net distribution.') + sharded_key, subkeys = constants.p_split(sharded_key) + params, data = pretrain.pretrain_hartree_fock(params=params, + data=data, + batch_network=batch_slater_slogdet, + batch_orbitals=batch_slater_mat, + sharded_key=subkeys, + scf_approx=hartree_fock, + cell=simulation_cell, + iterations=cfg.pretrain.iterations, + learning_rate=cfg.pretrain.lr, + full_det=cfg.network.detnet.full_det, + ) + + if (t_init == 0 and cfg.pretrain.method == 'hf' and + cfg.pretrain.iterations > 0): + logging.info('Pretrain using Hartree Fock distribution.') + sharded_key, subkeys = constants.p_split(sharded_key) + params, data = pretrain.pretrain_hartree_fock_usingHF(params=params, + data=data, + batch_orbitals=batch_slater_mat, + sharded_key=sharded_key, + cell=simulation_cell, + scf_approx=hartree_fock, + iterations=cfg.pretrain.iterations, + learning_rate=cfg.pretrain.lr, + full_det=cfg.network.detnet.full_det, + nsteps=cfg.pretrain.steps) + if (t_init == 0 and cfg.pretrain.iterations > 0): + logging.info('Saving pretrain params') + checkpoint.save(ckpt_save_path, 0, data, params, None, None,) + + sampling_func = slater_slogdet.apply if cfg.mcmc.importance_sampling else None + mcmc_step = qmc.make_mcmc_step(batch_slog_network=batch_slater_slogdet, + batch_per_device=local_batch_size//jax.local_device_count(), + latvec=jnp.asarray(simulation_cell.lattice_vectors()), + steps=cfg.mcmc.steps, + one_electron_moves=cfg.mcmc.one_electron, + importance_sampling=sampling_func, + ) + + total_energy = train.make_loss(network=slater_logdet.apply, + batch_network=batch_slater_logdet, + simulation_cell=simulation_cell, + clip_local_energy=cfg.optim.clip_el, + clip_type=cfg.optim.clip_type, + mode=cfg.optim.laplacian_mode, + partition_number=cfg.optim.partition_number, + ) + + def learning_rate_schedule(t): + return cfg.optim.lr.rate * jnp.power( + (1.0 / (1.0 + (t / cfg.optim.lr.delay))), cfg.optim.lr.decay) + + val_and_grad = jax.value_and_grad(total_energy, argnums=0, has_aux=True) + if cfg.optim.optimizer == 'adam': + optimizer = optax.chain(optax.scale_by_adam(**cfg.optim.adam), + optax.scale_by_schedule(learning_rate_schedule), + optax.scale(-1.)) + elif cfg.optim.optimizer == 'kfac': + optimizer = kfac_optim.Optimizer( + val_and_grad, + l2_reg=cfg.optim.kfac.l2_reg, + norm_constraint=cfg.optim.kfac.norm_constraint, + value_func_has_aux=True, + learning_rate_schedule=learning_rate_schedule, + curvature_ema=cfg.optim.kfac.cov_ema_decay, + inverse_update_period=cfg.optim.kfac.invert_every, + min_damping=cfg.optim.kfac.min_damping, + num_burnin_steps=0, + register_only_generic=cfg.optim.kfac.register_only_generic, + estimation_mode='fisher_exact', + multi_device=True, + pmap_axis_name=constants.PMAP_AXIS_NAME + # debug=True + ) + sharded_key, subkeys = kfac_utils.p_split(sharded_key) + opt_state = optimizer.init(params, subkeys, data) + opt_state = opt_state_ckpt or opt_state # avoid overwriting ckpted state + elif cfg.optim.optimizer == 'none': + total_energy = constants.pmap(total_energy) + opt_state = None + else: + raise ValueError('Unrecognized Optimizer.') + + if cfg.optim.optimizer != 'kfac' and cfg.optim.optimizer != 'none': + optimizer = optax.MultiSteps(optimizer, every_k_schedule=cfg.optim.ministeps) + + opt_state = jax.pmap(optimizer.init)(params) + opt_state = opt_state if opt_state_ckpt is None else optax._src.wrappers.MultiStepsState(*opt_state) + + def opt_update(t, grad, params, opt_state): + del t # Unused. + updates, opt_state = optimizer.update(grad, opt_state, params) + params = optax.apply_updates(params, updates) + return opt_state, params + + step = train.make_training_step(mcmc_step, val_and_grad, opt_update) + + mcmc_step = constants.pmap(mcmc_step) + + if mcmc_width_ckpt is not None: + mcmc_width = constants.broadcast_all_local_devices(jnp.asarray(mcmc_width_ckpt)) + else: + mcmc_width = constants.replicate_all_local_devices(jnp.asarray(cfg.mcmc.move_width)) + + if t_init == 0: + logging.info('Burning in MCMC chain for %d steps', cfg.mcmc.burn_in) + for t in range(cfg.mcmc.burn_in): + sharded_key, subkeys = constants.p_split(sharded_key) + data, pmove = mcmc_step(params, data, subkeys, mcmc_width) + logging.info('Completed burn-in MCMC steps') + logging.info('Initial energy for primitive cell: %03.4f E_h', + constants.pmap(total_energy)(params, data)[0][0] / simulation_cell.scale) + + time_of_last_ckpt = time.time() + + if cfg.optim.optimizer == 'none' and opt_state_ckpt is not None: + # If opt_state_ckpt is None, then we're restarting from a previous inference + # run (most likely due to preemption) and so should continue from the last + # iteration in the checkpoint. Otherwise, starting an inference run from a + # training run. + logging.info('No optimizer provided. Assuming inference run.') + logging.info('Setting initial iteration to 0.') + t_init = 0 + + train_schema = ['step', 'energy', 'variance', 'pmove', 'imaginary', 'kinetic', 'ewald'] + if cfg.log.complex_polarization: + train_schema.append('complex_polarization') + polarization = estimator.make_complex_polarization(simulation_cell) + pmap_polarization = constants.pmap(polarization) + if cfg.log.structure_factor: + structure_factor = estimator.make_structure_factor(simulation_cell) + pmap_structure_factor = constants.pmap(structure_factor) + with writers.Writer(name=cfg.log.stats_file_name, + schema=train_schema, + directory=ckpt_save_path, + iteration_key=None, + log=False) as writer: + for t in range(t_init, cfg.optim.iterations): + sharded_key, subkeys = constants.p_split(sharded_key) + if cfg.optim.optimizer == 'kfac': + new_data, pmove = mcmc_step(params, data, subkeys, mcmc_width) + # Need this split because MCMC step above used subkeys already + sharded_key, subkeys = kfac_utils.p_split(sharded_key) + new_params, new_opt_state, new_stats = optimizer.step( # pytype: disable=attribute-error + params=params, + state=opt_state, + rng=subkeys, + data_iterator=iter([new_data]), + momentum=shared_mom, + damping=shared_damping) + tree = {'params': new_params, 'loss': new_stats['loss'], 'optim': new_opt_state} + try: + # We don't do check_nan by default due to efficiency concern. + # We noticed ~0.2s overhead when performing this nan check + # at transitional medals. + if cfg.debug.check_nan: + chex.assert_tree_all_finite(tree) + data = new_data + params = new_params + opt_state = new_opt_state + stats = new_stats + loss = stats['loss'] + aux_data = stats['aux'] + except AssertionError as e: + # data, params, opt_state, and stats are not updated + logging.warn(str(e)) + loss = aux_data = None + elif cfg.optim.optimizer == 'none': + data, pmove = mcmc_step(params, data, subkeys, mcmc_width) + loss, aux_data = total_energy(params, data) + else: + data, params, opt_state, loss, aux_data, pmove, search_direction = step(shared_t, + data, + params, + opt_state, + subkeys, + mcmc_width) + shared_t = shared_t + 1 + loss = loss[0] / simulation_cell.scale if loss is not None else None + variance = aux_data.variance[0] / simulation_cell.scale ** 2 if aux_data is not None else None + imaginary = aux_data.imaginary[0] / simulation_cell.scale if aux_data is not None else None + kinetic = jnp.mean(aux_data.kinetic) / simulation_cell.scale if aux_data is not None else None + ewald = jnp.mean(aux_data.ewald) / simulation_cell.scale if aux_data is not None else None + pmove = pmove[0] + + if cfg.log.complex_polarization: + polarization_data = pmap_polarization(data)[0] + if cfg.log.structure_factor: + structure_factor_data = pmap_structure_factor(data)[0][None, :] + pd_tabel = pd.DataFrame(structure_factor_data) + pd_tabel.to_csv(str(ckpt_save_path) + '/structure_factor.csv', mode="a", sep=',', header=False) + + + if t % cfg.log.stats_frequency == 0 and loss is not None: + logging.info( + '%s Step %05d: %03.4f E_h, variance=%03.4f E_h^2, pmove=%0.2f, imaginary part=%03.4f, ' + 'kinetic=%03.4f E_h, ewald=%03.4f E_h', + datetime.datetime.now(), t, + loss, variance, pmove, imaginary, + kinetic.real, ewald) + result_dict = { + 'step': t, + 'energy': np.asarray(loss), + 'variance': np.asarray(variance), + 'pmove': np.asarray(pmove), + 'imaginary': np.asarray(imaginary), + 'kinetic': np.asarray(kinetic), + 'ewald': np.asarray(ewald), + } + if cfg.log.complex_polarization: + result_dict['complex_polarization'] = np.asarray(polarization_data) + writer.write(t, + **result_dict, + ) + + # Update MCMC move width + if t > 0 and t % cfg.mcmc.adapt_frequency == 0: + if np.mean(pmoves) > 0.55: + mcmc_width *= 1.1 + if np.mean(pmoves) < 0.5: + mcmc_width /= 1.1 + pmoves[:] = 0 + pmoves[t % cfg.mcmc.adapt_frequency] = pmove + + if (time.time() - time_of_last_ckpt > cfg.log.save_frequency * 60 + or t >= cfg.optim.iterations - 1 + or (cfg.log.save_frequency_in_step > 0 and t % cfg.log.save_frequency_in_step == 0)): + # no checkpointing in inference mode + if cfg.optim.optimizer != 'none': + checkpoint.save(ckpt_save_path, t, data, params, opt_state, mcmc_width,) + + time_of_last_ckpt = time.time() diff --git a/DeepSolid/qmc.py b/DeepSolid/qmc.py new file mode 100644 index 0000000..385422f --- /dev/null +++ b/DeepSolid/qmc.py @@ -0,0 +1,358 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import logging + +import jax +import jax.numpy as jnp +from DeepSolid import constants +from DeepSolid import distance + + +def _log_prob_gaussian(x, mu, sigma): + """Calculates the log probability of Gaussian with diagonal covariance. + + Args: + x: Positions. Shape (batch, nelectron, 1, ndim) - as used in mh_update. + mu: means of Gaussian distribution. Same shape as or broadcastable to x. + sigma: standard deviation of the distribution. Same shape as or + broadcastable to x. + + Returns: + Log probability of Gaussian distribution with shape as required for + mh_update - (batch, nelectron, 1, 1). + """ + numer = jnp.sum(-0.5 * ((x - mu) ** 2) / (sigma ** 2), axis=[1, 2, 3]) + denom = x.shape[-1] * jnp.sum(jnp.log(sigma), axis=[1, 2, 3]) + return numer - denom + + +def _harmonic_mean(x, atoms): + """Calculates the harmonic mean of each electron distance to the nuclei. + + Args: + x: electron positions. Shape (batch, nelectrons, 1, ndim). Note the third + dimension is already expanded, which allows for avoiding additional + reshapes in the MH algorithm. + atoms: atom positions. Shape (natoms, ndim) + + Returns: + Array of shape (batch, nelectrons, 1, 1), where the (i, j, 0, 0) element is + the harmonic mean of the distance of the j-th electron of the i-th MCMC + configuration to all atoms. + """ + ae = x - atoms[None, ...] + r_ae = jnp.linalg.norm(ae, axis=-1, keepdims=True) + return 1.0 / jnp.mean(1.0 / r_ae, axis=-2, keepdims=True) + + +def limdrift(g:jnp.array, cutoff=1): + """ + Limit a vector to have a maximum magnitude of cutoff while maintaining direction + + Args: + g: a [nconf,ndim] vector + + cutoff: the maximum magnitude + + Returns: + The vector with the cut off applied. + """ + g_shape = g.shape + g = g.reshape([-1, 3]) + tot = jnp.linalg.norm(g, axis=-1) + normalize = jnp.clip(tot, a_min=cutoff, a_max=jnp.max(tot)) + g = cutoff * g / normalize[:, None] + g = g.reshape(g_shape) + return g + +def importance_update(params, + f, + x1, + key, + lp_1, + num_accepts, + latvec, + stddev=0.02, + atoms=None, + i=0, + ): + """ + + :param params: + :param f: val_and_grad of batch_slogdet + :param x1: + :param key: + :param lp_1: + :param num_accepts: + :param latvec: + :param stddev: + :param atoms: + :param i: + :return: + """ + del i + key, subkey = jax.random.split(key) + if atoms is None: # symmetric proposal, same stddev everywhere + _, grad = f(params, x1) + grad = limdrift(grad) + gauss = stddev * jax.random.normal(subkey, shape=x1.shape) + x2 = x1 + gauss + stddev**2 * grad # proposal + x2, _ = distance.enforce_pbc(latvec, x2) + + # Compute reverse move + lpsi_2, new_grad = f(params, x2) + lp_2 = 2 * lpsi_2 + new_grad = limdrift(new_grad) + forward = jnp.sum(gauss ** 2, axis=-1) + backward = jnp.sum((gauss + stddev**2 * (grad + new_grad)) ** 2, + axis=-1) + lp_2 = lp_2 + 1 / (2 * stddev**2) * (forward - backward) + + ratio = lp_2 - lp_1 + else: # asymmetric proposal, stddev propto harmonic mean of nuclear distances + n = x1.shape[0] + x1 = jnp.reshape(x1, [n, -1, 1, 3]) + hmean1 = _harmonic_mean(x1, atoms) # harmonic mean of distances to nuclei + + x2 = x1 + stddev * hmean1 * jax.random.normal(subkey, shape=x1.shape) + lp_2 = 2. * f(params, x2) # log prob of proposal + hmean2 = _harmonic_mean(x2, atoms) # needed for probability of reverse jump + + lq_1 = _log_prob_gaussian(x1, x2, stddev * hmean1) # forward probability + lq_2 = _log_prob_gaussian(x2, x1, stddev * hmean2) # reverse probability + ratio = lp_2 + lq_2 - lp_1 - lq_1 + + x1 = jnp.reshape(x1, [n, -1]) + x2 = jnp.reshape(x2, [n, -1]) + + key, subkey = jax.random.split(key) + rnd = jnp.log(jax.random.uniform(subkey, shape=lp_1.shape)) + cond = ratio > rnd + x_new = jnp.where(cond[..., None], x2, x1) + lp_new = jnp.where(cond, lp_2, lp_1) + num_accepts += jnp.sum(cond) + + return x_new, key, lp_new, num_accepts + + +def mh_update(params, + f, + x1, + key, + lp_1, + num_accepts, + latvec, + stddev=0.02, + atoms=None, + i=0, + ): + """Performs one Metropolis-Hastings step using an all-electron move. + + Args: + params: Wavefuncttion parameters. + f: Callable with signature f(params, x) which returns the log of the + wavefunction (i.e. the sqaure root of the log probability of x). + x1: Initial MCMC configurations. Shape (batch, nelectrons*ndim). + key: RNG state. + lp_1: log probability of f evaluated at x1 given parameters params. + num_accepts: Number of MH move proposals accepted. + stddev: width of Gaussian move proposal. + atoms: If not None, atom positions. Shape (natoms, 3). If present, then the + Metropolis-Hastings move proposals are drawn from a Gaussian distribution, + N(0, (h_i stddev)^2), where h_i is the harmonic mean of distances between + the i-th electron and the atoms, otherwise the move proposal drawn from + N(0, stddev^2). + + Returns: + (x, key, lp, num_accepts), where: + x: Updated MCMC configurations. + key: RNG state. + lp: log probability of f evaluated at x. + num_accepts: update running total of number of accepted MH moves. + """ + del i + key, subkey = jax.random.split(key) + if atoms is None: # symmetric proposal, same stddev everywhere + x2 = x1 + stddev * jax.random.normal(subkey, shape=x1.shape) # proposal + x2, _ = distance.enforce_pbc(latvec, x2) + lp_2 = 2. * f(params, x2) # log prob of proposal + ratio = lp_2 - lp_1 + else: # asymmetric proposal, stddev propto harmonic mean of nuclear distances + n = x1.shape[0] + x1 = jnp.reshape(x1, [n, -1, 1, 3]) + hmean1 = _harmonic_mean(x1, atoms) # harmonic mean of distances to nuclei + + x2 = x1 + stddev * hmean1 * jax.random.normal(subkey, shape=x1.shape) + x2 = jnp.reshape(x2, [n, -1]) + x2, _ = distance.enforce_pbc(latvec, x2) + lp_2 = 2. * f(params, x2) + + x2 = jnp.reshape(x2, [n, -1, 1, 3]) + hmean2 = _harmonic_mean(x2, atoms) # needed for probability of reverse jump + + lq_1 = _log_prob_gaussian(x1, x2, stddev * hmean1) # forward probability + lq_2 = _log_prob_gaussian(x2, x1, stddev * hmean2) # reverse probability + ratio = lp_2 + lq_2 - lp_1 - lq_1 + + x1 = jnp.reshape(x1, [n, -1]) + x2 = jnp.reshape(x2, [n, -1]) + + key, subkey = jax.random.split(key) + rnd = jnp.log(jax.random.uniform(subkey, shape=lp_1.shape)) + cond = ratio > rnd + x_new = jnp.where(cond[..., None], x2, x1) + lp_new = jnp.where(cond, lp_2, lp_1) + num_accepts += jnp.sum(cond) + + return x_new, key, lp_new, num_accepts + + +def mh_one_electron_update(params, + f, + x1, + key, + lp_1, + num_accepts, + latvec, + stddev=0.02, + atoms=None, + i=0): + """Performs one Metropolis-Hastings step for a single electron. + + Args: + params: Wavefuncttion parameters. + f: Callable with signature f(params, x) which returns the log of the + wavefunction (i.e. the sqaure root of the log probability of x). + x1: Initial MCMC configurations. Shape (batch, nelectrons*ndim). + key: RNG state. + lp_1: log probability of f evaluated at x1 given parameters params. + num_accepts: Number of MH move proposals accepted. + stddev: width of Gaussian move proposal. + atoms: Ignored. Asymmetric move proposals are not implemented for + single-electron moves. + i: index of electron to move. + + Returns: + (x, key, lp, num_accepts), where: + x: Updated MCMC configurations. + key: RNG state. + lp: log probability of f evaluated at x. + num_accepts: update running total of number of accepted MH moves. + + Raises: + NotImplementedError: if atoms is supplied. + """ + key, subkey = jax.random.split(key) + n = x1.shape[0] + x1 = jnp.reshape(x1, [n, -1, 1, 3]) + nelec = x1.shape[1] + ii = i % nelec + if atoms is None: # symmetric proposal, same stddev everywhere + x2 = x1.at[:, ii].add(stddev * + jax.random.normal(subkey, shape=x1[:, ii].shape)) + x2, _ = distance.enforce_pbc(latvec, x2) + lp_2 = 2. * f(params, x2) # log prob of proposal + ratio = lp_2 - lp_1 + else: # asymmetric proposal, stddev propto harmonic mean of nuclear distances + raise NotImplementedError('Still need to work out reverse probabilities ' + 'for asymmetric moves.') + + x1 = jnp.reshape(x1, [n, -1]) + x2 = jnp.reshape(x2, [n, -1]) + key, subkey = jax.random.split(key) + rnd = jnp.log(jax.random.uniform(subkey, shape=lp_1.shape)) + cond = ratio > rnd + x_new = jnp.where(cond[..., None], x2, x1) + lp_new = jnp.where(cond, lp_2, lp_1) + num_accepts += jnp.sum(cond) + + return x_new, key, lp_new, num_accepts + + +def make_mcmc_step(batch_slog_network, + batch_per_device, + latvec, + steps=10, + atoms=None, + importance_sampling=None, + one_electron_moves=False, + ): + """Creates the MCMC step function. + + Args: + batch_network: function, signature (params, x), which evaluates the log of + the wavefunction (square root of the log probability distribution) at x + given params. Inputs and outputs are batched. + batch_per_device: Batch size per device. + steps: Number of MCMC moves to attempt in a single call to the MCMC step + function. + atoms: atom positions. If given, an asymmetric move proposal is used based + on the harmonic mean of electron-atom distances for each electron. + Otherwise the (conventional) normal distribution is used. + one_electron_moves: If true, attempt to move one electron at a time. + Otherwise, attempt one all-electron move per MCMC step. + + Returns: + Callable which performs the set of MCMC steps. + """ + if importance_sampling is not None: + if one_electron_moves: + raise ValueError('Importance sampling for one elec move is not implemented yet') + else: + logging.info('Using importance sampling') + func = jax.vmap(jax.value_and_grad(importance_sampling, argnums=1), in_axes=(None, 0)) + inner_fun = importance_update + else: + func = batch_slog_network + if one_electron_moves: + logging.info('Using one electron Metropolis sampling') + inner_fun = mh_one_electron_update + else: + logging.info('Using Metropolis sampling') + inner_fun = mh_update + + @jax.jit + def mcmc_step(params, data, key, width): + """Performs a set of MCMC steps. + + Args: + params: parameters to pass to the network. + data: (batched) MCMC configurations to pass to the network. + key: RNG state. + width: standard deviation to use in the move proposal. + + Returns: + (data, pmove), where data is the updated MCMC configurations, key the + updated RNG state and pmove the average probability a move was accepted. + """ + + def step_fn(i, x): + return inner_fun(params, func, *x, + latvec=latvec, stddev=width, + atoms=atoms, i=i) + + nelec = data.shape[-1] // 3 + nsteps = nelec * steps if one_electron_moves else steps + logprob = 2. * batch_slog_network(params, data) + data, key, _, num_accepts = jax.lax.fori_loop(0, nsteps, step_fn, + (data, key, logprob, 0.)) + pmove = jnp.sum(num_accepts) / (nsteps * batch_per_device) + pmove = constants.pmean_if_pmap(pmove, axis_name=constants.PMAP_AXIS_NAME) + return data, pmove + + return mcmc_step diff --git a/DeepSolid/supercell.py b/DeepSolid/supercell.py new file mode 100644 index 0000000..3220542 --- /dev/null +++ b/DeepSolid/supercell.py @@ -0,0 +1,148 @@ +# MIT License +# +# Copyright (c) 2019 Lucas K Wagner +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import logging + +import numpy as np +import pyscf.pbc.gto + + +def get_supercell_kpts(supercell): + """ + + :param supercell: + :return:supercell k points which belong to the unit box primitive cell k point space + """ + Sinv = np.linalg.inv(supercell.S).T + u = [0, 1] + unit_box = np.stack([x.ravel() for x in np.meshgrid(*[u] * 3, indexing="ij")]).T + unit_box_ = np.dot(unit_box, supercell.S.T) + xyz_range = np.stack([f(unit_box_, axis=0) for f in (np.amin, np.amax)]).T + kptmesh = np.meshgrid(*[np.arange(*r) for r in xyz_range], indexing="ij") + possible_kpts = np.dot(np.stack([x.ravel() for x in kptmesh]).T, Sinv) + in_unit_box = (possible_kpts >= 0) * (possible_kpts < 1 - 1e-12) + select = np.where(np.all(in_unit_box, axis=1))[0] + reclatvec = np.linalg.inv(supercell.original_cell.lattice_vectors()).T * 2 * np.pi + return np.dot(possible_kpts[select], reclatvec) + + +def get_supercell_copies(latvec, S): + Sinv = np.linalg.inv(S).T + u = [0, 1] + unit_box = np.stack([x.ravel() for x in np.meshgrid(*[u] * 3, indexing="ij")]).T + unit_box_ = np.dot(unit_box, S) + xyz_range = np.stack([f(unit_box_, axis=0) for f in (np.amin, np.amax)]).T + mesh = np.meshgrid(*[np.arange(*r) for r in xyz_range], indexing="ij") + possible_pts = np.dot(np.stack([x.ravel() for x in mesh]).T, Sinv.T) + in_unit_box = (possible_pts >= 0) * (possible_pts < 1 - 1e-12) + select = np.where(np.all(in_unit_box, axis=1))[0] + return np.linalg.multi_dot((possible_pts[select], S, latvec)) + + +def get_supercell(cell, S, sym_type='minimal') -> pyscf.pbc.gto.Cell: + """ + generate supercell from primitive cell with S specified + + :param cell: pyscf Cell object + :param S: (3, 3) supercell matrix for QMC from cell defined by cell.a. + :return: QMC simulation cell + """ + import pyscf.pbc + scale = np.abs(int(np.round(np.linalg.det(S)))) + superlattice = np.dot(S, cell.lattice_vectors()) + Rpts = get_supercell_copies(cell.lattice_vectors(), S) + atom = [] + for (name, xyz) in cell._atom: + atom.extend([(name, xyz + R) for R in Rpts]) + supercell = pyscf.pbc.gto.Cell() + supercell.a = superlattice + supercell.atom = atom + supercell.ecp = cell.ecp + supercell.basis = cell.basis + supercell.exp_to_discard = cell.exp_to_discard + supercell.unit = "Bohr" + supercell.spin = cell.spin * scale + supercell.build() + supercell.original_cell = cell + supercell.S = S + supercell.scale = scale + supercell.output = None + supercell.stdout = None + supercell = set_symmetry_lat(supercell, sym_type) + logging.info(f'Use {sym_type} type feature.') + return supercell + + +def set_symmetry_lat(supercell, sym_type='minimal'): + ''' + Attach corresponding lattice vectors to the simulation cell. + + :param supercell: + :param sym_type:specify the symmetry of constructed distance feature, + Minimal is used as default, and other type hasn't been tested. + :return: simulation cell with symmetry specified. + ''' + prim_bv = supercell.original_cell.reciprocal_vectors() + sim_bv = supercell.reciprocal_vectors() + if sym_type == 'minimal': + mat = np.eye(3) + elif sym_type == 'fcc': + mat = np.array([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 1]]) + elif sym_type == 'bcc': + mat = np.array([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, -1, 0], + [1, 0, -1], + [0, 1, -1]]) + elif sym_type == 'hexagonal': + mat = np.array([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [-1, -1, 0]]) + else: + mat = np.eye(3) + + prim_bv = mat @ prim_bv + sim_bv = mat @ sim_bv + + prim_av = np.linalg.pinv(prim_bv).T + sim_av = np.linalg.pinv(sim_bv).T + supercell.BV = sim_bv + supercell.AV = sim_av + supercell.original_cell.BV = prim_bv + supercell.original_cell.AV = prim_av + return supercell + + +def get_k_indices(cell, mf, kpts, tol=1e-6): + """Given a list of kpts, return inds such that mf.kpts[inds] is a list of kpts equivalent to the input list""" + kdiffs = mf.kpts[None] - kpts[:, None] + frac_kdiffs = np.dot(kdiffs, cell.lattice_vectors().T) / (2 * np.pi) + kdiffs = np.mod(frac_kdiffs + 0.5, 1) - 0.5 + return np.nonzero(np.linalg.norm(kdiffs, axis=-1) < tol)[1] diff --git a/DeepSolid/train.py b/DeepSolid/train.py new file mode 100644 index 0000000..4f656b9 --- /dev/null +++ b/DeepSolid/train.py @@ -0,0 +1,175 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import chex +import jax +import jax.numpy as jnp +import functools + +from DeepSolid import hamiltonian +from DeepSolid import constants +from DeepSolid.utils.kfac_ferminet_alpha import loss_functions + + +@chex.dataclass +class AuxiliaryLossData: + variance: jnp.DeviceArray + local_energy: jnp.DeviceArray + imaginary: jnp.DeviceArray + kinetic: jnp.DeviceArray + ewald: jnp.DeviceArray + + +def make_loss(network, batch_network, + simulation_cell, + clip_local_energy=5.0, + clip_type='real', + mode='for', + partition_number=3): + el_fun = hamiltonian.local_energy_seperate(network, + simulation_cell=simulation_cell, + mode=mode, + partition_number=partition_number) + batch_local_energy = jax.vmap(el_fun, in_axes=(None, 0), out_axes=0) + + @jax.custom_jvp + def total_energy(params, data): + """ + + :param params: + :param data: batch electron coord with shape [Batch, Nelec * Ndim] + :return: energy expectation of corresponding walkers (only take real part) with shape [Batch] + """ + ke, ew = batch_local_energy(params, data) + e_l = ke + ew + mean_e_l = jnp.mean(e_l) + + pmean_loss = constants.pmean_if_pmap(mean_e_l, axis_name=constants.PMAP_AXIS_NAME) + variance = constants.pmean_if_pmap(jnp.mean(jnp.abs(e_l)**2) - jnp.abs(mean_e_l.real) ** 2, + axis_name=constants.PMAP_AXIS_NAME) + loss = pmean_loss.real + imaginary = pmean_loss.imag + + return loss, AuxiliaryLossData(variance=variance, + local_energy=e_l, + imaginary=imaginary, + kinetic=ke, + ewald=ew, + ) + + @total_energy.defjvp + def total_energy_jvp(primals, tangents): + params, data = primals + loss, aux_data = total_energy(params, data) + diff = (aux_data.local_energy - loss) + if clip_local_energy > 0.0: + if clip_type == 'complex': + radius, phase = jnp.abs(diff), jnp.angle(diff) + radius_tv = constants.pmean_if_pmap(radius.std(), axis_name=constants.PMAP_AXIS_NAME) + radius_mean = jnp.median(radius) + radius_mean = constants.pmean_if_pmap(radius_mean, axis_name=constants.PMAP_AXIS_NAME) + clip_radius = jnp.clip(radius, + radius_mean - radius_tv * clip_local_energy, + radius_mean + radius_tv * clip_local_energy) + clip_diff = clip_radius * jnp.exp(1j * phase) + elif clip_type == 'real': + tv_re = jnp.mean(jnp.abs(diff.real)) + tv_re = constants.pmean_if_pmap(tv_re, axis_name=constants.PMAP_AXIS_NAME) + tv_im = jnp.mean(jnp.abs(diff.imag)) + tv_im = constants.pmean_if_pmap(tv_im, axis_name=constants.PMAP_AXIS_NAME) + clip_diff_re = jnp.clip(diff.real, + -clip_local_energy * tv_re, + clip_local_energy * tv_re) + clip_diff_im = jnp.clip(diff.imag, + -clip_local_energy * tv_im, + clip_local_energy * tv_im) + clip_diff = clip_diff_re + clip_diff_im * 1j + else: + raise ValueError('Unrecognized clip type.') + else: + clip_diff = diff + + psi_primal, psi_tangent = jax.jvp(batch_network, primals, tangents) + conj_psi_tangent = jnp.conjugate(psi_tangent) + conj_psi_primal = jnp.conjugate(psi_primal) + + loss_functions.register_normal_predictive_distribution(conj_psi_primal[:, None]) + + primals_out = loss, aux_data + # tangents_dot = jnp.dot(clip_diff, conj_psi_tangent).real + # dot causes the gradient to be extensive with batch size, which does matter for KFAC. + tangents_dot = jnp.mean((clip_diff * conj_psi_tangent).real) + + tangents_out = (tangents_dot, aux_data) + + return primals_out, tangents_out + + return total_energy + + +def make_training_step(mcmc_step, val_and_grad, opt_update): + + @functools.partial(constants.pmap, donate_argnums=(1, 2, 3, 4)) + def step(t, data, params, state, key, mcmc_width): + data, pmove = mcmc_step(params, data, key, mcmc_width) + + # Optimization step + (loss, aux_data), search_direction = val_and_grad(params, data) + search_direction = constants.pmean_if_pmap(search_direction, + axis_name=constants.PMAP_AXIS_NAME) + state, params = opt_update(t, search_direction, params, state) + return data, params, state, loss, aux_data, pmove, search_direction + + return step + + +@functools.partial(jax.vmap, in_axes=(0, 0), out_axes=0) +def direct_product(x, y): + return x.ravel()[:, None] * y.ravel()[None, :] + + +def make_sr_matrix(network): + ''' + which is used to calculate the fisher matrix, abandoned now. + :param network: + :return: + ''' + network_grad = jax.grad(network.apply, argnums=0, holomorphic=True) + batch_network_grad = jax.vmap(network_grad, in_axes=(None, 0)) + + def sr_matrix(params, data): + complex_params = jax.tree_map(lambda x: x+0j, params) + batch_diffs = batch_network_grad(complex_params, data) + + s1 = jax.tree_map(lambda x: jnp.mean(direct_product(jnp.conjugate(x), x), + axis=0), + batch_diffs) + s2 = jax.tree_map(lambda x: (jnp.mean(jnp.conjugate(x), axis=0).ravel()[:, None] * + jnp.mean(x, axis=0).ravel()[None, :] + ), + batch_diffs) + s1 = constants.pmean_if_pmap(s1, axis_name=constants.PMAP_AXIS_NAME) + s2 = constants.pmean_if_pmap(s2, axis_name=constants.PMAP_AXIS_NAME) + matrix = jax.tree_multimap(lambda x, y: x - y, s1, s2) + return matrix + + return sr_matrix + + + + + diff --git a/DeepSolid/utils/elements.py b/DeepSolid/utils/elements.py new file mode 100644 index 0000000..a3c3f32 --- /dev/null +++ b/DeepSolid/utils/elements.py @@ -0,0 +1,250 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import collections +from typing import Optional +import attr + + +@attr.s +class Element(object): + """Chemical element. + + Attributes: + symbol: official symbol of element. + atomic_number: atomic number of element. + period: period to which the element belongs. + spin: overrides default ground-state spin-configuration based on the + element's group (main groups only). + """ + symbol: str = attr.ib() + atomic_number: int = attr.ib() + period: int = attr.ib() + _spin: Optional[int] = attr.ib(default=None, repr=False) + + @property + def group(self) -> int: + """Group to which element belongs. Set to -1 for actines and lanthanides.""" + is_lanthanide = (58 <= self.atomic_number <= 71) + is_actinide = (90 <= self.atomic_number <= 103) + if is_lanthanide or is_actinide: + return -1 + if self.symbol == 'He': + # n=1 shell only has s orbital -> He is a noble gas. + return 18 + period_starts = (1, 3, 11, 19, 37, 55, 87) + period_start = period_starts[self.period - 1] + group_ = self.atomic_number - period_start + 1 + # Adjust for absence of d block in periods 2 and 3. + if self.period < 4 and group_ > 2: + group_ += 10 + # Adjust for Lanthanides and Actinides in periods 6 and 7. + if self.period >= 6 and group_ > 3: + group_ -= 14 + return group_ + + @property + def spin_config(self) -> int: + """Canonical spin configuration (via Hund's rules) of neutral atom. + + Returns: + Number of unpaired electrons (as required by PySCF) in the neutral atom's + ground state. + + Raises: + NotImplementedError: if element is a transition metal and the spin + configuration is not set at initialization. + """ + if self._spin is not None: + return self._spin + unpaired = {1: 1, 2: 0, 3: 1, 13: 1, 14: 2, 15: 3, 16: 2, 17: 1, 18: 0} + if self.group in unpaired: + return unpaired[self.group] + else: + raise NotImplementedError( + 'Spin configuration for transition metals not set.') + + @property + def nalpha(self) -> int: + """Returns the number of alpha electrons of the ground state neutral atom. + + Without loss of generality, the number of alpha electrons is taken to be + equal to or greater than the number of beta electrons. + """ + electrons = self.atomic_number + unpaired = self.spin_config + return (electrons + unpaired) // 2 + + @property + def nbeta(self) -> int: + """Returns the number of beta electrons of the ground state neutral atom. + + Without loss of generality, the number of alpha electrons is taken to be + equal to or greater than the number of beta electrons. + """ + electrons = self.atomic_number + unpaired = self.spin_config + return (electrons - unpaired) // 2 + + +# Atomic symbols for all known elements +# Generated using +# def _element(symbol, atomic_number): +# # period_start[n] = atomic number of group 1 element in (n+1)-th period. +# period_start = (1, 3, 11, 19, 37, 55, 87) +# for p, group1_no in enumerate(period_start): +# if atomic_number < group1_no: +# # In previous period but n is 0-based. +# period = p +# break +# else: +# period = p + 1 +# return Element(symbol=symbol, atomic_number=atomic_number, period=period) +# [_element(s, n+1) for n, s in enumerate(symbols)] +# where symbols is the list of chemical symbols of all elements. +_ELEMENTS = ( + Element(symbol='H', atomic_number=1, period=1), + Element(symbol='He', atomic_number=2, period=1), + Element(symbol='Li', atomic_number=3, period=2), + Element(symbol='Be', atomic_number=4, period=2), + Element(symbol='B', atomic_number=5, period=2), + Element(symbol='C', atomic_number=6, period=2), + Element(symbol='N', atomic_number=7, period=2), + Element(symbol='O', atomic_number=8, period=2), + Element(symbol='F', atomic_number=9, period=2), + Element(symbol='Ne', atomic_number=10, period=2), + Element(symbol='Na', atomic_number=11, period=3), + Element(symbol='Mg', atomic_number=12, period=3), + Element(symbol='Al', atomic_number=13, period=3), + Element(symbol='Si', atomic_number=14, period=3), + Element(symbol='P', atomic_number=15, period=3), + Element(symbol='S', atomic_number=16, period=3), + Element(symbol='Cl', atomic_number=17, period=3), + Element(symbol='Ar', atomic_number=18, period=3), + Element(symbol='K', atomic_number=19, period=4), + Element(symbol='Ca', atomic_number=20, period=4), + Element(symbol='Sc', atomic_number=21, period=4, spin=1), + Element(symbol='Ti', atomic_number=22, period=4, spin=2), + Element(symbol='V', atomic_number=23, period=4, spin=3), + Element(symbol='Cr', atomic_number=24, period=4, spin=6), + Element(symbol='Mn', atomic_number=25, period=4, spin=5), + Element(symbol='Fe', atomic_number=26, period=4, spin=4), + Element(symbol='Co', atomic_number=27, period=4, spin=3), + Element(symbol='Ni', atomic_number=28, period=4, spin=2), + Element(symbol='Cu', atomic_number=29, period=4, spin=1), + Element(symbol='Zn', atomic_number=30, period=4, spin=0), + Element(symbol='Ga', atomic_number=31, period=4), + Element(symbol='Ge', atomic_number=32, period=4), + Element(symbol='As', atomic_number=33, period=4), + Element(symbol='Se', atomic_number=34, period=4), + Element(symbol='Br', atomic_number=35, period=4), + Element(symbol='Kr', atomic_number=36, period=4), + Element(symbol='Rb', atomic_number=37, period=5), + Element(symbol='Sr', atomic_number=38, period=5), + Element(symbol='Y', atomic_number=39, period=5, spin=1), + Element(symbol='Zr', atomic_number=40, period=5, spin=2), + Element(symbol='Nb', atomic_number=41, period=5, spin=5), + Element(symbol='Mo', atomic_number=42, period=5, spin=6), + Element(symbol='Tc', atomic_number=43, period=5, spin=5), + Element(symbol='Ru', atomic_number=44, period=5, spin=4), + Element(symbol='Rh', atomic_number=45, period=5, spin=3), + Element(symbol='Pd', atomic_number=46, period=5, spin=0), + Element(symbol='Ag', atomic_number=47, period=5, spin=1), + Element(symbol='Cd', atomic_number=48, period=5, spin=0), + Element(symbol='In', atomic_number=49, period=5), + Element(symbol='Sn', atomic_number=50, period=5), + Element(symbol='Sb', atomic_number=51, period=5), + Element(symbol='Te', atomic_number=52, period=5), + Element(symbol='I', atomic_number=53, period=5), + Element(symbol='Xe', atomic_number=54, period=5), + Element(symbol='Cs', atomic_number=55, period=6), + Element(symbol='Ba', atomic_number=56, period=6), + Element(symbol='La', atomic_number=57, period=6), + Element(symbol='Ce', atomic_number=58, period=6), + Element(symbol='Pr', atomic_number=59, period=6), + Element(symbol='Nd', atomic_number=60, period=6), + Element(symbol='Pm', atomic_number=61, period=6), + Element(symbol='Sm', atomic_number=62, period=6), + Element(symbol='Eu', atomic_number=63, period=6), + Element(symbol='Gd', atomic_number=64, period=6), + Element(symbol='Tb', atomic_number=65, period=6), + Element(symbol='Dy', atomic_number=66, period=6), + Element(symbol='Ho', atomic_number=67, period=6), + Element(symbol='Er', atomic_number=68, period=6), + Element(symbol='Tm', atomic_number=69, period=6), + Element(symbol='Yb', atomic_number=70, period=6), + Element(symbol='Lu', atomic_number=71, period=6), + Element(symbol='Hf', atomic_number=72, period=6), + Element(symbol='Ta', atomic_number=73, period=6), + Element(symbol='W', atomic_number=74, period=6), + Element(symbol='Re', atomic_number=75, period=6), + Element(symbol='Os', atomic_number=76, period=6), + Element(symbol='Ir', atomic_number=77, period=6), + Element(symbol='Pt', atomic_number=78, period=6), + Element(symbol='Au', atomic_number=79, period=6), + Element(symbol='Hg', atomic_number=80, period=6), + Element(symbol='Tl', atomic_number=81, period=6), + Element(symbol='Pb', atomic_number=82, period=6), + Element(symbol='Bi', atomic_number=83, period=6), + Element(symbol='Po', atomic_number=84, period=6), + Element(symbol='At', atomic_number=85, period=6), + Element(symbol='Rn', atomic_number=86, period=6), + Element(symbol='Fr', atomic_number=87, period=7), + Element(symbol='Ra', atomic_number=88, period=7), + Element(symbol='Ac', atomic_number=89, period=7), + Element(symbol='Th', atomic_number=90, period=7), + Element(symbol='Pa', atomic_number=91, period=7), + Element(symbol='U', atomic_number=92, period=7), + Element(symbol='Np', atomic_number=93, period=7), + Element(symbol='Pu', atomic_number=94, period=7), + Element(symbol='Am', atomic_number=95, period=7), + Element(symbol='Cm', atomic_number=96, period=7), + Element(symbol='Bk', atomic_number=97, period=7), + Element(symbol='Cf', atomic_number=98, period=7), + Element(symbol='Es', atomic_number=99, period=7), + Element(symbol='Fm', atomic_number=100, period=7), + Element(symbol='Md', atomic_number=101, period=7), + Element(symbol='No', atomic_number=102, period=7), + Element(symbol='Lr', atomic_number=103, period=7), + Element(symbol='Rf', atomic_number=104, period=7), + Element(symbol='Db', atomic_number=105, period=7), + Element(symbol='Sg', atomic_number=106, period=7), + Element(symbol='Bh', atomic_number=107, period=7), + Element(symbol='Hs', atomic_number=108, period=7), + Element(symbol='Mt', atomic_number=109, period=7), + Element(symbol='Ds', atomic_number=110, period=7), + Element(symbol='Rg', atomic_number=111, period=7), + Element(symbol='Cn', atomic_number=112, period=7), + Element(symbol='Nh', atomic_number=113, period=7), + Element(symbol='Fl', atomic_number=114, period=7), + Element(symbol='Mc', atomic_number=115, period=7), + Element(symbol='Lv', atomic_number=116, period=7), + Element(symbol='Ts', atomic_number=117, period=7), + Element(symbol='Og', atomic_number=118, period=7), +) + +ATOMIC_NUMS = {element.atomic_number: element for element in _ELEMENTS} + +# Lookup by symbol instead of atomic number. +SYMBOLS = {element.symbol: element for element in _ELEMENTS} + +# Lookup by period. +PERIODS = collections.defaultdict(list) +for element in _ELEMENTS: + PERIODS[element.period].append(element) +PERIODS = {period: tuple(elements) for period, elements in PERIODS.items()} diff --git a/DeepSolid/utils/kfac_ferminet_alpha/__init__.py b/DeepSolid/utils/kfac_ferminet_alpha/__init__.py new file mode 100644 index 0000000..c6ded3d --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""Module for anything that an end user would use.""" + +from DeepSolid.utils.kfac_ferminet_alpha.loss_functions import register_normal_predictive_distribution +from DeepSolid.utils.kfac_ferminet_alpha.loss_functions import register_squared_error_loss +from DeepSolid.utils.kfac_ferminet_alpha.optimizer import Optimizer diff --git a/DeepSolid/utils/kfac_ferminet_alpha/curvature_blocks.py b/DeepSolid/utils/kfac_ferminet_alpha/curvature_blocks.py new file mode 100644 index 0000000..011f46c --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/curvature_blocks.py @@ -0,0 +1,502 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +"""Module for all of the different curvature blocks.""" +import abc +from typing import Any, Callable, Dict, Mapping, MutableMapping, Optional, Sequence, Union +import jax +from jax import core +import jax.numpy as jnp + +from DeepSolid.utils.kfac_ferminet_alpha import tag_graph_matcher as tgm +from DeepSolid.utils.kfac_ferminet_alpha import utils + +_Arrays = Sequence[jnp.ndarray] +_BlockInfo = Mapping[str, Any] + + +class CurvatureBlock(utils.Stateful, abc.ABC): + """Top level class.""" + + def __init__(self, layer_tag_eq: tgm.jax_core.JaxprEqn): + super(CurvatureBlock, self).__init__() + self._layer_tag_eq = layer_tag_eq + + @property + def layer_tag_primitive(self) -> tgm.tags.LayerTag: + assert isinstance(self._layer_tag_eq.primitive, tgm.tags.LayerTag) + return self._layer_tag_eq.primitive + + @property + def outputs_shapes(self) -> Sequence[Sequence[int]]: + output_vars = self.layer_tag_primitive.split_all_inputs( + self._layer_tag_eq.invars)[0] + return jax.tree_map(lambda x: x.aval.shape, output_vars) + + @property + def inputs_shapes(self) -> Sequence[Sequence[int]]: + input_vars = self.layer_tag_primitive.split_all_inputs( + self._layer_tag_eq.invars)[1] + return jax.tree_map(lambda x: x.aval.shape, input_vars) + + @property + def params_shapes(self) -> Sequence[Sequence[int]]: + params_vars = self.layer_tag_primitive.split_all_inputs( + self._layer_tag_eq.invars)[2] + return jax.tree_map(lambda x: x.aval.shape, params_vars) + + @abc.abstractmethod + def init(self, rng: jnp.ndarray) -> MutableMapping[str, Any]: + """This initializes/creates all of the arrays for the state of the block. + + Usually this would include the arrays used for storing the curvature + approximation, as well as the arrays for storing any approximate + inverses/powers of the curvature block. + + Args: + rng: The Jax PRNG key to use if any of the state is supposed to be + initialized randomly. + Returns: + A mutable mapping of the state. + """ + + @abc.abstractmethod + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + pass + + @abc.abstractmethod + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + pass + + @abc.abstractmethod + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + pass + + +CurvatureBlockCtor = Callable[[core.JaxprEqn], CurvatureBlock] + + +@utils.Stateful.infer_class_state +class NaiveDiagonal(CurvatureBlock): + """The naively estimated diagonal block.""" + diagonal_factor: utils.WeightedMovingAverage + + def init(self, rng: jnp.ndarray) -> Dict[str, Any]: + del rng + return dict( + diagonal_factor=utils.WeightedMovingAverage.zero( + self.outputs_shapes[0]) + ) + + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + dw, = info["outputs_tangent"] + diagonal_update = dw * dw / batch_size + self.diagonal_factor.update(diagonal_update, ema_old, ema_new) + self.diagonal_factor.sync(pmap_axis_name) + + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + pass + + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + w, = vec + if exp == 1: + return w * (self.diagonal_factor.value + diagonal_weight), + elif exp == -1: + return w / (self.diagonal_factor.value + diagonal_weight), + else: + raise NotImplementedError() + + +@utils.Stateful.infer_class_state +class TwoKroneckerFactored(CurvatureBlock, abc.ABC): + """A factor that is the Kronecker product of two matrices.""" + inputs_factor: utils.WeightedMovingAverage + inputs_factor_inverse: jnp.ndarray + outputs_factor: utils.WeightedMovingAverage + outputs_factor_inverse: jnp.ndarray + extra_scale: Optional[Union[int, float, jnp.ndarray]] + + @property + def has_bias(self) -> bool: + return len(self._layer_tag_eq.invars) == 4 + + @abc.abstractmethod + def input_size(self) -> int: + pass + + @abc.abstractmethod + def output_size(self) -> int: + pass + + def compute_extra_scale(self) -> Optional[Union[int, float, jnp.ndarray]]: + return 1 + + def init(self, rng: jnp.ndarray) -> Dict[str, Any]: + # The extra scale is technically a constant, but in general it could be + # useful for anyone examining the state to know it explicitly, + # hence we actually keep it as part of the state. + d_in = self.input_size() + d_out = self.output_size() + return dict( + inputs_factor=utils.WeightedMovingAverage.zero([d_in, d_in]), + inputs_factor_inverse=jnp.zeros([d_in, d_in]), + outputs_factor=utils.WeightedMovingAverage.zero([d_out, d_out]), + outputs_factor_inverse=jnp.zeros([d_out, d_out]), + extra_scale=self.compute_extra_scale() + ) + + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + self.inputs_factor.sync(pmap_axis_name) + self.outputs_factor.sync(pmap_axis_name) + + # This computes the approximate inverse factor using the pi-adjusted + # inversion from the original KFAC paper. + # Note that the damping is divided by extra_scale since: + # (s * A kron B + lambda I)^-1 = s^-1 (A kron B + s^-1 * lambda I)^-1 + # And the extra division by the scale is included in `multiply_matpower`. + (self.inputs_factor_inverse, + self.outputs_factor_inverse) = utils.pi_adjusted_inverse( + factor_0=self.inputs_factor.value, + factor_1=self.outputs_factor.value, + damping=diagonal_weight / self.extra_scale, + pmap_axis_name=pmap_axis_name) + + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + if self.has_bias: + w, b = vec + vec = jnp.concatenate([w.reshape([-1, w.shape[-1]]), b[None]], axis=0) + else: + w, = vec + vec = w.reshape([-1, w.shape[-1]]) + if exp == 1: + inputs_factor, outputs_factor = (self.inputs_factor.value, + self.outputs_factor.value) + scale = self.extra_scale + elif exp == -1: + inputs_factor, outputs_factor = (self.inputs_factor_inverse, + self.outputs_factor_inverse) + scale = 1.0 / self.extra_scale + diagonal_weight = 0 + else: + raise NotImplementedError() + + result = jnp.matmul(inputs_factor, vec) + result = jnp.matmul(result, outputs_factor) + result = result * scale + diagonal_weight * vec + + if self.has_bias: + w_new, b_new = result[:-1], result[-1] + return w_new.reshape(w.shape), b_new + else: + return result.reshape(w.shape), + + +class DenseTwoKroneckerFactored(TwoKroneckerFactored): + """Factor for a standard dense layer.""" + + def input_size(self) -> int: + if self.has_bias: + return self.params_shapes[0][0] + 1 + else: + return self.params_shapes[0][0] + + def output_size(self) -> int: + return self.params_shapes[0][1] + + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + del pmap_axis_name + (x,), (dy,) = info["inputs"], info["outputs_tangent"] + utils.check_first_dim_is_batch_size(batch_size, x, dy) + + if self.has_bias: + x_one = jnp.ones_like(x[:, :1]) + x = jnp.concatenate([x, x_one], axis=1) + input_stats = jnp.matmul(jnp.conjugate(x).T, x) / batch_size + output_stats = jnp.matmul(jnp.conjugate(dy).T, dy) / batch_size + output_stats = output_stats.real + self.inputs_factor.update(input_stats, ema_old, ema_new) + self.outputs_factor.update(output_stats, ema_old, ema_new) + + +@utils.Stateful.infer_class_state +class ScaleAndShiftDiagonal(CurvatureBlock): + """A scale and shift block with a diagonal approximation to the curvature.""" + scale_factor: Optional[utils.WeightedMovingAverage] + shift_factor: Optional[utils.WeightedMovingAverage] + + @property + def has_scale(self) -> bool: + return self._layer_tag_eq.params["has_scale"] + + @property + def has_shift(self) -> bool: + return self._layer_tag_eq.params["has_shift"] + + def init(self, rng: jnp.ndarray) -> Dict[str, Any]: + del rng + if self.has_scale and self.has_shift: + return dict( + scale_factor=utils.WeightedMovingAverage.zero( + self.params_shapes[0] + ), + shift_factor=utils.WeightedMovingAverage.zero( + self.params_shapes[1] + ) + ) + elif self.has_scale: + return dict( + scale_factor=utils.WeightedMovingAverage.zero( + self.params_shapes[0] + ), + shift_factor=None + ) + elif self.has_shift: + return dict( + scale_factor=None, + shift_factor=utils.WeightedMovingAverage.zero( + self.params_shapes[0] + ), + ) + else: + raise ValueError("Neither `has_scale` nor `has_shift`.") + + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + (x,), (dy,) = info["inputs"], info["outputs_tangent"] + utils.check_first_dim_is_batch_size(batch_size, x, dy) + + if self.has_scale: + assert self.scale_factor is not None + scale_shape = info["params"][0].shape + full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape + axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0] + d_scale = jnp.sum(x * dy, axis=axis) + scale_diag_update = jnp.sum(jnp.conjugate(d_scale) * d_scale, axis=0) / batch_size + scale_diag_update = scale_diag_update.real + self.scale_factor.update(scale_diag_update, ema_old, ema_new) + self.scale_factor.sync(pmap_axis_name) + + if self.has_shift: + assert self.shift_factor is not None + shift_shape = info["params"][1].shape + full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape + axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0] + d_shift = jnp.sum(dy, axis=axis) + shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size + self.shift_factor.update(shift_diag_update, ema_old, ema_new) + self.shift_factor.sync(pmap_axis_name) + + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + pass + + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + if self.has_scale and self.has_shift: + factors = (self.scale_factor.value, self.shift_factor.value) + elif self.has_scale: + factors = (self.scale_factor.value,) + elif self.has_shift: + factors = (self.shift_factor.value,) + else: + raise ValueError("Neither `has_scale` nor `has_shift`.") + factors = jax.tree_map(lambda x: x + diagonal_weight, factors) + if exp == 1: + return jax.tree_multimap(jnp.multiply, vec, factors) + elif exp == -1: + return jax.tree_multimap(jnp.divide, vec, factors) + else: + raise NotImplementedError() + + +@utils.Stateful.infer_class_state +class ScaleAndShiftFull(CurvatureBlock): + """A scale and shift block with full approximation to the curvature.""" + factor: utils.WeightedMovingAverage + inverse_factor: jnp.ndarray + + @property + def _has_scale(self) -> bool: + return self._layer_tag_eq.params["has_scale"] + + @property + def _has_shift(self) -> bool: + return self._layer_tag_eq.params["has_shift"] + + def init(self, rng: jnp.ndarray) -> Dict[str, Any]: + del rng + dims = sum(utils.product(shape) for shape in self.params_shapes) + return dict( + factor=utils.WeightedMovingAverage.zero([dims, dims]), + inverse_factor=jnp.zeros([dims, dims]) + ) + + def update_curvature_matrix_estimate( + self, + info: _BlockInfo, + batch_size: int, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + del pmap_axis_name + (x,), (dy,) = info["inputs"], info["outputs_tangent"] + utils.check_first_dim_is_batch_size(batch_size, x, dy) + + grads = list() + if self._has_scale: + # Scale gradients + scale_shape = info["params"][0].shape + full_scale_shape = (1,) * (len(x.shape) - len(scale_shape)) + scale_shape + axis = [i for i, s in enumerate(full_scale_shape) if s == 1 and i != 0] + d_scale = jnp.sum(x * dy, axis=axis) + d_scale = d_scale.reshape([batch_size, -1]) + grads.append(d_scale) + + if self._has_shift: + # Shift gradients + shift_shape = info["params"][1].shape + full_shift_shape = (1,) * (len(x.shape) - len(shift_shape)) + shift_shape + axis = [i for i, s in enumerate(full_shift_shape) if s == 1 and i != 0] + d_shift = jnp.sum(dy, axis=axis) + d_shift = d_shift.reshape([batch_size, -1]) + grads.append(d_shift) + + grads = jnp.concatenate(grads, axis=1) + factor_update = jnp.matmul(grads.T, grads) / batch_size + self.factor.update(factor_update, ema_old, ema_new) + + def update_curvature_inverse_estimate( + self, + diagonal_weight: Union[float, jnp.ndarray], + pmap_axis_name: str + ) -> None: + self.factor.sync(pmap_axis_name) + self.inverse_factor = utils.psd_inv_cholesky(self.factor.value, + diagonal_weight) + + def multiply_matpower( + self, + vec: _Arrays, + exp: Union[float, int], + diagonal_weight: Union[float, jnp.ndarray] + ) -> _Arrays: + # Remember the vector is a tuple of all parameters + if self._has_scale and self._has_shift: + flat_vec = jnp.concatenate([v.flatten() for v in vec]) + else: + flat_vec = vec[0].flatten() + + if exp == 1: + flat_result = ( + jnp.matmul(self.factor.value, flat_vec) + diagonal_weight * flat_vec) + elif exp == -1: + flat_result = jnp.matmul(self.inverse_factor, flat_vec) + else: + raise NotImplementedError() + + if self._has_scale and self._has_shift: + scale_dims = int(vec[0].size) + scale_result = flat_result[:scale_dims].reshape(vec[0].shape) + shift_result = flat_result[scale_dims:].reshape(vec[1].shape) + return scale_result, shift_result + else: + return flat_vec.reshape(vec[0].shape), + + +_default_tag_to_block: MutableMapping[str, CurvatureBlockCtor] = dict( + dense_tag=DenseTwoKroneckerFactored, + generic_tag=NaiveDiagonal, + scale_and_shift_tag=ScaleAndShiftDiagonal, +) + + +def copy_default_tag_to_block() -> MutableMapping[str, CurvatureBlockCtor]: + return dict(_default_tag_to_block) + + +def get_default_tag_to_block(tag_name: str) -> CurvatureBlockCtor: + return _default_tag_to_block[tag_name] + + +def set_default_tag_to_block( + tag_name: str, + block_class: CurvatureBlockCtor, +) -> None: + _default_tag_to_block[tag_name] = block_class diff --git a/DeepSolid/utils/kfac_ferminet_alpha/distributions.py b/DeepSolid/utils/kfac_ferminet_alpha/distributions.py new file mode 100644 index 0000000..8565be3 --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/distributions.py @@ -0,0 +1,78 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""Module for all distribution implementations needed for the loss functions.""" +import math +import jax +import jax.numpy as jnp + + +class MultivariateNormalDiag: + """Multivariate normal distribution on `R^k`.""" + + def __init__( + self, + loc: jnp.ndarray, + scale_diag: jnp.ndarray): + """Initializes a MultivariateNormalDiag distribution. + + Args: + loc: Mean vector of the distribution. Can also be a batch of vectors. + scale_diag: Vector of standard deviations. + """ + super().__init__() + self._loc = loc + self._scale_diag = scale_diag + + @property + def loc(self) -> jnp.ndarray: + """Mean of the distribution.""" + return self._loc + + @property + def scale_diag(self) -> jnp.ndarray: + """Scale of the distribution.""" + return self._scale_diag + + def _num_dims(self) -> int: + """Dimensionality of the events.""" + return self._scale_diag.shape[-1] + + def _standardize(self, value: jnp.ndarray) -> jnp.ndarray: + return (value - self._loc) / self._scale_diag + + def log_prob(self, value: jnp.ndarray) -> jnp.ndarray: + """See `Distribution.log_prob`.""" + log_unnormalized = -0.5 * jnp.square(self._standardize(value)) + log_normalization = 0.5 * math.log(2 * math.pi) + jnp.log(self._scale_diag) + return jnp.sum(log_unnormalized - log_normalization, axis=-1) + + def mean(self) -> jnp.ndarray: + """Calculates the mean.""" + return self.loc + + def sample(self, seed: jnp.ndarray) -> jnp.ndarray: + """Samples an event. + + Args: + seed: PRNG key or integer seed. + + Returns: + A sample. + """ + eps = jax.random.normal(seed, self.loc.shape) + return self.loc + eps * self.scale_diag diff --git a/DeepSolid/utils/kfac_ferminet_alpha/estimator.py b/DeepSolid/utils/kfac_ferminet_alpha/estimator.py new file mode 100644 index 0000000..7b4510a --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/estimator.py @@ -0,0 +1,343 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""Defines the high-level Fisher estimator class.""" +import collections +from typing import Any, Callable, Mapping, Optional, Sequence, Union, TypeVar + +import jax +import jax.numpy as jnp +import jax.random as jnr +import numpy as np + +from DeepSolid.utils.kfac_ferminet_alpha import curvature_blocks +from DeepSolid.utils.kfac_ferminet_alpha import tracer +from DeepSolid.utils.kfac_ferminet_alpha import utils + +_CurvatureBlock = curvature_blocks.CurvatureBlock +TagMapping = Mapping[str, curvature_blocks.CurvatureBlockCtor] +BlockVector = Sequence[jnp.ndarray] + +_StructureT = TypeVar("_StructureT") +_OptionalStateT = TypeVar("_OptionalStateT", bound=Optional[Mapping[str, Any]]) + + +@utils.Stateful.infer_class_state +class CurvatureEstimator(utils.Stateful): + """Curvature estimator class supporting various curvature approximations.""" + blocks: "collections.OrderedDict[str, _CurvatureBlock]" + damping: Optional[jnp.ndarray] + + def __init__(self, + tagged_func: Callable[[Any], jnp.ndarray], + func_args: Sequence[Any], + l2_reg: Union[float, jnp.ndarray], + estimation_mode: str = "fisher_gradients", + params_index: int = 0, + layer_tag_to_block_cls: Optional[TagMapping] = None): + """Create a FisherEstimator object. + + Args: + tagged_func: The function which evaluates the model, in which layer and + loss tags has already been registered. + func_args: Arguments to trace the function for layer and loss tags. + l2_reg: Scalar. The L2 regularization coefficient, which represents + the following regularization function: `coefficient/2 ||theta||^2`. + estimation_mode: The type of curvature estimator to use. One of: * + 'fisher_gradients' - the basic estimation approach from the original + K-FAC paper. (Default) * 'fisher_curvature_prop' - method which + estimates the Fisher using self-products of random 1/-1 vectors times + "half-factors" of the + Fisher, as described here: https://arxiv.org/abs/1206.6464 * + 'fisher_exact' - is the obvious generalization of Curvature + Propagation to compute the exact Fisher (modulo any additional + diagonal or Kronecker approximations) by looping over one-hot + vectors for each coordinate of the output instead of using 1/-1 + vectors. It is more expensive to compute than the other three + options by a factor equal to the output dimension, roughly + speaking. * 'fisher_empirical' - computes the 'empirical' Fisher + information matrix (which uses the data's distribution for the + targets, as opposed to the true Fisher which uses the model's + distribution) and requires that each registered loss have + specified targets. * 'ggn_curvature_prop' - Analogous to + fisher_curvature_prop, but estimates the Generalized + Gauss-Newton matrix (GGN). * 'ggn_exact'- Analogous to + fisher_exact, but estimates the Generalized Gauss-Newton matrix + (GGN). + params_index: The index of the arguments accepted by `func` which + correspond to parameters. + layer_tag_to_block_cls: An optional dict mapping tags to specific classes + of block approximations, which to override the default ones. + """ + if estimation_mode not in ("fisher_gradients", "fisher_empirical", + "fisher_exact", "fisher_curvature_prop", + "ggn_exact", "ggn_curvature_prop"): + raise ValueError(f"Unrecognised estimation_mode={estimation_mode}.") + super().__init__() + self.tagged_func = tagged_func + self.l2_reg = l2_reg + self.estimation_mode = estimation_mode + self.params_index = params_index + self.vjp = tracer.trace_estimator_vjp(self.tagged_func) + + # Figure out the mapping from layer + self.layer_tag_to_block_cls = curvature_blocks.copy_default_tag_to_block() + if layer_tag_to_block_cls is None: + layer_tag_to_block_cls = dict() + layer_tag_to_block_cls = dict(**layer_tag_to_block_cls) + self.layer_tag_to_block_cls.update(layer_tag_to_block_cls) + + # Create the blocks + self._in_tree = jax.tree_structure(func_args) + self._jaxpr = jax.make_jaxpr(self.tagged_func)(*func_args).jaxpr + self._layer_tags, self._loss_tags = tracer.extract_tags(self._jaxpr) + self.blocks = collections.OrderedDict() + counters = dict() + for eqn in self._layer_tags: + cls = self.layer_tag_to_block_cls[eqn.primitive.name] + c = counters.get(cls.__name__, 0) + self.blocks[cls.__name__ + "_" + str(c)] = cls(eqn) + counters[cls.__name__] = c + 1 + + @property + def diagonal_weight(self) -> jnp.ndarray: + return self.l2_reg + self.damping + + def vectors_to_blocks( + self, + parameter_structured_vector: Any, + ) -> Sequence[BlockVector]: + """Splits the parameters to values for the corresponding blocks.""" + in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars) + params_vars = in_vars[self.params_index] + params_vars_flat = jax.tree_flatten(params_vars)[0] + params_values_flat = jax.tree_flatten(parameter_structured_vector)[0] + assert len(params_vars_flat) == len(params_values_flat) + params_dict = dict(zip(params_vars_flat, params_values_flat)) + per_block_vectors = [] + for eqn in self._layer_tags: + if eqn.primitive.name == "generic_tag": + block_vars = eqn.invars + else: + block_vars = eqn.primitive.split_all_inputs(eqn.invars)[2] + per_block_vectors.append(tuple(params_dict.pop(v) for v in block_vars)) + if params_dict: + raise ValueError(f"From the parameters the following structure is not " + f"assigned to any block: {params_dict}. Most likely " + f"this part of the parameters is not part of the graph " + f"reaching the losses.") + return tuple(per_block_vectors) + + def blocks_to_vectors(self, per_block_vectors: Sequence[BlockVector]) -> Any: + """Reverses the function self.vectors_to_blocks.""" + in_vars = jax.tree_unflatten(self._in_tree, self._jaxpr.invars) + params_vars = in_vars[self.params_index] + assigned_dict = dict() + for eqn, block_values in zip(self._layer_tags, per_block_vectors): + if eqn.primitive.name == "generic_tag": + block_params = eqn.invars + else: + block_params = eqn.primitive.split_all_inputs(eqn.invars)[2] + assigned_dict.update(zip(block_params, block_values)) + params_vars_flat, params_tree = jax.tree_flatten(params_vars) + params_values_flat = [assigned_dict[v] for v in params_vars_flat] + assert len(params_vars_flat) == len(params_values_flat) + return jax.tree_unflatten(params_tree, params_values_flat) + + def init( + self, + rng: jnp.ndarray, + init_damping: Optional[jnp.ndarray], + ) -> Mapping[str, Any]: + """Returns an initialized variables for the curvature approximations and the inverses..""" + return dict( + blocks=collections.OrderedDict( + (name, block.init(block_rng)) # + for (name, block), block_rng # + in zip(self.blocks.items(), jnr.split(rng, len(self.blocks)))), + damping=init_damping) + + @property + def mat_type(self) -> str: + return self.estimation_mode.split("_")[0] + + def vec_block_apply( + self, + func: Callable[[_CurvatureBlock, BlockVector], BlockVector], + parameter_structured_vector: Any, + ) -> Any: + """Executes func for each approximation block on vectors.""" + per_block_vectors = self.vectors_to_blocks(parameter_structured_vector) + assert len(per_block_vectors) == len(self.blocks) + results = jax.tree_multimap(func, tuple(self.blocks.values()), + per_block_vectors) + parameter_structured_result = self.blocks_to_vectors(results) + utils.check_structure_shapes_and_dtype(parameter_structured_vector, + parameter_structured_result) + return parameter_structured_result + + def multiply_inverse(self, parameter_structured_vector: Any) -> Any: + """Multiplies the vectors by the corresponding (damped) inverses of the blocks. + + Args: + parameter_structured_vector: Structure equivalent to the parameters of the + model. + + Returns: + A structured identical to `vectors` containing the product. + """ + return self.multiply_matpower(parameter_structured_vector, -1) + + def multiply(self, parameter_structured_vector: Any) -> Any: + """Multiplies the vectors by the corresponding (damped) blocks. + + Args: + parameter_structured_vector: A vector in the same structure as the + parameters of the model. + + Returns: + A structured identical to `vectors` containing the product. + """ + return self.multiply_matpower(parameter_structured_vector, 1) + + def multiply_matpower( + self, + parameter_structured_vector: _StructureT, + exp: int, + ) -> _StructureT: + """Multiplies the vectors by the corresponding matrix powers of the blocks. + + Args: + parameter_structured_vector: A vector in the same structure as the + parameters of the model. + exp: A float representing the power to raise the blocks by before + multiplying it by the vector. + + Returns: + A structured identical to `vectors` containing the product. + """ + + def func(block: _CurvatureBlock, vec: BlockVector) -> BlockVector: + return block.multiply_matpower(vec, exp, self.diagonal_weight) + + return self.vec_block_apply(func, parameter_structured_vector) + + def update_curvature_matrix_estimate( + self, + ema_old: Union[float, jnp.ndarray], + ema_new: Union[float, jnp.ndarray], + batch_size: int, + rng: jnp.ndarray, + func_args: Sequence[Any], + pmap_axis_name: str, + ) -> None: + """Updates the curvature estimate.""" + + # Compute the losses and the VJP function from the function inputs + losses, losses_vjp = self.vjp(func_args) + + # Helper function that updates the blocks given a vjp vector + def _update_blocks(vjp_vec_, ema_old_, ema_new_): + blocks_info_ = losses_vjp(vjp_vec_) + for block_, block_info_ in zip(self.blocks.values(), blocks_info_): + block_.update_curvature_matrix_estimate( + info=block_info_, + batch_size=batch_size, + ema_old=ema_old_, + ema_new=ema_new_, + pmap_axis_name=pmap_axis_name) + + if self.estimation_mode == "fisher_gradients": + keys = jnr.split(rng, len(losses)) if len(losses) > 1 else [rng] + vjp_vec = tuple( + loss.grad_of_evaluate_on_sample(key, coefficient_mode="sqrt") + for loss, key in zip(losses, keys)) + _update_blocks(vjp_vec, ema_old, ema_new) + + elif self.estimation_mode in ("fisher_curvature_prop", + "ggn_curvature_prop"): + keys = jnr.split(rng, len(losses)) if len(losses) > 1 else [rng] + vjp_vec = [] + for loss, key in zip(losses, keys): + if self.estimation_mode == "fisher_curvature_prop": + random_b = jnr.bernoulli(key, shape=loss.fisher_factor_inner_shape()) + vjp_vec.append(loss.multiply_fisher_factor(random_b * 2.0 - 1.0)) + else: + random_b = jnr.bernoulli(key, shape=loss.ggn_factor_inner_shape()) + vjp_vec.append(loss.multiply_ggn_factor(random_b * 2.0 - 1.0)) + _update_blocks(tuple(vjp_vec), ema_old, ema_new) + + elif self.estimation_mode in ("fisher_exact", "ggn_exact"): + # We use the following trick to simulate summation. The equation is: + # estimate = ema_old * estimate + ema_new * (sum_i estimate_index_i) + # weight = ema_old * weight + ema_new + # Instead we update the estimate n times with the following updates: + # for k = 1 + # estimate_k = ema_old * estimate + (ema_new/n) * (n*estimate_index_k) + # weight_k = ema_old * weight + (ema_new/n) + # for k > 1: + # estimate_k = 1.0 * estimate_k-1 + (ema_new/n) * (n*estimate_index_k) + # weight_k = 1.0 * weight_k-1 + (ema_new/n) + # Which is mathematically equivalent to the original version. + zero_tangents = jax.tree_map(jnp.zeros_like, + list(loss.inputs for loss in losses)) + if self.estimation_mode == "fisher_exact": + num_indices = [ + (l, int(np.prod(l.fisher_factor_inner_shape[1:]))) for l in losses + ] + else: + num_indices = [ + (l, int(np.prod(l.ggn_factor_inner_shape()))) for l in losses + ] + total_num_indices = sum(n for _, n in num_indices) + for i, (loss, loss_num_indices) in enumerate(num_indices): + for index in range(loss_num_indices): + vjp_vec = zero_tangents.copy() + if self.estimation_mode == "fisher_exact": + vjp_vec[i] = loss.multiply_fisher_factor_replicated_one_hot([index]) + else: + vjp_vec[i] = loss.multiply_ggn_factor_replicated_one_hot([index]) + if isinstance(vjp_vec[i], jnp.ndarray): + # In the special case of only one parameter, it still needs to be a + # tuple for the tangents. + vjp_vec[i] = (vjp_vec[i],) + vjp_vec[i] = jax.tree_map(lambda x: x * total_num_indices, vjp_vec[i]) + _update_blocks(tuple(vjp_vec), ema_old, ema_new / total_num_indices) + ema_old = 1.0 + + elif self.estimation_mode == "fisher_empirical": + raise NotImplementedError() + else: + raise ValueError(f"Unrecognised estimation_mode={self.estimation_mode}") + + def update_curvature_estimate_inverse( + self, + pmap_axis_name: str, + state: _OptionalStateT, + ) -> _OptionalStateT: + if state is not None: + old_state = self.get_state() + self.set_state(state) + for block in self.blocks.values(): + block.update_curvature_inverse_estimate(self.diagonal_weight, + pmap_axis_name) + if state is None: + return None + else: + state = self.pop_state() + self.set_state(old_state) + return state diff --git a/DeepSolid/utils/kfac_ferminet_alpha/example.py b/DeepSolid/utils/kfac_ferminet_alpha/example.py new file mode 100644 index 0000000..c97bfc0 --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/example.py @@ -0,0 +1,171 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example of running KFAC.""" +from absl import app +from absl import flags +import jax +import jax.numpy as jnp + +import numpy as np +import DeepSolid.utils.kfac_ferminet_alpha as kfac_ferminet_alpha +from DeepSolid.utils.kfac_ferminet_alpha import utils + + +TRAINING_STEPS = flags.DEFINE_integer( + name="training_steps", + default=100, + help="Number of training steps to perform") +BATCH_SIZE = flags.DEFINE_integer( + name="batch_size", default=128, help="Batch size") +LEARNING_RATE = flags.DEFINE_float( + name="learning_rate", default=1e-3, help="Learning rate") +L2_REG = flags.DEFINE_float( + name="l2_reg", default=1e-3, help="L2 regularization coefficient") +MOMENTUM = flags.DEFINE_float( + name="momentum", default=0.8, help="Momentum coefficient") +DAMPING = flags.DEFINE_float( + name="damping", default=1e-2, help="Damping coefficient") +MULTI_DEVICE = flags.DEFINE_bool( + name="multi_device", + default=False, + help="Whether the computation should be replicated across multiple devices") +SEED = flags.DEFINE_integer(name="seed", default=12412321, help="JAX RNG seed") + + +def glorot_uniform(shape, key): + dim_in = np.prod(shape[:-1]) + dim_out = shape[-1] + c = jnp.sqrt(6 / (dim_in + dim_out)) + return jax.random.uniform(key, shape=shape, minval=-c, maxval=c) + + +def fully_connected_layer(params, x): + w, b = params + return jnp.matmul(x, w) + b[None] + + +def model_init(rng_key, batch, encoder_sizes=(1000, 500, 250, 30)): + """Initialize the standard autoencoder.""" + x_size = batch.shape[-1] + decoder_sizes = encoder_sizes[len(encoder_sizes) - 2::-1] + sizes = (x_size,) + encoder_sizes + decoder_sizes + (x_size,) + keys = jax.random.split(rng_key, len(sizes) - 1) + params = [] + for rng_key, dim_in, dim_out in zip(keys, sizes, sizes[1:]): + # Glorot uniform initialization + w = glorot_uniform((dim_in, dim_out), rng_key) + b = jnp.zeros([dim_out]) + params.append((w, b)) + return params, None + + +def model_loss(params, inputs, l2_reg): + """Evaluate the standard autoencoder.""" + h = inputs.reshape([inputs.shape[0], -1]) + for i, layer_params in enumerate(params): + h = fully_connected_layer(layer_params, h) + # Last layer does not have a nonlinearity + if i % 4 != 3: + h = jnp.tanh(h) + l2_value = 0.5 * sum(jnp.square(p).sum() for p in jax.tree_leaves(params)) + error = jax.nn.sigmoid(h) - inputs.reshape([inputs.shape[0], -1]) + mean_squared_error = jnp.mean(jnp.sum(error * error, axis=1), axis=0) + regularized_loss = mean_squared_error + l2_reg * l2_value + + return regularized_loss, dict(mean_squared_error=mean_squared_error) + + +def random_data(multi_device, batch_shape, rng): + if multi_device: + shape = (multi_device,) + tuple(batch_shape) + else: + shape = tuple(batch_shape) + while True: + rng, key = jax.random.split(rng) + yield jax.random.normal(key, shape) + + +def main(argv): + del argv # Unused. + + learning_rate = jnp.asarray([LEARNING_RATE.value]) + momentum = jnp.asarray([MOMENTUM.value]) + damping = jnp.asarray([DAMPING.value]) + + # RNG keys + global_step = jnp.zeros([]) + rng = jax.random.PRNGKey(SEED.value) + params_key, opt_key, step_key, data_key = jax.random.split(rng, 4) + dataset = random_data(MULTI_DEVICE.value, (BATCH_SIZE.value, 20), data_key) + example_batch = next(dataset) + + if MULTI_DEVICE.value: + global_step = utils.replicate_all_local_devices(global_step) + learning_rate = utils.replicate_all_local_devices(learning_rate) + momentum = utils.replicate_all_local_devices(momentum) + damping = utils.replicate_all_local_devices(damping) + params_key, opt_key = utils.replicate_all_local_devices( + (params_key, opt_key)) + step_key = utils.make_different_rng_key_on_all_devices(step_key) + split_key = jax.pmap(lambda x: tuple(jax.random.split(x))) + jit_init_parameters_func = jax.pmap(model_init) + else: + split_key = jax.random.split + jit_init_parameters_func = jax.jit(model_init) + + # Initialize or load parameters + params, func_state = jit_init_parameters_func(params_key, example_batch) + + # Make optimizer + optim = kfac_ferminet_alpha.Optimizer( + value_and_grad_func=jax.value_and_grad( + lambda p, x: model_loss(p, x, L2_REG.value), has_aux=True), + l2_reg=L2_REG.value, + value_func_has_aux=True, + value_func_has_state=False, + value_func_has_rng=False, + learning_rate_schedule=None, + momentum_schedule=None, + damping_schedule=None, + norm_constraint=1.0, + num_burnin_steps=10, + ) + + # Initialize optimizer + opt_state = optim.init(params, opt_key, example_batch, func_state) + + for t in range(TRAINING_STEPS.value): + step_key, key_t = split_key(step_key) + params, opt_state, stats = optim.step( + params, + opt_state, + key_t, + dataset, + learning_rate=learning_rate, + momentum=momentum, + damping=damping) + global_step = global_step + 1 + + # Log any of the statistics + print(f"iteration: {t}") + print(f"mini-batch loss = {stats['loss']}") + if "aux" in stats: + for k, v in stats["aux"].items(): + print(f"{k} = {v}") + print("----") + + +if __name__ == "__main__": + app.run(main) diff --git a/DeepSolid/utils/kfac_ferminet_alpha/layers_and_loss_tags.py b/DeepSolid/utils/kfac_ferminet_alpha/layers_and_loss_tags.py new file mode 100644 index 0000000..331136b --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/layers_and_loss_tags.py @@ -0,0 +1,357 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""A module for registering already known functions for tagging patterns.""" +import functools + +from typing import Sequence, Tuple, TypeVar + +import jax +from jax import core as jax_core +from jax import lax +from jax import lib as jax_lib +from jax.interpreters import batching as jax_batching +import jax.numpy as jnp + +_T = TypeVar("_T") + + +class LossTag(jax_core.Primitive): + """A tagging primitive specifically for losses.""" + multiple_results = True + + def __init__(self, cls, num_inputs: int, num_targets: int = 1): + super().__init__(cls.__name__ + "_tag") + self._cls = cls + self._num_inputs = num_inputs + self._num_targets = num_targets + jax.xla.translations[self] = self.xla_translation + jax.ad.primitive_jvps[self] = self.jvp + # This line defines how does the tag behave under vmap. It is required for + # any primitive that can be used inside a vmap. The reason why we want to + # allow this is two fold - one to not break user code when the tags are not + # used at all, and two - to be able to define a network with code for a + # single example which is the vmap-ed for a batch. + jax_batching.primitive_batchers[self] = self.batching + + @property + def num_inputs(self) -> int: + return self._num_inputs + + @property + def num_targets(self) -> int: + return self._num_targets + + def loss(self, *args, weight: float = 1.0, **kwargs): + return self._cls(*args, weight=weight, **kwargs) + + def loss_evaluate(self, *args, weight: float = 1.0, **kwargs): + return self.loss(*args, weight=weight, **kwargs).evaluate() + + def get_outputs(self, *args, weight: float, return_loss: bool, **kwargs): + if len(args) < self.num_inputs: + raise ValueError("Inputs to the tag are not enough.") + if len(args) < self.num_inputs + self.num_targets: + if len(args) != self.num_inputs: + raise ValueError("Inputs to the tag are not quite enough.") + if return_loss: + raise ValueError("Can not have return_loss=True when there are no " + "targets.") + return args + if len(args) > self.num_inputs + self.num_targets: + raise ValueError("Inputs to the tag are too many.") + if return_loss: + return self.loss(*args, weight=weight, **kwargs).evaluate() + else: + return args + + def impl(self, *args, weight: float, return_loss: bool, **kwargs): + return self.get_outputs(*args, weight=weight, return_loss=return_loss) + + def abstract_eval(self, *args, weight: float, return_loss: bool, **kwargs): + return self.get_outputs(*args, weight=weight, return_loss=return_loss) + + def xla_translation( + self, + c, + *args, + weight: float = 1.0, + return_loss: bool = False, + **kwargs, + ): + outputs = self.get_outputs( + *args, weight=weight, return_loss=return_loss, **kwargs) + if isinstance(outputs, tuple): + return jax_lib.xla_client.ops.Tuple(c, outputs) + return outputs + + def jvp( + self, + arg_values, + arg_tangents, + weight: float, + return_loss: bool, + **kwargs, + ): + if len(arg_values) != len(arg_tangents): + raise ValueError("Values and tangents are not the same length.") + primal_output = self.bind( + *arg_values, weight=weight, return_loss=return_loss, **kwargs) + if len(arg_values) == self.num_inputs: + tangents_out = self.get_outputs( + *arg_tangents, weight=weight, return_loss=return_loss, **kwargs) + elif return_loss: + tangents_out = jax.jvp( + functools.partial(self.loss_evaluate, weight=weight, **kwargs), + arg_tangents, arg_tangents)[1] + else: + tangents_out = arg_tangents + return primal_output, tangents_out + + def batching(self, batched_args, batched_dims, **kwargs): + return self.bind(*batched_args, **kwargs), batched_dims[0] + + +class LayerTag(jax_core.Primitive): + """A tagging primitive that is used to mark/tag computation.""" + + def __init__(self, name: str, num_inputs: int, num_outputs: int): + super().__init__(name) + if num_outputs > 1: + raise NotImplementedError( + f"Only single outputs are supported, got: num_outputs={num_outputs}") + self._num_outputs = num_outputs + self._num_inputs = num_inputs + jax.xla.translations[self] = self.xla_translation + jax.ad.deflinear(self, self.transpose) + jax.ad.primitive_transposes[self] = self.transpose + # This line defines how does the tag behave under vmap. It is required for + # any primitive that can be used inside a vmap. The reason why we want to + # allow this is two fold - one to not break user code when the tags are not + # used at all, and two - to be able to define a network with code for a + # single example which is the vmap-ed for a batch. + jax_batching.primitive_batchers[self] = self.batching + + @property + def num_outputs(self) -> int: + return self._num_outputs + + @property + def num_inputs(self) -> int: + return self._num_inputs + + def split_all_inputs( + self, + all_inputs: Sequence[_T], + ) -> Tuple[Sequence[_T], Sequence[_T], Sequence[_T]]: + outputs = tuple(all_inputs[:self.num_outputs]) + inputs = tuple(all_inputs[self.num_outputs:self.num_outputs + + self.num_inputs]) + params = tuple(all_inputs[self.num_outputs + self.num_inputs:]) + return outputs, inputs, params + + def get_outputs(self, *operands: _T, **kwargs) -> _T: + assert self.num_outputs == 1 + return operands[0] + + def xla_translation(self, c, *operands: _T, **kwargs) -> _T: + return self.get_outputs(*operands, **kwargs) + + @staticmethod + def transpose(cotangent, *operands, **kwargs): + return (cotangent,) + (None,) * (len(operands) - 1) + + def impl(self, *operands, **kwargs): + return self.get_outputs(*operands, **kwargs) + + def abstract_eval(self, *abstract_operands, **kwargs): + return self.get_outputs(*abstract_operands, **kwargs) + + def batching(self, batched_operands, batched_dims, **kwargs): + return self.bind(*batched_operands, **kwargs), batched_dims[0] + + +# _____ _ +# / ____| (_) +# | | __ ___ _ __ ___ _ __ _ ___ +# | | |_ |/ _ \ '_ \ / _ \ '__| |/ __| +# | |__| | __/ | | | __/ | | | (__ +# \_____|\___|_| |_|\___|_| |_|\___| +# +# + +generic_tag = LayerTag(name="generic_tag", num_inputs=0, num_outputs=1) + + +def register_generic(parameter: _T) -> _T: + return generic_tag.bind(parameter) + + +# _____ +# | __ \ +# | | | | ___ _ __ ___ ___ +# | | | |/ _ \ '_ \/ __|/ _ \ +# | |__| | __/ | | \__ \ __/ +# |_____/ \___|_| |_|___/\___| +# + +dense_tag = LayerTag(name="dense_tag", num_inputs=1, num_outputs=1) + + +def register_dense(y, x, w, b=None): + if b is None: + return dense_tag.bind(y, x, w) + return dense_tag.bind(y, x, w, b) + + +def dense_func(x, params): + """Example of a dense layer function.""" + w = params[0] + y = jnp.matmul(x, w) + if len(params) == 1: + # No bias + return y + # Add bias + return y + params[1] + + +def dense_tagging(jaxpr, inverse_map, values_map): + """Correctly registers a dense layer pattern.""" + del inverse_map + in_values = [values_map[v] for v in jaxpr.invars] + out_values = [values_map[v] for v in jaxpr.outvars] + return register_dense(out_values[0], *in_values) + + +# ___ _____ _____ _ _ _ +# |__ \| __ \ / ____| | | | | (_) +# ) | | | | | | ___ _ ____ _____ | |_ _| |_ _ ___ _ __ +# / /| | | | | | / _ \| '_ \ \ / / _ \| | | | | __| |/ _ \| "_ \ +# / /_| |__| | | |___| (_) | | | \ V / (_) | | |_| | |_| | (_) | | | | +# |____|_____/ \_____\___/|_| |_|\_/ \___/|_|\__,_|\__|_|\___/|_| |_| +# + +conv2d_tag = LayerTag(name="conv2d_tag", num_inputs=1, num_outputs=1) + + +def register_conv2d(y, x, w, b=None, **kwargs): + if b is None: + return conv2d_tag.bind(y, x, w, **kwargs) + return conv2d_tag.bind(y, x, w, b, **kwargs) + + +def conv2d_func(x, params): + """Example of a conv2d layer function.""" + w = params[0] + y = lax.conv_general_dilated( + x, + w, + window_strides=(2, 2), + padding="SAME", + dimension_numbers=("NHWC", "HWIO", "NHWC")) + if len(params) == 1: + # No bias + return y + # Add bias + return y + params[1][None, None, None] + + +def conv2d_tagging(jaxpr, inverse_map, values_map): + """Correctly registers a conv2d layer pattern.""" + in_values = [values_map[v] for v in jaxpr.invars] + out_values = [values_map[v] for v in jaxpr.outvars] + keys = [k for k in inverse_map.keys() if isinstance(k, str)] + keys = [k for k in keys if k.startswith("conv_general_dilated")] + if len(keys) != 1: + raise ValueError("Did not find any conv_general_dilated!") + kwargs = inverse_map[keys[0]].params + return register_conv2d(out_values[0], *in_values, **kwargs) + + +# _____ _ _ _____ _ _ __ _ +# / ____| | | | | / ____| | (_)/ _| | +# | (___ ___ __ _| | ___ __ _ _ __ __| | | (___ | |__ _| |_| |_ +# \___ \ / __/ _` | |/ _ \ / _` | '_ \ / _` | \___ \| '_ \| | _| __| +# ____) | (_| (_| | | __/ | (_| | | | | (_| | ____) | | | | | | | |_ +# |_____/ \___\__,_|_|\___| \__,_|_| |_|\__,_| |_____/|_| |_|_|_| \__| +# + +scale_and_shift_tag = LayerTag( + name="scale_and_shift_tag", num_inputs=1, num_outputs=1) + + +def register_scale_and_shift(y, args, has_scale: bool, has_shift: bool): + assert has_scale or has_shift + x, args = args[0], args[1:] + return scale_and_shift_tag.bind( + y, x, *args, has_scale=has_scale, has_shift=has_shift) + + +def scale_and_shift_func(x, params, has_scale: bool, has_shift: bool): + """Example of a scale and shift function.""" + if has_scale and has_shift: + scale, shift = params + return x * scale + shift + elif has_scale: + return x * params[0] + elif has_shift: + return x + params[0] + else: + raise ValueError() + + +def scale_and_shift_tagging( + jaxpr, + inverse_map, + values_map, + has_scale: bool, + has_shift: bool, +): + """Correctly registers a scale and shift layer pattern.""" + del inverse_map + in_values = [values_map[v] for v in jaxpr.invars] + out_values = [values_map[v] for v in jaxpr.outvars] + return register_scale_and_shift(out_values[0], in_values, has_scale, + has_shift) + + +def batch_norm_func( + inputs: Tuple[jnp.ndarray, jnp.ndarray], + params: Tuple[jnp.ndarray, jnp.ndarray], +) -> jnp.ndarray: + """Example of batch norm as is defined in Haiku.""" + x, y = inputs + scale, shift = params + inv = scale * y + return x * inv + shift + + +def batch_norm_tagging_func( + jaxpr, + inverse_map, + values_map, + has_scale: bool, + has_shift: bool, +): + """Correctly registers a batch norm layer pattern as is defined in Haiku.""" + del inverse_map + in_values = [values_map[v] for v in jaxpr.invars] + out_values = [values_map[v] for v in jaxpr.outvars] + # The first two are both multipliers with the scale so we merge them + in_values = [in_values[0] * in_values[1]] + in_values[2:] + return register_scale_and_shift(out_values[0], in_values, has_scale, + has_shift) diff --git a/DeepSolid/utils/kfac_ferminet_alpha/loss_functions.py b/DeepSolid/utils/kfac_ferminet_alpha/loss_functions.py new file mode 100644 index 0000000..740eaf1 --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/loss_functions.py @@ -0,0 +1,656 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""Loss functions to be used by LayerCollection.""" +import abc +from typing import Tuple, Optional, Union, Sequence + +import jax +import jax.numpy as jnp + +from DeepSolid.utils.kfac_ferminet_alpha import distributions +from DeepSolid.utils.kfac_ferminet_alpha import layers_and_loss_tags as tags +from DeepSolid.utils.kfac_ferminet_alpha import utils + +ArrayPair = Tuple[jnp.ndarray, jnp.ndarray] +FloatArray = Union[float, jnp.ndarray] +Index = Tuple[int] + + +class LossFunction(abc.ABC): + """Abstract base class for loss functions. + + Note that unlike typical loss functions used in neural networks these are + neither summed nor averaged over the batch and hence the output of evaluate() + will not be a scalar. It is up to the user to then to correctly manipulate + them as needed. + """ + + def __init__(self, weight: FloatArray): + self._weight = weight + + @property + def weight(self) -> FloatArray: + return self._weight + + @property + @abc.abstractmethod + def targets(self) -> Optional[jnp.ndarray]: + """The targets being predicted by the model. + + Returns: + None or Tensor of appropriate shape for calling self._evaluate() on. + """ + pass + + @property + @abc.abstractmethod + def inputs(self) -> Sequence[jnp.ndarray]: + """The inputs to the loss function (excluding the targets).""" + pass + + @abc.abstractmethod + def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]): + pass + + def evaluate( + self, + targets: Optional[jnp.ndarray] = None, + coefficient_mode: str = "regular", + ) -> jnp.ndarray: + """Evaluate the loss function on the targets.""" + if targets is None and self.targets is None: + raise ValueError("Cannot evaluate losses with unspecified targets.") + elif targets is None: + targets = self.targets + if coefficient_mode == "regular": + multiplier = self.weight + elif coefficient_mode == "sqrt": + multiplier = jnp.sqrt(self.weight) + elif coefficient_mode == "off": + multiplier = 1.0 + else: + raise ValueError(f"Unrecognized coefficient_mode={coefficient_mode}.") + return self._evaluate(targets) * multiplier + + @abc.abstractmethod + def _evaluate(self, targets: jnp.ndarray) -> jnp.ndarray: + """Evaluates the negative log probability of the targets. + + Args: + targets: Tensor that distribution can calculate log_prob() of. + + Returns: + negative log probability of each target, summed across all targets. + """ + pass + + def grad_of_evaluate( + self, + targets: Optional[jnp.ndarray], + coefficient_mode: str, + ) -> Sequence[jnp.ndarray]: + """Evaluates the gradient of the loss function. + + Note that the targets of the loss must not be `None`. + + Args: + targets: The potential targets on which to evaluate the gradient. + coefficient_mode: The coefficient mode to use for evaluation. + + Returns: + The gradient of the loss evaluation function with respect to the inputs. + """ + def evaluate_sum(inputs: Sequence[jnp.ndarray]) -> jnp.ndarray: + instance = self.copy_with_different_inputs(inputs) + return jnp.sum(instance.evaluate(targets, coefficient_mode)) + return jax.grad(evaluate_sum)(self.inputs) + + def multiply_ggn(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by the GGN. + + Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) + of the loss function with respect to its inputs. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the 'inputs' + property. + + Returns: + The vector right-multiplied by the GGN. Will be of the same shape(s) + as the 'inputs' property. + """ + return utils.scalar_mul(self.multiply_ggn_unweighted(vector), self.weight) + + @abc.abstractmethod + def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + """Same as `multiply_ggn`, but without taking into account the weight.""" + pass + + def multiply_ggn_factor(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by a factor B of the GGN. + + Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = G where G is the GGN, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'ggn_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_ggn_factor_unweighted(vector), jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + """Same as `multiply_ggn_factor`, but without taking into account the weight.""" + pass + + def multiply_ggn_factor_transpose(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by the transpose of a factor B of the GGN. + + Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = G where G is the GGN, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the 'inputs' + property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'ggn_factor_inner_shape' property. + """ + return utils.scalar_mul( + self.multiply_ggn_factor_transpose_unweighted(vector), + jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_ggn_factor_transpose_unweighted( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + """Same as `multiply_ggn_factor_transpose`, but without taking into account the weight.""" + pass + + def multiply_ggn_factor_replicated_one_hot(self, index: Index) -> jnp.ndarray: + """Right-multiply a replicated-one-hot vector by a factor B of the GGN. + + Here the 'GGN' is the GGN matrix (whose definition is slightly flexible) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = G where G is the GGN, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements of + the 'ggn_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B^T. Will be of the same shape(s) as the + 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_ggn_factor_replicated_one_hot_unweighted(index), + jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_ggn_factor_replicated_one_hot_unweighted( + self, + index: Index + ) -> jnp.ndarray: + pass + + @property + @abc.abstractmethod + def ggn_factor_inner_shape(self) -> Sequence[int]: + """The shape of the tensor returned by multiply_ggn_factor.""" + pass + + +class NegativeLogProbLoss(LossFunction): + """Abstract base class for loss functions that are negative log probs.""" + + @property + def inputs(self): + return self.params + + @property + @abc.abstractmethod + def params(self): + """Parameters to the underlying distribution.""" + pass + + def multiply_fisher(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by the Fisher. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the 'inputs' + property. + + Returns: + The vector right-multiplied by the Fisher. Will be of the same shape(s) + as the 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_fisher_unweighted(vector), self.weight) + + @abc.abstractmethod + def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + pass + + def multiply_fisher_factor(self, vector: jnp.ndarray) -> jnp.ndarray: + """Right-multiply a vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'fisher_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_fisher_factor_unweighted(vector), jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_fisher_factor_unweighted( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + pass + + def multiply_fisher_factor_transpose( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + """Right-multiply a vector by the transpose of a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the 'inputs' + property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'fisher_factor_inner_shape' property. + """ + return utils.scalar_mul( + self.multiply_fisher_factor_transpose_unweighted(vector), + jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_fisher_factor_transpose_unweighted( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + pass + + def multiply_fisher_factor_replicated_one_hot( + self, + index: Index + ) -> jnp.ndarray: + """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribution (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements of + the 'fisher_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + return utils.scalar_mul( + self.multiply_fisher_factor_replicated_one_hot_unweighted(index), + jnp.sqrt(self.weight)) + + @abc.abstractmethod + def multiply_fisher_factor_replicated_one_hot_unweighted( + self, + index: Index + ) -> jnp.ndarray: + pass + + @property + @abc.abstractmethod + def fisher_factor_inner_shape(self) -> Sequence[int]: + """The shape of the tensor returned by multiply_fisher_factor.""" + pass + + @abc.abstractmethod + def sample(self, rng_key: jnp.ndarray) -> jnp.ndarray: + """Sample 'targets' from the underlying distribution.""" + pass + + def grad_of_evaluate_on_sample( + self, + rng_key: jnp.ndarray, + coefficient_mode: str, + ) -> Sequence[jnp.ndarray]: + """Evaluates the gradient of the log probability on a random sample. + + Args: + rng_key: Jax PRNG key for sampling. + coefficient_mode: The coefficient mode to use for evaluation. + + Returns: + The gradient of the log probability of targets sampled from the + distribution. + """ + return self.grad_of_evaluate(self.sample(rng_key), coefficient_mode) + + +class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss, abc.ABC): + """Base class for neg log prob losses whose inputs are 'natural' parameters. + + We will take the GGN of the loss to be the Fisher associated with the + distribution, which also happens to be equal to the Hessian for this class + of loss functions. See here: https://arxiv.org/abs/1412.1193 + + 'Natural parameters' are defined for exponential-family models. See for + example: https://en.wikipedia.org/wiki/Exponential_family + """ + + def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + return self.multiply_fisher_unweighted(vector) + + def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + return self.multiply_fisher_factor_unweighted(vector) + + def multiply_ggn_factor_transpose_unweighted( + self, + vector: jnp.ndarray + ) -> jnp.ndarray: + return self.multiply_fisher_factor_transpose_unweighted(vector) + + def multiply_ggn_factor_replicated_one_hot_unweighted( + self, + index: Index + ) -> jnp.ndarray: + return self.multiply_fisher_factor_replicated_one_hot_unweighted(index) + + @property + def ggn_factor_inner_shape(self) -> Sequence[int]: + return self.fisher_factor_inner_shape + + +class DistributionNegativeLogProbLoss(NegativeLogProbLoss): + """Base class for neg log prob losses that use the distribution classes.""" + + @property + @abc.abstractmethod + def dist(self): + """The underlying distribution instance.""" + pass + + def _evaluate(self, targets: jnp.ndarray): + return -self.dist.log_prob(targets) + + def sample(self, rng_key: jnp.ndarray): + return self.dist.sample(seed=rng_key) + + @property + def fisher_factor_inner_shape(self) -> Sequence[int]: + return self.dist.mean().shape + + +class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for a normal distribution parameterized by a mean vector. + + + Note that the covariance is treated as the identity divided by 2. + Also note that the Fisher for such a normal distribution with respect the mean + parameter is given by: + + F = (1 / variance) * I + + See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. + """ + + def __init__( + self, + mean: jnp.ndarray, + targets: Optional[jnp.ndarray] = None, + variance: float = 0.5, + weight: float = 1.0, + ): + super().__init__(weight=weight) + self._mean = mean + self._targets = targets + self._variance = variance + if not isinstance(variance, float): + raise ValueError("The `variance` argument should be python float.") + + @property + def targets(self) -> Optional[jnp.ndarray]: + return self._targets + + @property + def dist(self): + scale_diag = jnp.full_like(self._mean, jnp.sqrt(self._variance)) + return distributions.MultivariateNormalDiag(self._mean, scale_diag) + + @property + def params(self): + return self._mean, + + def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]): + [mean] = inputs + return NormalMeanNegativeLogProbLoss( + mean=mean, + targets=self.targets, + variance=self._variance, + weight=self.weight, + ) + + def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: + return vector / self._variance + + def multiply_fisher_factor_unweighted( + self, + vector: jnp.ndarray, + ) -> jnp.ndarray: + return vector / jnp.sqrt(self._variance) + + def multiply_fisher_factor_transpose_unweighted( + self, + vector: jnp.ndarray, + ) -> jnp.ndarray: + return self.multiply_fisher_factor_unweighted(vector) # it's symmetric + + def multiply_fisher_factor_replicated_one_hot_unweighted( + self, + index: Index, + ) -> jnp.ndarray: + assert len(index) == 1, f"Length of index was {len(index)}." + index = index[0] + ones_slice = jnp.ones([self._mean.shape[0]])[..., None] + output_slice = ones_slice / jnp.sqrt(self._variance) + return insert_slice_in_zeros(output_slice, 1, self._mean.shape[1], index) + + +def insert_slice_in_zeros( + slice_to_insert: jnp.ndarray, + dim: int, + dim_size: int, + position: int, +) -> jnp.ndarray: + """Inserts slice into a larger tensor of zeros. + + Forms a new tensor which is the same shape as slice_to_insert, except that + the dimension given by 'dim' is expanded to the size given by 'dim_size'. + 'position' determines the position (index) at which to insert the slice within + that dimension. + + Assumes slice_to_insert.shape[dim] = 1. + + Args: + slice_to_insert: The slice to insert. + dim: The dimension which to expand with zeros. + dim_size: The new size of the 'dim' dimension. + position: The position of 'slice_to_insert' in the new tensor. + + Returns: + The new tensor. + + Raises: + ValueError: If the slice's shape at the given dim is not 1. + """ + slice_shape = slice_to_insert.shape + if slice_shape[dim] != 1: + raise ValueError(f"Expected slice_to_insert.shape to have {dim} dim of 1," + f" but was {slice_to_insert.shape[dim]}.") + + before = [0] * int(len(slice_shape)) + after = before[:] + before[dim] = position + after[dim] = dim_size - position - 1 + return jnp.pad(slice_to_insert, list(zip(before, after))) + + +# _______ _____ _ _ _ _ +# |__ __| | __ \ (_) | | | | (_) +# | | __ _ __ _ | |__) |___ __ _ _ ___| |_ _ __ __ _| |_ _ ___ _ __ +# | |/ _` |/ _` | | _ // _ \/ _` | / __| __| '__/ _` | __| |/ _ \| '_ \ +# | | (_| | (_| | | | \ \ __/ (_| | \__ \ |_| | | (_| | |_| | (_) | | | | +# |_|\__,_|\__, | |_| \_\___|\__, |_|___/\__|_| \__,_|\__|_|\___/|_| |_| +# __/ | __/ | +# |___/ |___/ + + +NormalMeanNegativeLogProbLoss_tag = tags.LossTag( + NormalMeanNegativeLogProbLoss, num_inputs=1) + + +def register_normal_predictive_distribution( + mean: jnp.ndarray, + targets: Optional[jnp.ndarray] = None, + variance: float = 0.5, + weight: float = 1.0, +): + """Registers a normal predictive distribution. + + This corresponds to a squared error loss of the form + weight/(2*var) * ||target - mean||^2 + + Args: + mean: A tensor defining the mean vector of the distribution. The first + dimension must be the batch size. + targets: (OPTIONAL) The targets for the loss function. Only required if one + wants to use the "empirical Fisher" instead of the true Fisher (which is + controlled by the 'estimation_mode' to the optimizer). + (Default: None) + variance: float. The variance of the distribution. Note that the default + value of 0.5 corresponds to a standard squared error loss weight * + ||target - prediction||^2. If you want your squared error loss to be of + the form 0.5*coeff*||target - prediction||^2 you should use + variance=1.0. + (Default: 0.5) + weight: A scalar coefficient to multiply the log prob loss associated with + this distribution. The Fisher will be multiplied by the corresponding + factor. In general this is NOT equivalent to changing the temperature of + the distribution, but in the ase of normal distributions it may be. + (Default: 1.0) + + Returns: + The mean and targets as dependable on the tag. + """ + if targets is None: + targets = jnp.zeros_like(mean) + return NormalMeanNegativeLogProbLoss_tag.bind( + mean, targets, variance=variance, weight=weight, return_loss=False) + + +def register_squared_error_loss( + prediction: jnp.ndarray, + targets: Optional[jnp.ndarray] = None, + weight: float = 1.0, +): + """Registers a squared error loss function. + + This assumes the squared error loss of the form ||target - prediction||^2, + averaged across the mini-batch. If your loss uses a coefficient of 0.5 + you need to set the "weight" argument to reflect this. + + Args: + prediction: The prediction made by the network (i.e. its output). The first + dimension must be the batch size. + targets: (OPTIONAL) The targets for the loss function. Only required if one + wants to use the "empirical Fisher" instead of the true Fisher (which is + controlled by the 'estimation_mode' to the optimizer). + (Default: None) + weight: A float coefficient to multiply the loss function by. + (Default: 1.0) + Returns: + The mean and targets as dependable on the tag. + """ + return register_normal_predictive_distribution( + prediction, targets=targets, variance=0.5, weight=weight) diff --git a/DeepSolid/utils/kfac_ferminet_alpha/optimizer.py b/DeepSolid/utils/kfac_ferminet_alpha/optimizer.py new file mode 100644 index 0000000..42784c2 --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/optimizer.py @@ -0,0 +1,614 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""A module for the main curvature optimizer class.""" +from typing import Any, Callable, Iterator, Mapping, Optional, Sequence, Tuple, Union + +import jax +import jax.lax as lax +import jax.numpy as jnp +import jax.random as jnr + +from DeepSolid.utils.kfac_ferminet_alpha import estimator +from DeepSolid.utils.kfac_ferminet_alpha import tag_graph_matcher as tgm +from DeepSolid.utils.kfac_ferminet_alpha import utils + +ScheduleType = Callable[[jnp.ndarray], Optional[jnp.ndarray]] +Parameters = Any +Batch = Any +FuncState = Any +State = Mapping[str, Any] + + +@utils.Stateful.infer_class_state +class Optimizer(utils.Stateful): + """The default optimizer class.""" + velocities: Parameters + estimator: estimator.CurvatureEstimator + step_counter: jnp.ndarray + + def __init__( + self, + value_and_grad_func, + l2_reg: Union[float, jnp.ndarray], + value_func_has_aux: bool = False, + value_func_has_state: bool = False, + value_func_has_rng: bool = False, + learning_rate_schedule: Optional[ScheduleType] = None, + momentum_schedule: Optional[ScheduleType] = None, + damping_schedule: Optional[ScheduleType] = None, + min_damping: Union[float, jnp.ndarray] = 1e-8, + max_damping: Union[float, jnp.ndarray] = jnp.inf, + norm_constraint: Optional[Union[float, jnp.ndarray]] = None, + num_burnin_steps: int = 10, + estimation_mode: str = "fisher_gradients", + curvature_ema: Union[float, jnp.ndarray] = 0.95, + inverse_update_period: int = 5, + register_only_generic: bool = False, + layer_tag_to_block_cls: Optional[estimator.TagMapping] = None, + patterns_to_skip: Sequence[str] = (), + donate_parameters: bool = False, + donate_optimizer_state: bool = False, + donate_batch_inputs: bool = False, + donate_func_state: bool = False, + batch_process_func: Optional[Callable[[Any], Any]] = None, + multi_device: bool = False, + use_jax_cond: bool = True, + debug: bool = False, + pmap_axis_name="kfac_axis", + ): + """Initializes the K-FAC optimizer with the given settings. + + Args: + value_and_grad_func: Python callable. The function should return the value + of the loss to be optimized and its gradients. If the argument + `value_func_has_aux` is `False` then the interface should be: loss, + loss_grads = value_and_grad_func(params, batch) + If `value_func_has_aux` is `True` then the interface should be: (loss, + aux), loss_grads = value_and_grad_func(params, batch) + l2_reg: Scalar. Set this value to tell the optimizer what L2 + regularization coefficient you are using (if any). Note the coefficient + appears in the regularizer as coeff / 2 * sum(param**2). Note that the + user is still responsible for adding regularization to the loss. + value_func_has_aux: Boolean. Specifies whether the provided callable + `value_and_grad_func` returns the loss value only, or also some + auxiliary data. (Default: False) + value_func_has_state: Boolean. Specifies whether the provided callable + `value_and_grad_func` has a persistent state that is inputed and + it also outputs an update version of it. (Default: False) + value_func_has_rng: Boolean. Specifies whether the provided callable + `value_and_grad_func` additionally takes as input an rng key. + (Default: False) + learning_rate_schedule: Callable. A schedule for the learning rate. This + should take as input the current step number and return a single + `jnp.ndarray` that represents the learning rate. (Default: None) + momentum_schedule: Callable. A schedule for the momentum. This should take + as input the current step number and return a single `jnp.ndarray` + that represents the momentum. (Default: None) + damping_schedule: Callable. A schedule for the damping. This should take + as input the current step number and return a single `jnp.ndarray` + that represents the learning rate. (Default: None) + min_damping: Scalar. Minimum value the damping parameter can take. Note + that the default value of 1e-8 is quite arbitrary, and you may have to + adjust this up or down for your particular problem. If you are using a + non-zero value of l2_reg you *may* be able to set this to + zero. (Default: 1e-8) + max_damping: Scalar. Maximum value the damping parameter can take. + (Default: Infinity) + norm_constraint: Scalar. If specified, the update is scaled down so that + its approximate squared Fisher norm `v^T F v` is at most the specified + value.(Note that here `F` is the approximate curvature matrix, not the + exact.) (Default: None) + num_burnin_steps: Int. At the start of optimization, e.g. the first step, + before performing the actual step the optimizer will perform this many + times updates to the curvature approximation without updating the + actual parameters. (Default: 10) + estimation_mode: String. The type of estimator to use for the curvature + matrix. Can be one of: * fisher_empirical * fisher_exact * + fisher_gradients * fisher_curvature_prop * ggn_exact * + ggn_curvature_prop See the doc-string for CurvatureEstimator (in + estimator.py) for a more + detailed description of these options. (Default: 'fisher_gradients'). + curvature_ema: The decay factor used when calculating the covariance + estimate moving averages. (Default: 0.95) + inverse_update_period: Int. The number of steps in between updating the + the computation of the inverse curvature approximation. (Default: 5) + register_only_generic: Boolean. Whether when running the auto-tagger to + register only generic parameters, or allow it to use the graph matcher + to automatically pick up any kind of layer tags. (Default: False) + layer_tag_to_block_cls: Dictionary. A mapping from layer tags to block + classes which to override the default choices of block approximation for + that specific tag. See the doc-string for CurvatureEstimator (in + estimator.py) for a more detailed description of this. + patterns_to_skip: Tuple. A list of any patterns that should be skipped by + the graph matcher when auto-tagging. + donate_parameters: Boolean. Whether to use jax's `donate_argnums` to + donate the parameter values of each call to `step`. Note that this + implies that you will not be able to access the old parameter values' + buffers after calling into `step`. + donate_optimizer_state: Boolean. Whether to use jax's `donate_argnums` to + donate the optimizer state of each call to `step`. Note that this + implies that you will not be able to access the old optimizer state + values' buffers after calling into `step`. + donate_batch_inputs: Boolean. Whether to use jax's `donate_argnums` to + donate the batch values of each call to `step`. Note that this implies + that you will not be able to access the old batch values' buffers after + calling into `step`. + donate_func_state: Boolean. Whether to use jax's `donate_argnums` to + donate the persistent function state of each call to `step`. Note that + this implies that you will not be able to access the old function state + values' buffers after calling into `step`. + batch_process_func: Callable. A function which to be called on each batch + before feeding to the KFAC on device. This could be useful for specific + device input optimizations. + multi_device: Boolean. Whether to use `pmap` and run the optimizer on + multiple devices. (Default: False) + use_jax_cond: Not used for the moment. + debug: Boolean. If non of the step or init functions would be jitted. Note + that this also overrides `multi_device` and prevents using `pmap`. + (Default: False) + pmap_axis_name: String. The name of the `pmap` axis to use when + `multi_device` is set to True. (Default: curvature_axis) + """ + super().__init__() + self.value_and_grad_func = value_and_grad_func + self.value_func_has_aux = value_func_has_aux + self.value_func_has_state = value_func_has_state + self.value_func_has_rng = value_func_has_rng + self.value_func = utils.convert_value_and_grad_to_value_func( + value_and_grad_func, has_aux=value_func_has_aux) + self.l2_reg = l2_reg + self.learning_rate_schedule = learning_rate_schedule + if momentum_schedule is not None: + + def schedule_with_first_step_zero(global_step: jnp.ndarray): + value = momentum_schedule(global_step) + check = jnp.equal(global_step, 0) + return check * jnp.zeros_like(value) + (1 - check) * value + + self.momentum_schedule = schedule_with_first_step_zero + else: + self.momentum_schedule = None + self.damping_schedule = damping_schedule + self.min_damping = min_damping + self.max_damping = max_damping + self.norm_constraint = norm_constraint + self.num_burnin_steps = num_burnin_steps + self.estimation_mode = estimation_mode + self.curvature_ema = curvature_ema + self.inverse_update_period = inverse_update_period + self.register_only_generic = register_only_generic + self.layer_tag_to_block_cls = layer_tag_to_block_cls + self.patterns_to_skip = patterns_to_skip + self.donate_parameters = donate_parameters + self.donate_optimizer_state = donate_optimizer_state + self.donate_batch_inputs = donate_batch_inputs + self.donate_func_state = donate_func_state + self.batch_process_func = batch_process_func or (lambda x: x) + self.multi_device = multi_device + self.use_jax_cond = use_jax_cond + self.debug = debug + self.pmap_axis_name = pmap_axis_name if multi_device else None + self._rng_split = utils.p_split if multi_device else jnr.split + + # Attributes filled in during self.init() + self.finalized = False + self.tagged_func = None + self.flat_params_shapes = None + self.params_treedef = None + # Special attributes related to jitting/pmap + self._jit_init = None + self._jit_burnin = None + self._jit_step = None + + def finalize( + self, + params: Parameters, + rng: jnp.ndarray, + batch: Batch, + func_state: Optional[FuncState] = None, + ) -> None: + """Finalizes the optimizer by tracing the model function with the params and batch.""" + if self.finalized: + raise ValueError("Optimizer has already been finalized.") + if self.multi_device: + # We assume that the parameters and batch are replicated, while tracing + # must happen with parameters for a single device call + params, rng, batch = jax.tree_map(lambda x: x[0], (params, rng, batch)) + if func_state is not None: + func_state = jax.tree_map(lambda x: x[0], func_state) + batch = self.batch_process_func(batch) + # These are all tracing operations and we can run them with abstract values + func_args = utils.make_func_args(params, func_state, rng, batch, + self.value_func_has_state, + self.value_func_has_rng) + # Run all tracing with abstract values so no computation is done + flat_params, self.params_treedef = jax.tree_flatten(params) + self.flat_params_shapes = tuple(p.shape for p in flat_params) + self.tagged_func = tgm.auto_register_tags( + func=self.value_func, + func_args=func_args, + params_index=0, + register_only_generic=self.register_only_generic, + patterns_to_skip=self.patterns_to_skip) + self.estimator = estimator.CurvatureEstimator( + self.tagged_func, + func_args, + self.l2_reg, + self.estimation_mode, + layer_tag_to_block_cls=self.layer_tag_to_block_cls) + # Arguments: params, opt_state, rng, batch, func_state + donate_argnums = [] + if self.donate_parameters: + donate_argnums.append(0) + if self.donate_optimizer_state: + donate_argnums.append(1) + if self.donate_batch_inputs: + donate_argnums.append(3) + if self.donate_func_state and self.value_func_has_state: + donate_argnums.append(4) + donate_argnums = tuple(donate_argnums) + + if self.debug: + self._jit_init = self._init + self._jit_burnin = self._burnin + self._jit_step = self._step + elif self.multi_device: + self._jit_init = jax.pmap( + self._init, axis_name=self.pmap_axis_name, donate_argnums=[0]) + # batch size is static argnum and is at index 5 + self._jit_burnin = jax.pmap( + self._burnin, + axis_name=self.pmap_axis_name, + static_broadcasted_argnums=[5]) + self._jit_step = jax.pmap( + self._step, + axis_name=self.pmap_axis_name, + donate_argnums=donate_argnums, + static_broadcasted_argnums=[5]) + else: + self._jit_init = jax.jit(self._init, donate_argnums=[0]) + # batch size is static argnum and is at index 5 + self._jit_burnin = jax.jit(self._burnin, static_argnums=[5]) + self._jit_step = jax.jit( + self._step, donate_argnums=donate_argnums, static_argnums=[5]) + self.finalized = True + + def _init(self, rng: jnp.ndarray) -> State: + """This is the non-jitted version of initializing the state.""" + flat_velocities = [jnp.zeros(shape) for shape in self.flat_params_shapes] + return dict( + velocities=jax.tree_unflatten(self.params_treedef, flat_velocities), + estimator=self.estimator.init(rng, None), + step_counter=jnp.asarray(0)) + + def verify_args_and_get_step_counter( + self, + params: Parameters, + state: State, + rng: jnp.ndarray, + data_iterator: Iterator[Batch], + func_state: Optional[FuncState] = None, + learning_rate: Optional[jnp.ndarray] = None, + momentum: Optional[jnp.ndarray] = None, + damping: Optional[jnp.ndarray] = None, + global_step_int: Optional[int] = None, + ) -> int: + """Verifies that the arguments passed to `Optimizer.step` are correct.""" + if not self.finalized: + rng, rng_finalize = self._rng_split(rng) + self.finalize(params, rng_finalize, next(data_iterator), func_state) + # Verify correct arguments invocation + if self.learning_rate_schedule is not None and learning_rate is not None: + raise ValueError("When you have passed a `learning_rate_schedule` you " + "should not pass a value to the step function.") + if self.momentum_schedule is not None and momentum is not None: + raise ValueError("When you have passed a `momentum_schedule` you should " + "not pass a value to the step function.") + if self.damping_schedule is not None and damping is not None: + raise ValueError("When you have passed a `damping_schedule` you should " + "not pass a value to the step function.") + # Do a bunrnin on the first iteration + if global_step_int is None: + if self.multi_device: + return int(utils.get_first(state["step_counter"])) + else: + return int(state["step_counter"]) + return global_step_int + + def _burnin( + self, + params: Parameters, + state: State, + rng: jnp.ndarray, + batch: Batch, + func_state: Optional[FuncState], + batch_size: Optional[int], + ) -> Tuple[State, Optional[FuncState]]: + """This is the non-jitted version of a single burnin step.""" + self.set_state(state) + batch = self.batch_process_func(batch) + rng, func_rng = jnr.split(rng) if self.value_func_has_rng else (rng, None) + func_args = utils.make_func_args(params, func_state, func_rng, batch, + self.value_func_has_state, + self.value_func_has_rng) + + # Compute batch size + if batch_size is None: + batch_size = jax.tree_flatten(batch)[0][0].shape[0] + + # Update curvature estimate + ema_old, ema_new = 1.0, 1.0 / self.num_burnin_steps + self.estimator.update_curvature_matrix_estimate(ema_old, ema_new, + batch_size, rng, func_args, + self.pmap_axis_name) + + if func_state is not None: + out, _ = self.value_and_grad_func(*func_args) + _, func_state, _ = utils.extract_func_outputs(out, + self.value_func_has_aux, + self.value_func_has_state) + + return self.pop_state(), func_state + + def _step( + self, + params: Parameters, + state: State, + rng: jnp.ndarray, + batch: Batch, + func_state: Optional[FuncState], + batch_size: Optional[int], + learning_rate: Optional[jnp.ndarray], + momentum: Optional[jnp.ndarray], + damping: Optional[jnp.ndarray], + ) -> Union[Tuple[Parameters, State, FuncState, Mapping[str, jnp.ndarray]], + Tuple[Parameters, State, Mapping[str, jnp.ndarray]]]: + """This is the non-jitted version of a single step.""" + # Unpack and set the state + self.set_state(state) + if damping is not None: + assert self.estimator.damping is None + self.estimator.damping = damping + else: + assert self.estimator.damping is not None + + # Preprocess the batch and construct correctly the function arguments + batch = self.batch_process_func(batch) + rng, func_rng = jnr.split(rng) if self.value_func_has_rng else (rng, None) + func_args = utils.make_func_args(params, func_state, func_rng, batch, + self.value_func_has_state, + self.value_func_has_rng) + + # Compute the batch size + if batch_size is None: + batch_size = jax.tree_flatten(batch)[0][0].shape[0] + + # Compute schedules if applicable + if self.learning_rate_schedule is not None: + assert learning_rate is None + learning_rate = self.learning_rate_schedule(self.step_counter) + else: + assert learning_rate is not None + if self.momentum_schedule is not None: + assert momentum is None + momentum = self.momentum_schedule(self.step_counter) + else: + assert momentum is not None + if self.damping_schedule is not None: + assert damping is None + damping = self.damping_schedule(self.step_counter) + else: + assert damping is not None + + # Compute current loss and gradients + out, grads = self.value_and_grad_func(*func_args) + loss, new_func_state, aux = utils.extract_func_outputs( + out, self.value_func_has_aux, self.value_func_has_state) + # Sync loss and grads + loss, grads = utils.pmean_if_pmap((loss, grads), self.pmap_axis_name) + + # Update curvature estimate + self.estimator.update_curvature_matrix_estimate( + self.curvature_ema, + 1.0, + batch_size, + rng, + func_args, + self.pmap_axis_name, + ) + + # Optionally update the inverse estimate + self.estimator.set_state( + lax.cond( + self.step_counter % self.inverse_update_period == 0, + lambda s: self.estimator.update_curvature_estimate_inverse( # pylint: disable=g-long-lambda + self.pmap_axis_name, s), + lambda s: s, + self.estimator.pop_state())) + + # Compute proposed directions + vectors = self.propose_directions( + grads, + self.velocities, + learning_rate, + momentum, + ) + + # The learning rate is defined as the negative of the coefficient by which + # we multiply the gradients, while the momentum is the coefficient by + # which we multiply the velocities. + neg_learning_rate = -learning_rate + # Compute the coefficients of the update vectors + assert neg_learning_rate is not None and momentum is not None + coefficients = (neg_learning_rate, momentum) + + # Update velocities and compute new delta + self.velocities, delta = self.velocities_and_delta( + self.velocities, + vectors, + coefficients, + ) + + # Update parameters: params = params + delta + params = jax.tree_multimap(jnp.add, params, delta) + + # Optionally compute the reduction ratio and update the damping + self.estimator.damping = None + rho = jnp.nan + + # Statistics with useful information + stats = dict() + stats["step"] = self.step_counter + stats["loss"] = loss + stats["learning_rate"] = -coefficients[0] + stats["momentum"] = coefficients[1] + stats["damping"] = damping + stats["rho"] = rho + if self.value_func_has_aux: + stats["aux"] = aux + self.step_counter = self.step_counter + 1 + + if self.value_func_has_state: + return params, self.pop_state(), new_func_state, stats + else: + assert new_func_state is None + return params, self.pop_state(), stats + + def init( + self, + params: Parameters, + rng: jnp.ndarray, + batch: Batch, + func_state: Optional[FuncState] = None, + ) -> State: + """Initializes the optimizer and returns the appropriate optimizer state.""" + if not self.finalized: + self.finalize(params, rng, batch, func_state) + return self._jit_init(rng) + + def step( + self, + params: Parameters, + state: Mapping[str, Any], + rng: jnp.ndarray, + data_iterator: Iterator[Any], + func_state: Any = None, + learning_rate: Optional[jnp.ndarray] = None, + momentum: Optional[jnp.ndarray] = None, + damping: Optional[jnp.ndarray] = None, + batch_size: Optional[int] = None, + global_step_int: Optional[int] = None, + ) -> Union[Tuple[Parameters, State, FuncState, Mapping[str, jnp.ndarray]], + Tuple[Parameters, State, Mapping[str, jnp.ndarray]]]: + """Performs a single update step using the optimizer. + + Args: + params: The parameters of the model. + state: The state of the optimizer. + rng: A Jax PRNG key. + data_iterator: An iterator that returns a batch of data. + func_state: Any function state that gets passed in and returned. + learning_rate: This must be provided when + `use_adaptive_learning_rate=False` and `learning_rate_schedule=None`. + momentum: This must be provided when + `use_adaptive_momentum=False` and `momentum_schedule=None`. + damping: This must be provided when + `use_adaptive_damping=False` and `damping_schedule=None`. + batch_size: The batch size to use for KFAC. The default behaviour when it + is None is to use the leading dimension of the first data array. + global_step_int: The global step as a python int. Note that this must + match the step inte rnal to the optimizer that is part of its state. + + Returns: + (params, state, stats) + where: + params: The updated model parameters. + state: The updated optimizer state. + stats: A dictionary of key statistics provided to be logged. + """ + step_counter_int = self.verify_args_and_get_step_counter( + params=params, + state=state, + rng=rng, + data_iterator=data_iterator, + func_state=func_state, + learning_rate=learning_rate, + momentum=momentum, + damping=damping, + global_step_int=global_step_int) + + if step_counter_int == 0: + for _ in range(self.num_burnin_steps): + rng, rng_burn = self._rng_split(rng) + batch = next(data_iterator) + state, func_state = self._jit_burnin(params, state, rng_burn, batch, + func_state, batch_size) + + # On the first step we always treat the momentum as 0.0 + if self.momentum_schedule is None: + momentum = jnp.zeros([]) + if self.multi_device: + momentum = utils.replicate_all_local_devices(momentum) + + batch = next(data_iterator) + return self._jit_step(params, state, rng, batch, func_state, batch_size, + learning_rate, momentum, damping) + + def propose_directions( + self, + grads: Parameters, + velocities: Parameters, + learning_rate: Optional[jnp.ndarray], + momentum: Optional[jnp.ndarray], + ) -> Tuple[Parameters, Parameters]: + """Computes the vector proposals for the next step.""" + del momentum # not used in this, but could be used in subclasses + preconditioned_grads = self.estimator.multiply_matpower(grads, -1) + + if self.norm_constraint is not None: + assert learning_rate is not None + sq_norm_grads = utils.inner_product(preconditioned_grads, grads) + sq_norm_scaled_grads = sq_norm_grads * learning_rate**2 + + # We need to sync the norms here, because reduction can be + # non-deterministic. They specifically are on GPUs by default for better + # performance. Hence although grads and preconditioned_grads are synced, + # the inner_product operation can still produce different answers on + # different devices. + sq_norm_scaled_grads = utils.pmean_if_pmap(sq_norm_scaled_grads, + self.pmap_axis_name) + + max_coefficient = jnp.sqrt(self.norm_constraint / sq_norm_scaled_grads) + coefficient = jnp.minimum(max_coefficient, 1) + preconditioned_grads = utils.scalar_mul(preconditioned_grads, coefficient) + + return preconditioned_grads, velocities + + def velocities_and_delta( + self, + velocities: Parameters, + vectors: Sequence[Parameters], + coefficients: Sequence[jnp.ndarray], + ) -> Sequence[Parameters]: + """Computes the new velocities and delta (update to parameters).""" + del velocities + assert len(vectors) == len(coefficients) + delta = utils.scalar_mul(vectors[0], coefficients[0]) + for vi, wi in zip(vectors[1:], coefficients[1:]): + delta = jax.tree_multimap(jnp.add, delta, utils.scalar_mul(vi, wi)) + return delta, delta diff --git a/DeepSolid/utils/kfac_ferminet_alpha/tag_graph_matcher.py b/DeepSolid/utils/kfac_ferminet_alpha/tag_graph_matcher.py new file mode 100644 index 0000000..96357b2 --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/tag_graph_matcher.py @@ -0,0 +1,755 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""A module for tagging and graph manipulation.""" +import collections +import functools +import itertools +from typing import Any, NamedTuple, Sequence + +from absl import logging +import jax +from jax import core as jax_core +from jax import lax +from jax import util as jax_util +from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import networkx as nx +from networkx.algorithms import isomorphism +import numpy as np +import ordered_set + +from DeepSolid.utils.kfac_ferminet_alpha import layers_and_loss_tags as tags + +USE_NETWORKX = False + + +def match_nodes(g1, g2, mapping, node1, node2): + """Matching nodes when doing graph search.""" + + if not kfac_node_match(g1.nodes[node1], g2.nodes[node2]): + return False + # Check predecessors + p1 = set(n for n in g1.predecessors(node1) if n in mapping.keys()) + p2 = set(n for n in g2.predecessors(node2) if n in mapping.values()) + if len(p1) != len(p2): + return False + for p1_i in p1: + if mapping[p1_i] not in p2: + return False + # Check successors + s1 = set(n for n in g1.successors(node1) if n in mapping.keys()) + s2 = set(n for n in g2.successors(node2) if n in mapping.values()) + if len(s1) != len(s2): + return False + for s1_i in s1: + if mapping[s1_i] not in s2: + return False + return True + + +def generate_candidates(g1, g2, mapping, node1, node2): + """Generates the initial candidates for graph search.""" + # Check predecessors + p1 = set(n for n in g1.predecessors(node1) if n not in mapping.keys()) + p2 = set(n for n in g2.predecessors(node2) if n not in mapping.values()) + candidates = ordered_set.OrderedSet(itertools.product(p1, p2)) + s1 = set(n for n in g1.successors(node1) if n not in mapping.keys()) + s2 = set(n for n in g2.successors(node2) if n not in mapping.values()) + candidates.update(list(itertools.product(s1, s2))) + return candidates + + +def find_mappings(pattern, graph, mapping, terminals): + """Finds all mappings from graph search of the pattern.""" + if len(mapping) == len(pattern): + for k, v in terminals.items(): + v.add(mapping[k]) + return [frozenset(mapping.items())] + mappings = set() + nodes_list = list(mapping.keys()) + for node1 in reversed(nodes_list): + for s1 in pattern.successors(node1): + if s1 not in mapping.keys(): + for s2 in graph.successors(mapping[node1]): + if s2 not in mapping.values(): + if s1 not in terminals or s2 not in terminals[s1]: + if match_nodes(pattern, graph, mapping, s1, s2): + mapping[s1] = s2 + mappings.update( + find_mappings(pattern, graph, mapping, terminals)) + mapping.pop(s1) + for p1 in pattern.predecessors(node1): + if p1 not in mapping.keys(): + for p2 in graph.predecessors(mapping[node1]): + if p2 not in mapping.values(): + if p1 not in terminals or p2 not in terminals[p1]: + if match_nodes(pattern, graph, mapping, p1, p2): + mapping[p1] = p2 + mappings.update( + find_mappings(pattern, graph, mapping, terminals)) + mapping.pop(p1) + return mappings + + +def match_pattern(pattern, graph): + """Given a pattern returns all matches inside the graph.""" + if USE_NETWORKX: + matcher = isomorphism.GraphMatcher( + graph, pattern, node_match=kfac_node_match) + mappings = list( + dict((k, v) + for v, k in mapping.items()) + for mapping in matcher.subgraph_isomorphisms_iter()) + else: + mapping = collections.OrderedDict() + params1 = [n for n in pattern.nodes if pattern.nodes[n]["op"] == "param"] + params2 = [n for n in graph.nodes if graph.nodes[n]["op"] == "param"] + terminals = { + n: set() for n in pattern.nodes if not list(pattern.successors(n)) + } + + mappings = set() + for node1, node2 in itertools.product(params1, params2): + mapping[node1] = node2 + mappings.update(find_mappings(pattern, graph, mapping, terminals)) + mapping.pop(node1) + for v in terminals.values(): + v.clear() + mappings = list(dict(mapping) for mapping in mappings) + + var_mappings = [] + for mapping in mappings: + var_mappings.append(dict()) + for k, v in mapping.items(): + cond = pattern.nodes[k]["op"] in ("param", "array") + source = pattern.nodes[k]["var"] if cond else k + target = graph.nodes[v]["var"] if cond else graph.nodes[v]["eqn"] + var_mappings[-1][source] = target + + return var_mappings + + +def read_env(env, var): + # Literals are values baked into the Jaxpr + if isinstance(var, jax.core.Literal): + return var.val + return env[var] + + +def write_env(env, var, val): + env[var] = val + + +def abstract_single_value(value): + if isinstance(value, jnp.ndarray): + value = jax.ShapedArray(np.shape(value), np.result_type(value)) + return pe.PartialVal.unknown(value) + else: + return value + + +def abstract_args(args): + return jax.tree_map(abstract_single_value, args) + + +def evaluate_eqn(eqn, in_values, write_func): + """Evaluate a single Jax equation and writes the outputs.""" + in_values = list(in_values) + # This is logic specifically to handle `xla_call` + call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params) + if call_jaxpr: + subfuns = [ + jax.core.lu.wrap_init( + functools.partial(jax.core.eval_jaxpr, call_jaxpr, ())) + ] + else: + subfuns = [] + ans = eqn.primitive.bind(*(subfuns + in_values), **params) + if eqn.primitive.multiple_results: + jax_util.safe_map(write_func, eqn.outvars, ans) + else: + write_func(eqn.outvars[0], ans) + return ans + + +def clean_jaxpr_eqns(jaxpr, preserve_tags=True): + """Performs dead code elimination on the jaxpr, preserving loss and layer tags.""" + eqns = [] + dependants = set(jaxpr.outvars) + for eqn in reversed(jaxpr.eqns): + check = False + for v in eqn.outvars: + if v in dependants: + dependants.remove(v) + check = True + if isinstance(eqn.primitive, (tags.LossTag, tags.LayerTag)): + check = check or preserve_tags + if check: + eqns.append(eqn) + new_dependants = set( + v for v in eqn.invars if not isinstance(v, jax_core.Literal)) + dependants = dependants.union(new_dependants) + # Dependants should only be invars + dependants = dependants - set(jaxpr.invars + jaxpr.constvars) + + if dependants: + raise ValueError("Something went wrong with the dead code elimination.") + return reversed(eqns) + + +def broadcast_merger(f): + """Transforms `f` into a function where all consecutive broadcasts are merged.""" + + def merged_func(*func_args): + typed_jaxpr, out_avals = jax.make_jaxpr(f, return_shape=True)(*func_args) + out_tree = jax.tree_structure(out_avals) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + + # Mapping from variable -> value + env = dict() + read = functools.partial(read_env, env) + write = functools.partial(write_env, env) + + # Bind args and consts to environment + flat_args = jax.tree_flatten(func_args)[0] + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, flat_args) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, flat_args) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Loop through equations and evaluate primitives using `bind` + broadcasts_outputs = dict() + for eqn in clean_jaxpr_eqns(jaxpr): + # We ignore broadcasting of constants + if (eqn.primitive.name == "broadcast_in_dim" and + not all(isinstance(v, jax_core.Literal) for v in eqn.invars)): + if eqn.invars[0] in broadcasts_outputs: + x, dims = broadcasts_outputs[eqn.invars[0]] + kept_dims = eqn.params["broadcast_dimensions"] + kept_dims = [kept_dims[d] for d in dims] + y = lax.broadcast_in_dim(x, eqn.params["shape"], kept_dims) + jax_util.safe_map(write, eqn.outvars, [y]) + broadcasts_outputs[eqn.outvars[0]] = (x, kept_dims) + else: + inputs = jax_util.safe_map(read, eqn.invars) + evaluate_eqn(eqn, inputs, write) + broadcasts_outputs[eqn.outvars[0]] = ( + inputs[0], eqn.params["broadcast_dimensions"]) + else: + evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) + return jax.tree_unflatten(out_tree, jax_util.safe_map(read, jaxpr.outvars)) + + return merged_func + + +class JaxGraph(NamedTuple): + jaxpr: Any + consts: Any + params: Any + params_tree: Any + in_tree: Any + out_tree: Any + digraph: nx.DiGraph + tagging_func: Any + + +SPECIAL_OP_COMPARE_RULES = dict() + + +def default_compare(node1, node2): + if node1["op"] != node2["op"]: + return False + params1, params2 = node1["eqn"].params, node2["eqn"].params + if set(params1.keys()) != set(params2.keys()): + return False + for k in params1.keys(): + if params1[k] != params2[k]: + return False + return True + + +def reshape_compare(node1, node2): + """Compares two reshape nodes.""" + assert node1["op"] == node2["op"] == "reshape" + params1, params2 = node1["eqn"].params, node2["eqn"].params + if params1["dimensions"] != params2["dimensions"]: + return False + return True + + +def broadcast_in_dim_compare(node1, node2): + """Compares two reshape nodes.""" + assert node1["op"] == node2["op"] == "broadcast_in_dim" + return True + + +def conv_compare(node1, node2): + """Compares two conv_general_dialted nodes.""" + assert node1["op"] == node2["op"] == "conv_general_dilated" + params1, params2 = node1["eqn"].params, node2["eqn"].params + for k in ("window_strides", "padding", "lhs_dilation", "rhs_dilation", + "lhs_shape", "rhs_shape"): + if len(params1[k]) != len(params2[k]): + return False + if (len(params1["dimension_numbers"].lhs_spec) != # + len(params2["dimension_numbers"].lhs_spec)): + return False + if (len(params1["dimension_numbers"].rhs_spec) != # + len(params2["dimension_numbers"].rhs_spec)): + return False + if (len(params1["dimension_numbers"].out_spec) != # + len(params2["dimension_numbers"].out_spec)): + return False + if ((params1["feature_group_count"] > 1) != # + (params2["feature_group_count"] > 1)): + return False + if ((params1["batch_group_count"] > 1) != # + (params2["batch_group_count"] > 1)): + return False + return True + + +SPECIAL_OP_COMPARE_RULES["reshape"] = reshape_compare +SPECIAL_OP_COMPARE_RULES["broadcast_in_dim"] = broadcast_in_dim_compare +SPECIAL_OP_COMPARE_RULES["conv_general_dilated"] = conv_compare + + +def kfac_node_match(node1, node2): + """Checks if two nodes are equivalent.""" + # Parameters match with each other and nothing else + if node1["op"] == "param" and node2["op"] == "param": + return True + # return node1["rank"] == node2["rank"] + if node1["op"] == "param" or node2["op"] == "param": + return False + # Arrays always match each other and nothing else + if node1["op"] == "array" and node2["op"] == "array": + return True + if node1["op"] == "array" or node2["op"] == "array": + return False + # Operators match first on name + if node1["op"] != node2["op"]: + return False + compare = SPECIAL_OP_COMPARE_RULES.get(node1["op"], default_compare) + return compare(node1, node2) + + +def var_to_str(var): + """Returns a string representation of the variable of a Jax expression.""" + if isinstance(var, jax.core.Literal): + return str(var) + elif isinstance(var, jax.core.UnitVar): + return "*" + elif not isinstance(var, jax.core.Var): + raise ValueError(f"Idk what to do with this {type(var)}?") + c = int(var.count) + if c == -1: + return "_" + str_rep = "" + while c > 25: + str_rep += chr(c % 26 + ord("a")) + c = c // 26 + str_rep += chr(c + ord("a")) + return str_rep[::-1] + + +def extract_param_vars_flat(jaxpr, in_tree, params_index): + if params_index is None: + params_index = [] + elif isinstance(params_index, int): + params_index = [params_index] + in_vars = jax.tree_unflatten(in_tree, jaxpr.invars) + return jax.tree_flatten([in_vars[i] for i in params_index]) + + +def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None): + """Fills the graph with the jaxpr.""" + in_vars = in_vars or [var_to_str(v) for v in jaxpr.invars + jaxpr.constvars] + in_map = dict(zip(jaxpr.invars + jaxpr.constvars, in_vars)) + out_vars = out_vars or [var_to_str(v) for v in jaxpr.outvars] + out_map = dict(zip(jaxpr.outvars, out_vars)) + + for eqn in jaxpr.eqns: + in_vars = [] + for v in eqn.invars: + if isinstance(v, (jax.core.Literal, jax.core.UnitVar)): + in_vars.append(var_to_str(v)) + else: + in_vars.append(in_map.get(v, var_to_str(v))) + out_vars = [out_map.get(v, var_to_str(v)) for v in eqn.outvars] + in_str = ",".join(in_vars) + out_str = ",".join(out_vars) + if isinstance(eqn.primitive, tags.LossTag): + func_name = "__loss_tag" + elif isinstance(eqn.primitive, tags.LayerTag): + func_name = "__layer_tag" + else: + func_name = eqn.primitive.name + node_c = f"{func_name}({in_str})->{out_str}" + graph.add_node(node_c, op=eqn.primitive.name, eqn=eqn) + + # Create incoming edges + for v, name in zip(eqn.invars, in_vars): + if (not isinstance(v, jax.core.Literal) and + not isinstance(v, jax.core.UnitVar)): + graph.add_edge(name, node_c) + + # Create output nodes and edges + for v, name in zip(eqn.outvars, out_vars): + graph.add_node(name, op="array", var=v) + graph.add_edge(node_c, name) + + +def create_digraph(jaxpr, params): + """Creates a directed graph from the given jaxpr and parameters.""" + graph = nx.DiGraph() + # Create input nodes + for v in jaxpr.invars + jaxpr.constvars: + if v in params: + graph.add_node(var_to_str(v), op="param", var=v) + else: + graph.add_node(var_to_str(v), op="array", var=v) + fill_jaxpr_to_graph(graph, jaxpr) + + return graph + + +def function_to_jax_graph(func, args, params_index, tagging_func=None): + """Creates a `JaxGraph` instance from the provided function.""" + in_tree = jax.tree_structure(args) + typed_jaxpr = jax.make_jaxpr(func)(*args) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + params, params_tree = extract_param_vars_flat(jaxpr, in_tree, params_index) + + digraph = create_digraph(jaxpr, params) + if tagging_func is not None: + tagging_func = functools.partial(tagging_func, jaxpr) + return JaxGraph( + jaxpr=jaxpr, + consts=consts, + params=params, + params_tree=params_tree, + in_tree=in_tree, + out_tree=None, + digraph=digraph, + tagging_func=tagging_func) + + +def print_nice_jaxpr(jaxpr): + for eqn in jaxpr.eqns: + print(tuple(eqn.invars), "->", eqn.primitive.name, tuple(eqn.outvars)) + + +def auto_register_tags(func, + func_args, + params_index: int = 0, + register_only_generic: bool = False, + compute_only_loss_tags: bool = True, + patterns_to_skip: Sequence[str] = ()): + """Transform the function to one that is populated with tags.""" + func = broadcast_merger(func) + graph = function_to_jax_graph(func, func_args, params_index=params_index) + matches = dict() + + # Extract the tagged losses variables and all their ancestors + loss_output_vars = [] + num_losses = 0 + loss_ancestors = set() + for node in graph.digraph.nodes: + if node.startswith("__loss_tag"): + num_losses += 1 + ancestors = nx.ancestors(graph.digraph, node) + ancestors.add(node) + for output_node in node.split("->")[-1].split(","): + ancestors.add(output_node) + loss_output_vars.append(graph.digraph.nodes[output_node]["var"]) + loss_ancestors = loss_ancestors.union(ancestors) + loss_output_vars = tuple(loss_output_vars) + + # Extract the sub-graph that leads to losses + sub_graph = nx.induced_subgraph(graph.digraph, loss_ancestors) + + # First collect all parameters that are already part of a layer tag + tagged_params = dict() + pattern_counters = dict() + for tag_node in ( + node for node in sub_graph.nodes if node.startswith("__layer_tag")): + inputs = graph.digraph.nodes[tag_node]["eqn"].invars + tag_instance = graph.digraph.nodes[tag_node]["eqn"].primitive + if tag_instance.name == "generic_tag": + tag_params = tag_instance.split_all_inputs(inputs)[0] + else: + tag_params = tag_instance.split_all_inputs(inputs)[2] + pattern_number = pattern_counters.get(tag_instance.name, 0) + for param in tag_params: + if param not in graph.params: + raise ValueError(f"You have registered a layer tag with parameter " + f"that is not part of the parameters at index " + f"{params_index}.") + if param in tagged_params: + raise ValueError(f"You have registered twice the parameter {param}.") + tagged_params[param] = f"Manual[{tag_instance.name}_{pattern_number}]" + if tag_instance.name not in pattern_counters: + pattern_counters[tag_instance.name] = 1 + else: + pattern_counters[tag_instance.name] += 1 + + if not register_only_generic: + for pattern_name, patterns in get_graph_patterns(): + if pattern_name in patterns_to_skip: + logging.info("Skipping graph pattern %s", pattern_name) + continue + logging.info("Matching graph pattern %s", pattern_name) + for pattern in patterns: + for match_map in match_pattern(pattern.digraph, sub_graph): + if len(pattern.jaxpr.outvars) > 1: + raise NotImplementedError() + output = pattern.jaxpr.outvars[0] + if matches.get(match_map[output]) is not None: + raise ValueError(f"Found more than one match for equation " + f"{match_map[output]}. Examine the jaxpr:\n " + f"{graph.jaxpr}") + # Mark the parameters as already tagged + match_params = set() + match_params_already_tagged = False + for param in match_map.values(): + if param in graph.params: + match_params.add(param) + if param in tagged_params.keys(): + match_params_already_tagged = True + # Register the match only if no parameters are already registered + if not match_params_already_tagged: + matches[match_map[output]] = (match_map, pattern.tagging_func) + pattern_number = pattern_counters.get(pattern_name, 0) + for param in match_params: + tagged_params[param] = f"Auto[{pattern_name}_{pattern_number}]" + if pattern_name not in pattern_counters: + pattern_counters[pattern_name] = 1 + else: + pattern_counters[pattern_name] += 1 + + # Mark remaining parameters as orphans + orphan_params = sorted( + set(graph.params) - set(tagged_params.keys()), key=lambda v: v.count) + params_regs = [tagged_params.get(p, "Orphan") for p in graph.params] + params_regs = jax.tree_unflatten(graph.params_tree, params_regs) + logging.info("=" * 50) + logging.info("Graph parameter registrations:") + logging.info(params_regs) + logging.info("=" * 50) + + # Construct a function with all of the extra tag registrations + @functools.wraps(func) + def wrapped_auto_registered(*args): + flat_args, _ = jax.tree_flatten(args) + # Mapping from variable -> value + env = {} + + read = functools.partial(read_env, env) + write = functools.partial(write_env, env) + + def tag(var): + if matches.get(var) is not None: + inv_map, tagging_func = matches[var] + var_map = {k: v for k, v in inv_map.items() if not isinstance(k, str)} + val_map = jax.tree_map(read, var_map) + val = tagging_func(inv_map, val_map) + env[var] = val + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, graph.jaxpr.invars, flat_args) + jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts) + + # Register any orphan parameters as generic + for param_var in orphan_params: + write(param_var, tags.register_generic(read(param_var))) + + # Set the correct output variables + if compute_only_loss_tags: + output_vars = loss_output_vars + out_tree = jax.tree_structure(loss_output_vars) + else: + output_vars = graph.jaxpr.outvars + out_tree = graph.out_tree + + # Loop through equations and evaluate primitives using `bind` + losses_evaluated = 0 + for eqn in graph.jaxpr.eqns: + evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) + jax_util.safe_map(tag, eqn.outvars) + + # If we want to output only tagged losses + if isinstance(eqn.primitive, tags.LossTag): + losses_evaluated += 1 + if compute_only_loss_tags and num_losses == losses_evaluated: + break + + outputs = jax_util.safe_map(read, output_vars) + return jax.tree_unflatten(out_tree, outputs) + + return wrapped_auto_registered + + +# Registered graphs +NAME_TO_JAX_GRAPH = dict() +DEFERRED_REGISTRATIONS = [] + + +def register_function(name, func, tagging_func, example_args, params_index, + precedence): + """Registers a function as a pattern in the graph matcher registry. + + The graph matcher needs to trace at least once the full function, which means + you need to provide it with dummy arguments. The shapes of the arguments do + not matter, as the graph matcher ignores their values, however the rank does. + Especially if there is some broadcasting happening you should register with + every possible broadcast pattern. As a general advice avoid using a shape to + be 1, unless you want the pattern to specifically match that, as some + operations, like squeeze for example, can have special behaviour then. + + Args: + name: The name of the pattern that is being registered to. + func: The function that performs the computation. + tagging_func: Function that correctly creates the tag. + example_args: Example arguments that can be inputted into `func`. + params_index: Specifies at which index of the `example_args` are considered + a parameter. + precedence: This specifies what precedence the graph matcher is going to + assign to the provided pattern. The graph matcher will go from lowest to + highest precedence, randomly breaking ties, when matching. Note that the + pattern that matches a parameter with the lowest precedence will get + registered and no other will. Specifically useful when there is a pattern + for a layer with and without bias, in which case the with bias + registration always should go with lower precedence. + """ + + # This is required because we can not use Jax before InitGoogle() runs + def register(): + jnp_args = jax.tree_map(jnp.asarray, example_args) + graph = function_to_jax_graph( + func, jnp_args, params_index=params_index, tagging_func=tagging_func) + if NAME_TO_JAX_GRAPH.get(name) is None: + NAME_TO_JAX_GRAPH[name] = (precedence, []) + assert precedence == NAME_TO_JAX_GRAPH[name][0] + NAME_TO_JAX_GRAPH[name][1].append(graph) + + DEFERRED_REGISTRATIONS.append(register) + + +def get_graph_patterns(): + """Returns all graph patterns sorted by their precedence.""" + while DEFERRED_REGISTRATIONS: + DEFERRED_REGISTRATIONS.pop()() + return [(name, pattern) for name, (_, pattern) in sorted( + NAME_TO_JAX_GRAPH.items(), key=lambda pair: pair[1][0])] + + +# Dense with bias +register_function( + "dense_with_bias", + tags.dense_func, + tags.dense_tagging, + [np.zeros([11, 13]), [np.zeros([13, 7]), np.zeros([7])]], + params_index=1, + precedence=0) + +# Dense without bias +register_function( + "dense_no_bias", + tags.dense_func, + tags.dense_tagging, [np.zeros([11, 13]), [np.zeros([13, 7])]], + params_index=1, + precedence=1) + +# Conv2d with bias +register_function( + "conv2d_with_bias", + tags.conv2d_func, + tags.conv2d_tagging, + [np.zeros([2, 8, 8, 5]), [np.zeros([3, 3, 5, 4]), + np.zeros([4])]], + params_index=1, + precedence=0) + +# Conv2d without bias +register_function( + "conv2d_no_bias", + tags.conv2d_func, + tags.conv2d_tagging, [np.zeros([2, 8, 8, 5]), [np.zeros([3, 3, 5, 4])]], + params_index=1, + precedence=1) + +# Standard scale and shift with both scale and shift +register_function( + "scale_and_shift", + functools.partial( + tags.scale_and_shift_func, has_scale=True, has_shift=True), + functools.partial( + tags.scale_and_shift_tagging, has_scale=True, has_shift=True), + [np.zeros([2, 13]), [np.zeros([13]), np.zeros([13])]], + params_index=1, + precedence=0) + +# Same but no broadcasting +register_function( + "scale_and_shift", + functools.partial( + tags.scale_and_shift_func, has_scale=True, has_shift=True), + functools.partial( + tags.scale_and_shift_tagging, has_scale=True, has_shift=True), + [np.zeros([13]), [np.zeros([13]), np.zeros([13])]], + params_index=1, + precedence=0) + +# Scale and shift as implemented in batch norm layers in Haiku +register_function( + "scale_and_shift", + tags.batch_norm_func, + functools.partial( + tags.batch_norm_tagging_func, has_scale=True, has_shift=True), + [[np.zeros([2, 13]), np.zeros([13])], [np.zeros([13]), + np.zeros([13])]], + params_index=1, + precedence=0) + +# Same but no broadcasting +register_function( + "scale_and_shift", + tags.batch_norm_func, + functools.partial( + tags.batch_norm_tagging_func, has_scale=True, has_shift=True), + [[np.zeros([13]), np.zeros([13])], [np.zeros([13]), + np.zeros([13])]], + params_index=1, + precedence=0) + +# Only scale +register_function( + "scale_only", + functools.partial( + tags.scale_and_shift_func, has_scale=True, has_shift=False), + functools.partial( + tags.scale_and_shift_tagging, has_scale=True, has_shift=False), + [np.zeros([2, 13]), [np.zeros([13])]], + params_index=1, + precedence=1) diff --git a/DeepSolid/utils/kfac_ferminet_alpha/tracer.py b/DeepSolid/utils/kfac_ferminet_alpha/tracer.py new file mode 100644 index 0000000..7c35330 --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/tracer.py @@ -0,0 +1,332 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""Module for the Jax tracer functionality for tags.""" +import functools +from typing import Any, Callable, Sequence, Tuple + +import jax +from jax import core +from jax import util as jax_util +import jax.numpy as jnp + +from DeepSolid.utils.kfac_ferminet_alpha import layers_and_loss_tags as tags +from DeepSolid.utils.kfac_ferminet_alpha import tag_graph_matcher as tgm +from DeepSolid.utils.kfac_ferminet_alpha import utils +from DeepSolid.utils.kfac_ferminet_alpha import vjp_rc + +_Function = Callable[[Any], Any] +_Loss = tags.LossTag + + +def extract_tags( + jaxpr: core.Jaxpr +) -> Tuple[Sequence[core.JaxprEqn], Sequence[core.JaxprEqn]]: + """Extracts all of the tag equations.""" + # Loop through equations and evaluate primitives using `bind` + layer_tags = [] + loss_tags = [] + for eqn in jaxpr.eqns: + if isinstance(eqn.primitive, tags.LossTag): + loss_tags.append(eqn) + elif isinstance(eqn.primitive, tags.LayerTag): + layer_tags.append(eqn) + return tuple(layer_tags), tuple(loss_tags) + + +def construct_compute_losses_inputs( + jaxpr: core.Jaxpr, + consts: Tuple[Any], + num_losses: int, + primals: Any, + params_index: int) -> Callable[[Any], Sequence[Sequence[jnp.ndarray]]]: + """Constructs a function that computes all of the inputs to all losses.""" + primals_ = list(primals) + + def forward_compute_losses( + params_primals: Any, + ) -> Sequence[Sequence[jnp.ndarray]]: + primals_[params_index] = params_primals + flat_args = jax.tree_flatten(primals_)[0] + # Mapping from variable -> value + env = dict() + read = functools.partial(tgm.read_env, env) + write = functools.partial(tgm.write_env, env) + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, flat_args) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Loop through equations and evaluate primitives using `bind` + losses_so_far = 0 + loss_tags = [] + for eqn in jaxpr.eqns: + tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) + if isinstance(eqn.primitive, tags.LossTag): + loss_tags.append(eqn) + losses_so_far += 1 + if num_losses is not None and losses_so_far == num_losses: + break + return tuple(tuple(read(v) for v in tag.invars) for tag in loss_tags) + # return tuple(jax_util.safe_map(read, tag.invars) for tag in loss_tags) + return forward_compute_losses + + +# We know when `.primitive` will be either a `LossTag` or a `LayerTag`, however +# pytype cannot infer its subclass, so we need to unbox it. + + +def _unbox_loss_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LossTag: + assert isinstance(jaxpr_eqn.primitive, tags.LossTag) + return jaxpr_eqn.primitive + + +def _unbox_layer_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LayerTag: + assert isinstance(jaxpr_eqn.primitive, tags.LayerTag) + return jaxpr_eqn.primitive + + +def trace_losses_matrix_vector_vjp(tagged_func: _Function, + params_index: int = 0): + """Returns the Jacobian-transposed vector product (backward mode) function in equivalent form to jax.vjp.""" + def vjp(*primals): + typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + _, loss_jaxpr_eqns = extract_tags(jaxpr) + n = len(loss_jaxpr_eqns) + losses_func = construct_compute_losses_inputs( + jaxpr, consts, n, primals, params_index) + losses_inputs, full_vjp_func = jax.vjp(losses_func, primals[params_index]) + losses = [] + for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs): + loss_tag = _unbox_loss_tag(jaxpr_eqn) + losses.append(loss_tag.loss(*inputs, weight=jaxpr_eqn.params["weight"])) + losses = tuple(losses) + + def vjp_func(tangents): + flat_tangents = jax.tree_flatten(tangents)[0] + loss_invars = [] + loss_targets = [] + for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs): + num_inputs = _unbox_loss_tag(jaxpr_eqn).num_inputs + loss_invars.append(tuple(jaxpr_eqn.invars[:num_inputs])) + loss_targets.append(inputs[num_inputs:]) + treedef = jax.tree_structure(loss_invars) + tangents = jax.tree_unflatten(treedef, flat_tangents) + # Since the losses could also take and targets as inputs and we don't want + # this function to computes vjp w.r.t to those (e.g. the user should not + # be providing tangent vectors for the targets, only for inputs) we have + # to manually fill in these "extra" tangents with zeros. + targets_tangents = jax.tree_map(jnp.zeros_like, loss_targets) + tangents = tuple(ti + tti for ti, tti in zip(tangents, targets_tangents)) + input_tangents = full_vjp_func(tangents)[0] + return input_tangents, + return losses, vjp_func + return vjp + + +def trace_losses_matrix_vector_jvp( + tagged_func: _Function, + params_index: int = 0): + """Returns the Jacobian vector product (forward mode) function in equivalent form to jax.jvp.""" + def jvp(primals, params_tangents): + typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + _, loss_tags = extract_tags(jaxpr) + n = len(loss_tags) + losses_func = construct_compute_losses_inputs(jaxpr, consts, n, + primals, params_index) + primals = (primals[params_index],) + tangents = (params_tangents,) + (primals_out, tangents_out) = jax.jvp(losses_func, primals, tangents) + tangents_out = tuple(tuple(t[:tag.primitive.num_inputs]) + for t, tag in zip(tangents_out, loss_tags)) + losses = tuple(tag.primitive.loss(*inputs, weight=tag.params["weight"]) + for tag, inputs in zip(loss_tags, primals_out)) + return losses, tangents_out + return jvp + + +def trace_losses_matrix_vector_hvp(tagged_func, params_index=0): + """Returns the Hessian vector product function of **the tagged losses**, rather than the output value of `tagged_func`.""" + # The function uses backward-over-forward mode. + + def hvp(primals, params_tangents): + typed_jaxpr = jax.make_jaxpr(tagged_func)(*primals) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + _, loss_tags = extract_tags(jaxpr) + n = len(loss_tags) + losses_func = construct_compute_losses_inputs( + jaxpr, consts, n, primals, params_index) + + def losses_sum(param_primals): + loss_inputs = losses_func(param_primals) + losses = [ + _unbox_loss_tag(jaxpr_eqn).loss( + *inputs, weight=jaxpr_eqn.params["weight"]) + for jaxpr_eqn, inputs in zip(loss_tags, loss_inputs) + ] + # This computes the sum of losses evaluated. Makes it easier as we can + # now use jax.grad rather than jax.vjp for taking derivatives. + return sum(jnp.sum(loss.evaluate(None)) for loss in losses) + + def grads_times_tangents(params_primals): + grads = jax.grad(losses_sum)(params_primals) + return utils.inner_product(grads, params_tangents) + + return jax.grad(grads_times_tangents)(primals[params_index]) + return hvp + + +def trace_estimator_vjp(tagged_func: _Function) -> _Function: + """Creates the function needed for an estimator of curvature matrices. + + Args: + tagged_func: An function that has been annotated with tags both for layers + and losses. + + Returns: + A function with the same signatures as `tagged_func`, which when provided + with inputs returns two things: + 1. The instances of all losses objected that are tagged. + 2. A second function, which when provide with tangent vectors for each + of the loss instances' parameters, returns for every tagged layer a + dictionary containing the following elements: + inputs - The primal values of the inputs to the layer. + outputs - The primal values of the outputs to the layer. + params - The primal values of the layer. + inputs_tangent - The tangent value of layer, given the provided + tangents of the losses. + inputs_tangent - The tangent value of layer, given the provided + tangents of the losses. + inputs_tangent - The tangent value of layer, given the provided + tangents of the losses. + """ + def full_vjp_func(func_args): + # Trace the tagged function + typed_jaxpr = jax.make_jaxpr(tagged_func)(*func_args) + jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals + layer_tags, loss_tags = extract_tags(jaxpr) + + layer_vars_flat = jax.tree_flatten([tag.invars for tag in layer_tags])[0] + layer_input_vars = tuple(set(layer_vars_flat)) + + def forward(): + own_func_args = func_args + # Mapping from variable -> value + env = dict() + read = functools.partial(tgm.read_env, env) + write = functools.partial(tgm.write_env, env) + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Loop through equations and evaluate primitives using `bind` + num_losses_passed = 0 + for eqn in jaxpr.eqns: + tgm.evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) + if isinstance(eqn.primitive, tags.LossTag): + num_losses_passed += 1 + if num_losses_passed == len(loss_tags): + break + if num_losses_passed != len(loss_tags): + raise ValueError("This should be unreachable.") + + return jax_util.safe_map(read, layer_input_vars) + + def forward_aux(aux): + own_func_args = func_args + # Mapping from variable -> value + env = dict() + read = functools.partial(tgm.read_env, env) + def write(var, val): + if not isinstance(var, (jax.core.Literal, jax.core.UnitVar)): + val = val + aux[var] if var in aux else val + env[var] = val + + # Bind args and consts to environment + write(jax.core.unitvar, jax.core.unit) + jax_util.safe_map(write, jaxpr.invars, jax.tree_flatten(own_func_args)[0]) + jax_util.safe_map(write, jaxpr.constvars, consts) + + # Loop through equations and evaluate primitives using `bind` + num_losses_passed = 0 + losses_inputs_values = [] + losses_kwargs_values = [] + for eqn in jaxpr.eqns: + input_values = jax_util.safe_map(read, eqn.invars) + tgm.evaluate_eqn(eqn, input_values, write) + if isinstance(eqn.primitive, tags.LossTag): + loss = eqn.primitive.loss(*input_values, weight=eqn.params["weight"]) + losses_inputs_values.append(loss.inputs) + losses_kwargs_values.append(dict( + targets=loss.targets, + weight=eqn.params["weight"] + )) + num_losses_passed += 1 + if num_losses_passed == len(loss_tags): + break + if num_losses_passed != len(loss_tags): + raise ValueError("This should be unreachable.") + # Read the inputs to the loss functions, but also return the target values + return tuple(losses_inputs_values), tuple(losses_kwargs_values) + + layer_input_values = forward() + primals_dict = dict(zip(layer_input_vars, layer_input_values)) + primals_dict.update(zip(jaxpr.invars, jax.tree_flatten(func_args)[0])) + aux_values = jax.tree_map(lambda x:jnp.zeros_like(x), layer_input_values) + aux_dict = dict(zip(layer_input_vars, aux_values)) + + losses_args, aux_vjp, losses_kwargs = vjp_rc.vjp_rc(forward_aux, aux_dict, + has_aux=True) + losses = tuple(tag.primitive.loss(*inputs, **kwargs) + for tag, inputs, kwargs in + zip(loss_tags, losses_args, losses_kwargs)) + + def vjp_func(tangents): + tangents = jax.tree_map(lambda x:x+0j, tangents) + all_tangents = aux_vjp(tangents) + tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:] + inputs_tangents = jax.tree_flatten(inputs_tangents)[0] + tangents_dict.update(zip(jaxpr.invars, inputs_tangents)) + + read_primals = functools.partial(tgm.read_env, primals_dict) + read_tangents = functools.partial(tgm.read_env, tangents_dict) + layers_info = [] + for jaxpr_eqn in layer_tags: + layer_tag = _unbox_layer_tag(jaxpr_eqn) + info = dict() + primals = jax_util.safe_map(read_primals, tuple(jaxpr_eqn.invars)) + ( + info["outputs"], + info["inputs"], + info["params"], + ) = layer_tag.split_all_inputs(primals) + tangents = jax_util.safe_map(read_tangents, tuple(jaxpr_eqn.invars)) + ( + info["outputs_tangent"], + info["inputs_tangent"], + info["params_tangent"], + ) = layer_tag.split_all_inputs(tangents) + layers_info.append(info) + return tuple(layers_info) + + return losses, vjp_func + return full_vjp_func diff --git a/DeepSolid/utils/kfac_ferminet_alpha/utils.py b/DeepSolid/utils/kfac_ferminet_alpha/utils.py new file mode 100644 index 0000000..c684263 --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/utils.py @@ -0,0 +1,458 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. +"""Utilities related to multi-device operations.""" +import collections +from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union +import dataclasses +import jax +from jax import core +from jax import lax +import jax.numpy as jnp +from jax.scipy import linalg +import jax.tree_util as tree_util + +T = TypeVar("T") + + +def wrap_if_pmap(p_func): + + def p_func_if_pmap(obj, axis_name): + try: + core.axis_frame(axis_name) + return p_func(obj, axis_name) + except NameError: + return obj + + return p_func_if_pmap + + +pmean_if_pmap = wrap_if_pmap(lax.pmean) +psum_if_pmap = wrap_if_pmap(lax.psum) +compute_mean = jax.pmap(lambda x: lax.pmean(x, "i"), axis_name="i") +compute_sum = jax.pmap(lambda x: lax.psum(x, "i"), axis_name="i") + + +def get_first(obj: T) -> T: + return jax.tree_map(lambda x: x[0], obj) + + +def get_mean(obj: T) -> T: + return get_first(compute_mean(obj)) + + +def get_sum(obj: T) -> T: + return get_first(compute_sum(obj)) + + +broadcast_all_local_devices = jax.pmap(lambda x: x) + + +def replicate_all_local_devices(obj: T) -> T: + n = jax.local_device_count() + obj_stacked = jax.tree_map(lambda x: jnp.stack([x] * n, axis=0), obj) + return broadcast_all_local_devices(obj_stacked) + + +def make_different_rng_key_on_all_devices(rng: jnp.ndarray) -> jnp.ndarray: + rng = jax.random.fold_in(rng, jax.host_id()) + rng = jax.random.split(rng, jax.local_device_count()) + return broadcast_all_local_devices(rng) + + +p_split = jax.pmap(lambda key: tuple(jax.random.split(key))) + + +def scalar_mul(obj: T, scalar: Union[float, jnp.ndarray]) -> T: + return jax.tree_map(lambda x: x * scalar, obj) + + +def scalar_div(obj: T, scalar: Union[float, jnp.ndarray]) -> T: + return jax.tree_map(lambda x: x / scalar, obj) + + +def make_func_args(params, func_state, rng, batch, has_state: bool, + has_rng: bool): + """Correctly puts all arguments to the function together.""" + func_args = (params,) + if has_state: + if func_state is None: + raise ValueError("The `func_state` is None, but the argument `has_state` " + "is True.") + func_args += (func_state,) + if has_rng: + if rng is None: + raise ValueError("The `rng` is None, but the argument `has_rng` is True.") + func_args += (rng,) + func_args += (batch,) + return func_args + + +def extract_func_outputs( + raw_outputs: Any, + has_aux: bool, + has_state: bool, +) -> Tuple[jnp.ndarray, Any, Any]: + """Given the function output returns separately the loss, func_state, aux.""" + if not has_aux and not has_state: + return raw_outputs, None, None + loss, other = raw_outputs + if has_aux and has_state: + func_state, aux = other + elif has_aux: + func_state, aux = None, other + else: + func_state, aux = other, None + return loss, func_state, aux + + +def inner_product(obj1: T, obj2: T) -> jnp.ndarray: + if jax.tree_structure(obj1) != jax.tree_structure(obj2): + raise ValueError("The two structures are not identical.") + elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2) + return sum(jax.tree_flatten(elements_product)[0]) + + +def psd_inv_cholesky(matrix: jnp.ndarray, damping: jnp.ndarray) -> jnp.ndarray: + assert matrix.ndim == 2 + identity = jnp.eye(matrix.shape[0]) + matrix = matrix + damping * identity + return linalg.solve(matrix, identity, sym_pos=True) + + +def solve_maybe_small(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """Computes a^-1 b more efficiently for small matrices.""" + assert a.shape[-1] == a.shape[-2] == b.shape[-1] + d = a.shape[-1] + if d == 0: + return a + elif d == 1: + return b / a[..., 0] + elif d == 2: + det = a[..., 0, 0] * a[..., 1, 1] - a[..., 0, 1] * a[..., 1, 0] + b_0 = a[..., 1, 1] * b[..., 0] - a[..., 0, 1] * b[..., 1] + b_1 = a[..., 0, 0] * b[..., 1] - a[..., 1, 0] * b[..., 0] + return jnp.stack([b_0, b_1], axis=-1) / det + elif d == 3: + raise NotImplementedError() + return jnp.linalg.solve(a, b) + + +def pi_adjusted_inverse( + factor_0: jnp.ndarray, + factor_1: jnp.ndarray, + damping: jnp.ndarray, + pmap_axis_name: str, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Performs inversion with pi-adjusted damping.""" + # Compute the norms of each factor + norm_0 = jnp.trace(factor_0) + norm_1 = jnp.trace(factor_1) + + # We need to sync the norms here, because reduction can be non-deterministic. + # They specifically are on GPUs by default for better performance. + # Hence although factor_0 and factor_1 are synced, the trace operation above + # can still produce different answers on different devices. + norm_0, norm_1 = pmean_if_pmap((norm_0, norm_1), axis_name=pmap_axis_name) + + # Compute the overall scale + scale = norm_0 * norm_1 + + def regular_inverse( + operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: + factor0, factor1, norm0, norm1, s, d = operand + # Special cases with one or two scalar factors + if factor0.size == 1 and factor1.size == 1: + value = jnp.ones_like(factor0) / jnp.sqrt(s) + return value, value + if factor0.size == 1: + factor1_normed = factor1 / norm1 + damping1 = d / norm1 + factor1_inv = psd_inv_cholesky(factor1_normed, damping1) + return jnp.full((1, 1), s), factor1_inv + if factor1.size == 1: + factor0_normed = factor0 / norm0 + damping0 = d / norm0 + factor0_inv = psd_inv_cholesky(factor0_normed, damping0) + return factor0_inv, jnp.full((1, 1), s) + + # Invert first factor + factor0_normed = factor0 / norm0 + damping0 = jnp.sqrt(d * factor1.shape[0] / (s * factor0.shape[0])) + factor0_inv = psd_inv_cholesky(factor0_normed, damping0) / jnp.sqrt(s) + + # Invert second factor + factor1_normed = factor1 / norm1 + damping1 = jnp.sqrt(d * factor0.shape[0] / (s * factor1.shape[0])) + factor1_inv = psd_inv_cholesky(factor1_normed, damping1) / jnp.sqrt(s) + return factor0_inv, factor1_inv + + def zero_inverse( + operand: Sequence[jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]: + return (jnp.eye(factor_0.shape[0]) / jnp.sqrt(operand[-1]), + jnp.eye(factor_1.shape[0]) / jnp.sqrt(operand[-1])) + + # In the special case where for some reason one of the factors is zero, then + # the correct inverse of `(0 kron A + lambda I)` is + # `(I/sqrt(lambda) kron (I/sqrt(lambda)`. However, because one of the norms is + # zero, then `pi` and `1/pi` would be 0 and infinity leading to NaN values. + # Hence, we need to make this check explicitly. + return lax.cond( + jnp.greater(scale, 0.0), + regular_inverse, + zero_inverse, + operand=(factor_0, factor_1, norm_0, norm_1, scale, damping)) + + +def convert_value_and_grad_to_value_func( + value_and_grad_func, + has_aux: bool = False, +): + """Converts a value_and_grad function to value_func only.""" + + def value_func(*args, **kwargs): + out, _ = value_and_grad_func(*args, **kwargs) + if has_aux: + return out[0] + else: + return out + + return value_func + + +def check_structure_shapes_and_dtype(obj1: T, obj2: T) -> None: + """Verifies that the two objects have the same pytree structure.""" + assert jax.tree_structure(obj1) == jax.tree_structure(obj2) + for v1, v2 in zip(jax.tree_flatten(obj1)[0], jax.tree_flatten(obj2)[0]): + assert v1.shape == v2.shape + assert v1.dtype == v2.dtype + + +def check_first_dim_is_batch_size(batch_size: int, *args: jnp.ndarray) -> None: + for i, arg in enumerate(args): + if arg.shape[0] != batch_size: + raise ValueError(f"Expecting first dimension of arg[{i}] with shape " + f"{arg.shape} to be equal to the batch size " + f"{batch_size}.") + + +def py_tree_registered_dataclass(cls, *args, **kwargs): + """Creates a new dataclass type and registers it as a pytree node.""" + dcls = dataclasses.dataclass(cls, *args, **kwargs) + tree_util.register_pytree_node( + dcls, + lambda instance: ( # pylint: disable=g-long-lambda + [getattr(instance, f.name) + for f in dataclasses.fields(instance)], None), + lambda _, instance_args: dcls(*instance_args)) + return dcls + + +class WeightedMovingAverage: + """A wrapped class for a variable for which we keep exponential moving average.""" + + def __init__(self, weight: jnp.ndarray, array: jnp.ndarray): + self._weight = weight + self._array = array + + @staticmethod + def zero(shape: Sequence[int]) -> "WeightedMovingAverage": + return WeightedMovingAverage(weight=jnp.zeros([]), array=jnp.zeros(shape)) + + @property + def weight(self) -> jnp.ndarray: + return self._weight + + @property + def value(self) -> jnp.ndarray: + return self._array / self._weight + + @property + def raw_value(self) -> jnp.ndarray: + return self._array + + def update(self, value: jnp.ndarray, old_weight_multiplier: float, + new_weight: float) -> None: + self._weight = old_weight_multiplier * self._weight + new_weight + self._array = old_weight_multiplier * self._array + new_weight * value + + def sync(self, pmap_axis_name: str) -> None: + self._array = pmean_if_pmap(self._array, pmap_axis_name) + + def __str__(self) -> str: + return (f"ExponentialMovingAverage(weight={self._weight}, " + f"array={self._array})") + + def __repr__(self) -> str: + return self.__str__() + + +tree_util.register_pytree_node( + WeightedMovingAverage, + lambda instance: ((instance.weight, instance.raw_value), None), + lambda _, instance_args: WeightedMovingAverage(*instance_args), +) + + +class Stateful: + """A class for stateful objects.""" + + def __init__(self, stateful_fields_names: Optional[Sequence[str]] = ()): + self.__stateful_fields_names = stateful_fields_names + + def _add_stateful_fields_names(self, value: Sequence[str]) -> None: + self.__stateful_fields_names += tuple(value) + + def get_state(self) -> Mapping[str, Any]: + """Returns the state of the object.""" + state = dict() + for name in self.__stateful_fields_names: + state[name] = Stateful._get_state_from_instance(getattr(self, name)) + return state + + def set_state(self, value): + """Sets the state of the object with the provided value and returns the object.""" + assert isinstance(value, dict) + for name in self.__stateful_fields_names: + setattr(self, name, + Stateful._set_state_to_instance(getattr(self, name), value[name])) + return self + + def clear_state(self) -> None: + """Clears the state of the object.""" + for name in self.__stateful_fields_names: + setattr(self, name, + Stateful._clear_state_from_instance(getattr(self, name))) + + def pop_state(self) -> Mapping[str, Any]: + """Returns the current state of the object, while simultaneously clearing it.""" + state = self.get_state() + self.clear_state() + return state + + @staticmethod + def _get_state_from_instance(obj): + """Recursively gets the state of the object and returns it.""" + if isinstance(obj, Stateful): + return obj.get_state() + if isinstance(obj, list): + return [Stateful._get_state_from_instance(i) for i in obj] + if isinstance(obj, tuple): + return tuple(Stateful._get_state_from_instance(i) for i in obj) + if isinstance(obj, collections.OrderedDict): + return collections.OrderedDict( + (k, Stateful._get_state_from_instance(v)) for k, v in obj.items()) + if isinstance(obj, dict): + return dict( + (k, Stateful._get_state_from_instance(v)) for k, v in obj.items()) + return obj + + @staticmethod + def _set_state_to_instance(obj, value): + """Recursively sets the state of the object and returns it.""" + if isinstance(obj, Stateful): + obj.set_state(value) + return obj + if isinstance(value, list): + if obj is None: + obj = [None] * len(value) + return [ + Stateful._set_state_to_instance(obj_i, value_i) + for obj_i, value_i in zip(obj, value) + ] + if isinstance(value, tuple): + if obj is None: + obj = [None] * len(value) + return tuple( + Stateful._set_state_to_instance(obj_i, value_i) + for obj_i, value_i in zip(obj, value)) + if isinstance(value, collections.OrderedDict): + if obj is None: + obj = dict((k, None) for k in value) + return collections.OrderedDict( + (k, Stateful._set_state_to_instance(obj[k], value[k])) for k in obj) + if isinstance(value, dict): + obj = dict((k, None) for k in value) + return dict( + (k, Stateful._set_state_to_instance(obj[k], value[k])) for k in obj) + return value + + @staticmethod + def _clear_state_from_instance(obj): + """Recursively clears the state of the object and returns it.""" + if isinstance(obj, Stateful): + obj.clear_state() + return obj + if isinstance(obj, list): + return [Stateful._clear_state_from_instance(obj_i) for obj_i in obj] + if isinstance(obj, tuple): + return tuple(Stateful._clear_state_from_instance(obj_i) for obj_i in obj) + if isinstance(obj, collections.OrderedDict): + return collections.OrderedDict( + (k, Stateful._clear_state_from_instance(obj[k])) for k in obj) + if isinstance(obj, dict): + return dict((k, Stateful._clear_state_from_instance(obj[k])) for k in obj) + return None + + @staticmethod + def infer_class_state(class_type): + """Infers a stateful class state attributes from class annotations.""" + if not issubclass(class_type, Stateful): + raise ValueError( + f"In order to annotate a class as stateful it must inherit " + f"{Stateful!r}") + + class_type = dataclasses.dataclass( + class_type, init=False, repr=False, eq=False) # pytype: disable=wrong-keyword-args + fields_names = tuple(field.name for field in dataclasses.fields(class_type)) + original_init = getattr(class_type, "__init__", None) + if original_init is None: + + def injected_init(self, *args, **kwargs): + super(self.__class__, self).__init__(*args, **kwargs) # pylint: disable=bad-super-call + Stateful._add_stateful_fields_names(self, fields_names) + for field_name in fields_names: + if getattr(self, field_name, None) is None: + setattr(self, field_name, None) + + setattr(class_type, "__init__", injected_init) + else: + + def injected_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + Stateful._add_stateful_fields_names(self, fields_names) + for field_name in fields_names: + if getattr(self, field_name, None) is None: + setattr(self, field_name, None) + + setattr(class_type, "__init__", injected_init) + return class_type + + +def compute_sq_norm_relative_abs_diff(obj, pmap_axis_name): + sq_norm = inner_product(obj, obj) + synced_sq_norm = psum_if_pmap(sq_norm, pmap_axis_name) + synced_sq_norm = (synced_sq_norm - sq_norm) / (jax.device_count() - 1.0) + sq_norm_abs_diff = jnp.abs(sq_norm - synced_sq_norm) + return sq_norm_abs_diff / sq_norm + + +def product(iterable_object): + x = 1 + for element in iterable_object: + x *= element + return x diff --git a/DeepSolid/utils/kfac_ferminet_alpha/vjp_rc.py b/DeepSolid/utils/kfac_ferminet_alpha/vjp_rc.py new file mode 100644 index 0000000..57c7761 --- /dev/null +++ b/DeepSolid/utils/kfac_ferminet_alpha/vjp_rc.py @@ -0,0 +1,93 @@ +# Copyright 2020, 2021 The NetKet Authors - All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# enable x64 on jax +# must be done at startup. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import jax +import jax.numpy as jnp +from jax.tree_util import ( + tree_map, + tree_multimap, +) + +def vjp_rc( + fun, *primals, has_aux: bool = False, conjugate: bool = False): + ''' + realize the vjp of R->C function + :param fun: + :param primals: + :param has_aux: + :param conjugate: + :return: + ''' + if has_aux: + + def real_fun(*primals): + val, aux = fun(*primals) + real_val = jax.tree_map(lambda x:x.real, val) + return real_val, aux + + def imag_fun(*primals): + val, aux = fun(*primals) + imag_val = jax.tree_map(lambda x: x.imag, val) + return imag_val, aux + + vals_r, vjp_r_fun, aux = jax.vjp(real_fun, *primals, has_aux=True) + vals_j, vjp_j_fun, _ = jax.vjp(imag_fun, *primals, has_aux=True) + + else: + real_fun = lambda *primals: fun(*primals).real + imag_fun = lambda *primals: fun(*primals).imag + + vals_r, vjp_r_fun = jax.vjp(real_fun, *primals, has_aux=False) + vals_j, vjp_j_fun = jax.vjp(imag_fun, *primals, has_aux=False) + + primals_out = jax.tree_multimap(lambda x,y:x + 1j*y, vals_r, vals_j) + + def vjp_fun(ȳ): + """ + function computing the vjp product for a R->C function. + """ + ȳ_r = jax.tree_map(lambda x:x.real, ȳ) + # ȳ_r = jax.tree_map(lambda x:jnp.asarray(x, dtype=vals_r.dtype), ȳ_r) + ȳ_j = jax.tree_map(lambda x:x.imag, ȳ) + # ȳ_j = jax.tree_map(lambda x:jnp.asarray(x, dtype=vals_j.dtype), ȳ_j) + + # val = vals_r + vals_j + vr_jr = vjp_r_fun(jax.tree_map(lambda x,v:jnp.asarray(x, dtype=v.dtype), ȳ_r, vals_r)) + vj_jr = vjp_r_fun(jax.tree_map(lambda x,v:jnp.asarray(x, dtype=v.dtype), ȳ_j, vals_r)) + vr_jj = vjp_j_fun(jax.tree_map(lambda x,v:jnp.asarray(x, dtype=v.dtype), ȳ_r, vals_j)) + vj_jj = vjp_j_fun(jax.tree_map(lambda x,v:jnp.asarray(x, dtype=v.dtype), ȳ_j, vals_j)) + + r = tree_multimap( + lambda re, im: re + 1j * im, + vr_jr, + vj_jr, + ) + i = tree_multimap(lambda re, im: re + 1j * im, vr_jj, vj_jj) + out = tree_multimap(lambda re, im: re + 1j * im, r, i) + + if conjugate: + out = tree_map(jnp.conjugate, out) + + return out + + if has_aux: + return primals_out, vjp_fun, aux + else: + return primals_out, vjp_fun \ No newline at end of file diff --git a/DeepSolid/utils/poscar_to_cell.py b/DeepSolid/utils/poscar_to_cell.py new file mode 100644 index 0000000..a214c87 --- /dev/null +++ b/DeepSolid/utils/poscar_to_cell.py @@ -0,0 +1,92 @@ +""" +I/O routines for crystal structure. +Author: + Zhi-Hao Cui + Bo-Xiao Zheng +""" +# modified from libdmet_preview: +# https://github.com/gkclab/libdmet_preview/blob/faee119f18755314d945393595301f66baf40ae5/libdmet/utils/iotools.py + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + + +import numpy as np +from pyscf.data.nist import BOHR +from collections import OrderedDict +import scipy.linalg as la +import sys +import os + + +def Frac2Real(cellsize, coord): + assert cellsize.ndim == 2 and cellsize.shape[0] == cellsize.shape[1] + return np.dot(coord, cellsize) + +def Real2Frac(cellsize, coord): + assert cellsize.ndim == 2 and cellsize.shape[0] == cellsize.shape[1] + return np.dot(coord, la.inv(cellsize)) + + +def read_poscar(fname="POSCAR"): + """ + Read cell structure from a VASP POSCAR file. + + Args: + fname: file name. + Returns: + cell: cell, without build, unit in A. + """ + from pyscf.pbc import gto + with open(fname, 'r') as f: + lines = f.readlines() + + # 1 line scale factor + line = lines[1].split() + assert len(line) == 1 + factor = float(line[0]) + + # 2-4 line, lattice vector + a = np.array([np.fromstring(lines[i], dtype=np.double, sep=' ') \ + for i in range(2, 5)]) * factor + a = a / BOHR + + # 5, 6 line, species names and numbers + sp_names = lines[5].split() + if all([name.isdigit() for name in sp_names]): + # 5th line can be number of atoms not names. + sp_nums = np.fromstring(lines[5], dtype=int, sep=' ') + sp_names = ["X" for i in range(len(sp_nums))] + line_no = 6 + else: + sp_nums = np.fromstring(lines[6], dtype=int, sep=' ') + line_no = 7 + + # 7, cartisian or fraction or direct + line = lines[line_no].split() + assert len(line) == 1 + use_cart = line[0].startswith(('C', 'K', 'c', 'k')) + line_no += 1 + + # 8-end, coords + atom_col = [] + for sp_name, sp_num in zip(sp_names, sp_nums): + for i in range(sp_num): + # there may be 4th element for comments or fixation. + coord = np.array(list(map(float, + \ + lines[line_no].split()[:3]))) + if use_cart: + coord *= factor + coord = coord / BOHR + else: + coord = Frac2Real(a, coord) + atom_col.append((sp_name, coord)) + line_no += 1 + + cell = gto.Cell() + cell.a = a + cell.atom = atom_col + cell.unit = 'Bohr' + return cell + diff --git a/DeepSolid/utils/system.py b/DeepSolid/utils/system.py new file mode 100644 index 0000000..56a7250 --- /dev/null +++ b/DeepSolid/utils/system.py @@ -0,0 +1,87 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +from typing import Sequence +import attr +import numpy as np + +from DeepSolid.utils import elements +from DeepSolid.utils import units as unit_conversion + + +@attr.s +class Atom: + """Atom information for Hamiltonians. + + The nuclear charge is inferred from the symbol if not given, in which case the + symbol must be the IUPAC symbol of the desired element. + + Attributes: + symbol: Element symbol. + coords: An iterable of atomic coordinates. Always a list of floats and in + bohr after initialisation. Default: place atom at origin. + charge: Nuclear charge. Default: nuclear charge (atomic number) of atom of + the given name. + atomic_number: Atomic number associated with element. Default: atomic number + of element of the given symbol. Should match charge unless fractional + nuclear charges are being used. + units: String giving units of coords. Either bohr or angstrom. Default: + bohr. If angstrom, coords are converted to be in bohr and units to the + string 'bohr'. + coords_angstrom: list of atomic coordinates in angstrom. + coords_array: Numpy array of atomic coordinates in bohr. + element: elements.Element corresponding to the symbol. + """ + symbol = attr.ib(type=str) + coords = attr.ib( + type=Sequence[float], + converter=lambda xs: tuple(float(x) for x in xs), + default=(0.0, 0.0, 0.0)) + charge = attr.ib(type=float, converter=float) + atomic_number = attr.ib(type=int, converter=int) + units = attr.ib( + type=str, + default='bohr', + validator=attr.validators.in_(['bohr', 'angstrom'])) + + @charge.default + def _set_default_charge(self): + return self.element.atomic_number + + @atomic_number.default + def _set_default_atomic_number(self): + return self.element.atomic_number + + def __attrs_post_init__(self): + if self.units == 'angstrom': + self.coords = [unit_conversion.angstrom2bohr(x) for x in self.coords] + self.units = 'bohr' + + @property + def coords_angstrom(self): + return [unit_conversion.bohr2angstrom(x) for x in self.coords] + + @property + def coords_array(self): + if not hasattr(self, '_coords_arr'): + self._coords_arr = np.array(self.coords) + return self._coords_arr + + @property + def element(self): + return elements.SYMBOLS[self.symbol] diff --git a/DeepSolid/utils/units.py b/DeepSolid/utils/units.py new file mode 100644 index 0000000..a84a301 --- /dev/null +++ b/DeepSolid/utils/units.py @@ -0,0 +1,49 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +from typing import TypeVar +import numpy as np + +# 1 Bohr = 0.52917721067 (12) x 10^{-10} m +# https://physics.nist.gov/cgi-bin/cuu/Value?bohrrada0 +# Note: pyscf uses a slightly older definition of 0.52917721092 angstrom. +ANGSTROM_BOHR = 0.52917721067 +BOHR_ANGSTROM = 1. / ANGSTROM_BOHR + +# 1 Hartree = 627.509474 kcal/mol +# https://en.wikipedia.org/wiki/Hartree +KCAL_HARTREE = 627.509474 +HARTREE_KCAL = 1. / KCAL_HARTREE + +NumericalLike = TypeVar('NumericalLike', float, np.ndarray) + + +def bohr2angstrom(x_b: NumericalLike) -> NumericalLike: + return x_b * ANGSTROM_BOHR + + +def angstrom2bohr(x_a: NumericalLike) -> NumericalLike: + return x_a * BOHR_ANGSTROM + + +def hartree2kcal(x_b: NumericalLike) -> NumericalLike: + return x_b * KCAL_HARTREE + + +def kcal2hartree(x_a: NumericalLike) -> NumericalLike: + return x_a * HARTREE_KCAL diff --git a/DeepSolid/utils/writers.py b/DeepSolid/utils/writers.py new file mode 100644 index 0000000..3f42c92 --- /dev/null +++ b/DeepSolid/utils/writers.py @@ -0,0 +1,158 @@ +# Copyright 2020 DeepMind Technologies Limited. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file may have been modified by Bytedance Inc. (“Bytedance Modifications”). +# All Bytedance Modifications are Copyright 2022 Bytedance Inc. + +import contextlib +import os +from typing import Mapping, Optional, Sequence + +from absl import logging +import tables + + +class Writer(contextlib.AbstractContextManager): + """Write data to CSV, as well as logging data to stdout if desired.""" + + def __init__(self, + name: str, + schema: Sequence[str], + directory: str = 'logs/', + iteration_key: Optional[str] = 't', + log: bool = True): + """Initialise Writer. + + Args: + name: file name for CSV. + schema: sequence of keys, corresponding to each data item. + directory: directory path to write file to. + iteration_key: if not None or a null string, also include the iteration + index as the first column in the CSV output with the given key. + log: Also log each entry to stdout. + """ + self._schema = schema + if not os.path.isdir(directory): + os.mkdir(directory) + self._filename = os.path.join(directory, name + '.csv') + self._iteration_key = iteration_key + self._log = log + + def __enter__(self): + should_add_header = not os.path.exists(self._filename) + + self._file = open(self._filename, 'a+') + + if should_add_header: + # write top row of csv + if self._iteration_key: + self._file.write(f'{self._iteration_key},') + self._file.write(','.join(self._schema) + '\n') + return self + + def write(self, t: int, **data): + """Writes to file and stdout. + + Args: + t: iteration index. + **data: data items with keys as given in schema. + """ + row = [str(data.get(key, '')) for key in self._schema] + if self._iteration_key: + row.insert(0, str(t)) + for key in data: + if key not in self._schema: + raise ValueError('Not a recognized key for writer: %s' % key) + + # write the data to csv + self._file.write(','.join(row) + '\n') + + # write the data to abseil logs + if self._log: + logging.info('Iteration %s: %s', t, data) + + def flush(self): + self._file.flush() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.flush() + self._file.close() + + +class H5Writer(contextlib.AbstractContextManager): + """Write data to HDF5 files.""" + + def __init__(self, + name: str, + schema: Mapping[str, Sequence[int]], + directory: str = '', + index_key: str = 't', + compression_level: int = 5): + """Initialise H5Writer. + + Args: + name: file name for CSV. + schema: dict of keys, corresponding to each data item . All data is + assumed ot be 32-bit floats. + directory: directory path to write file to. + index_key: name of (integer) key used to index each entry. + compression_level: compression level (0-9) used to compress HDF5 file. + """ + self._path = os.path.join(directory, name) + self._schema = schema + self._index_key = index_key + self._description = {} + self._file = None + self._complevel = compression_level + + def __enter__(self): + if not self._schema: + return self + pos = 1 + self._description[self._index_key] = tables.Int32Col(pos=pos) + for key, shape in self._schema.items(): + pos += 1 + self._description[key] = tables.Float32Col(pos=pos, shape=shape) + if not os.path.isdir(os.path.dirname(self._path)): + os.mkdir(os.path.dirname(self._path)) + self._file = tables.open_file( + self._path, + mode='w', + title='Fermi Net Data', + filters=tables.Filters(complevel=self._complevel)) + self._table = self._file.create_table( + where=self._file.root, name='data', description=self._description) + return self + + def write(self, index: int, data): + """Write data to HDF5 file. + + Args: + index: iteration index. + data: dict of arrays to write to file. Only elements with keys in the + schema are written. + """ + if self._file: + h5_data = (index,) + for key in self._description: + if key != self._index_key: + h5_data += (data[key],) + self._table.append([h5_data]) + self._table.flush() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._file: + self._file.close() + self._file = None diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b09cd78 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..e7072a2 --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +# DeepSolid + +An implementation of the algorithm given in +["Ab initio calculation of real solids via neural network ansatz"](https://arxiv.org/abs/2203.15472). +A periodic neural network is proposed as wavefunction ansatz for solid quantum Monte Carlo and achieves +unprecedented accuracy compared with other state-of-the-art methods. +This repository is developed upon [FermiNet](https://github.com/deepmind/ferminet/tree/jax) +and [PyQMC](https://github.com/WagnerGroup/pyqmc). + +## Installation + +DeepSolid can be installed via the supplied setup.py file. +```shell +pip3 install -e. +``` + +If GPU is available, we recommend you to install jax and jaxlib with cuda 11.4+. +Our experiments were carried out with jax==0.2.26 and jaxlib==0.1.75. + +## Usage + +[Ml_collection](https://github.com/google/ml_collections) package is used for system definition. Below is a simple example of H10 in PBC: +``` +deepsolid --config=PATH/TO/DeepSolid/config/two_hydrogen_cell.py:H,5,1,1,2.0,0,ccpvdz --config.batch_size 4096 +``` + +### Customize your system +Simulation system can be customized in config.py file, such as + +```python +import numpy as np +from pyscf.pbc import gto +from DeepSolid import base_config +from DeepSolid import supercell + + +def get_config(input_str): + symbol, S = input_str.split(',') + cfg = base_config.default() + + # Set up cell. + cell = gto.Cell() + + # Define the atoms in the primitive cell. + cell.atom = f""" + {symbol} 0.000000000000 0.000000000000 0.000000000000 + """ + + # Define the pretrain basis. + cell.basis = "ccpvdz" + + # Define the lattice vectors of the primitive cell. + # In this example it's a simple cubic. + cell.a = """ + 3.00, 0.00, 0.00 + 0.00, 3.00, 0.00 + 0.00, 0.00, 3.00""" + + # Define the unit used in cell definition, only support Bohr now. + cell.unit = "B" + cell.verbose = 5 + + # Define the threshold to discard gaussian basis used in pretrain. + cell.exp_to_discard = 0.1 + cell.build() + + # Define the supercell for QMC, S specifies how to tile the primitive cell. + S = np.eye(3) * int(S) + simulation_cell = supercell.get_supercell(cell, S) + + # Assign the defined supercell to cfg. + cfg.system.pyscf_cell = simulation_cell + + return cfg +``` +After defining the config file, simply use the following command to launch the simulation: + +```shell +deepsolid --config=PATH/TO/config.py:He,1 --config.batch_size 4096 +``` + + +### Read structure from poscar file + +We also support reading structure from poscar file, which is commonly used. Simply use the following command +```shell +deepsolid --config=DeepSolid/config/read_poscar.py:PATH/TO/POSCAR/bcc_li.vasp,1,ccpvdz +``` +## Distributed training +Present released code doesn't support multi-node training. See [this link](https://github.com/google/jax/pull/8364) +for help. + +## Tricks to accelerate +The bottleneck of DeepSolid is the laplacian evaluation of the neural network. We recommend +the users to use partition mode instead, simply adding two more flags: +```shell +deepsolid --config=PATH/TO/config.py --config.optim.laplacian_mode=partition --config.optim.partition_number=3 +``` +Partition mode will try to parallelize the calculation of laplacian and partition number must be a factor of +(electron number * 3). Note that partition mode will require a lot of GPU memory. + +## Precision +DeepSolid supports both FP32 and FP64. However, we recommend the users turn off the TF32 mode which +is automatically adopted in A100 if FP32 is chosen. TF32 can be turned off using the following command: + +```shell +NVIDIA_TF32_OVERRIDE=0 deepsolid --config.use_x64=False +``` + +## Giving Credit + +If you use this code in your work, please cite the associated paper. + +``` +@article{li2022ab, + title={Ab initio calculation of real solids via neural network ansatz}, + author={Li, Xiang and Li, Zhe and Chen, Ji}, + journal={arXiv preprint arXiv:2203.15472}, + year={2022} +} +``` diff --git a/bin/deepsolid b/bin/deepsolid new file mode 100644 index 0000000..7750dac --- /dev/null +++ b/bin/deepsolid @@ -0,0 +1,37 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# modified from FermiNet:https://github.com/deepmind/ferminet + +import sys +import os + +from absl import app +from absl import flags +from absl import logging +from jax.config import config as jax_config +from DeepSolid import process +from ml_collections.config_flags import config_flags + +logging.get_absl_handler().python_handler.stream = sys.stdout +logging.set_verbosity(logging.INFO) + +# internal imports + +FLAGS = flags.FLAGS + +config_flags.DEFINE_config_file('config', None, 'Path to config file.') + + +def main(_): + cfg = FLAGS.config + if cfg.use_x64: + jax_config.update("jax_enable_x64", True) + process.process(cfg) + + +if __name__ == '__main__': + app.run(main) \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e2440ab --- /dev/null +++ b/setup.py @@ -0,0 +1,40 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup, find_packages + +REQUIRED_PACKAGES = ( + "absl-py", + 'attrs', + "dataclasses", + "networkx", + "scipy", + "numpy", + "ordered-set", + "typing", + "chex", + "jax", + "jaxlib", + "pandas", + "ml_collections", + "pyscf", + "tables", + 'h5py==3.2.1', + 'optax==0.0.9', + +) + +setup( + name="DeepSolid", + version="1.0", + description="A library combining solid quantum Monte Carlo and neural network.", + author='ByteDance', + author_email='lixiang.62770689@bytedance.com', + install_requires=REQUIRED_PACKAGES, + packages=find_packages(), + scripts=['bin/deepsolid'], + license='Apache 2.0', +)