Skip to content

Commit

Permalink
neighbor index constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Feb 8, 2024
1 parent 0afdb9c commit bffd80c
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion src/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,75 @@ def batch_to_dict(batch):
if hasattr(batch, 'central_scalar_attributes'):
batch_dict['central_scalar_attributes'] = batch.central_scalar_attributes

return batch_dict
return batch_dict

class NeighborIndexConstructor:
def __init__(self, i_list, j_list, S_list, species):
n_atoms = len(species)
self.neighbors_index = [[] for i in range(n_atoms)]
self.neighbors_shift = [[] for i in range(n_atoms)]

for i, j, index, S in zip(i_list, j_list, range(len(i_list)), S_list):
self.neighbors_index[i].append(j)
self.neighbors_shift[i].append(S)

self.relative_positions = [[] for i in range(n_atoms)]
self.neighbor_species = [[] for i in range(n_atoms)]
self.neighbors_pos = [[] for i in range(n_atoms)]

def is_same(first, second):
for i in range(len(first)):
if first[i] != second[i]:
return False
return True

for i, j, index, S in zip(i_list, j_list, range(len(i_list)), S_list):
self.relative_positions[i].append(index)
self.neighbor_species[i].append(species[j])
for k in range(len(self.neighbors_index[j])):
if (self.neighbors_index[j][k] == i) and is_same(self.neighbors_shift[j][k], -S):
self.neighbors_pos[i].append(k)

def get_max_num(self):
maximum = None
for chunk in self.relative_positions:
if (maximum is None) or (len(chunk) > maximum):
maximum = len(chunk)
return maximum

def get_neighbor_index(self, max_num, all_species):
nums = []
mask = []
relative_positions = np.zeros([len(self.relative_positions), max_num])
neighbors_pos = np.zeros([len(self.relative_positions), max_num], dtype = int)
neighbors_index = np.zeros([len(self.relative_positions), max_num], dtype = int)

for i in range(len(self.relative_positions)):
now = np.array(self.relative_positions[i])
if len(now) > 0:
relative_positions[i, :len(now)] = now
neighbors_pos[i, :len(now)] = self.neighbors_pos[i]
neighbors_index[i, :len(now)] = self.neighbors_index[i]

nums.append(len(self.relative_positions[i]))
current_mask = np.zeros(max_num)
current_mask[len(self.relative_positions[i]):] = True
mask.append(current_mask[np.newaxis, :])


mask = np.concatenate(mask, axis = 0)
relative_positions = torch.LongTensor(relative_positions)
nums = torch.FloatTensor(nums)
mask = torch.BoolTensor(mask)

neighbors_pos = torch.LongTensor(neighbors_pos)
neighbors_index = torch.LongTensor(neighbors_index)

neighbor_species = len(all_species) * np.ones([len(self.neighbor_species), max_num], dtype = int)
for i in range(len(self.neighbor_species)):
now = np.array(self.neighbor_species[i])
now = np.array([np.where(all_species == specie)[0][0] for specie in now])
neighbor_species[i, :len(now)] = now
neighbor_species = torch.LongTensor(neighbor_species)

return neighbors_pos, neighbors_index, nums, mask, neighbor_species, relative_positions

0 comments on commit bffd80c

Please sign in to comment.