-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathmace_block.py
749 lines (666 loc) · 24.6 KB
/
mace_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
###########################################################################################
# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch.nn.functional
from e3nn import nn, o3
from e3nn.util.jit import compile_mode
from mace_irreps_tools import (
linear_out_irreps,
reshape_irreps,
tp_out_irreps_with_instructions,
)
from mace_radial import BesselBasis, GaussianBasis, PolynomialCutoff
from mace_symmetric_contraction import SymmetricContraction
def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
@torch.jit.script
def scatter_sum(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum",
) -> torch.Tensor:
assert reduce == "sum" # for now, TODO
index = _broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
@compile_mode("script")
class LinearNodeEmbeddingBlock(torch.nn.Module):
def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps):
super().__init__()
self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out)
def forward(
self,
node_attrs: torch.Tensor,
) -> torch.Tensor: # [n_nodes, irreps]
return self.linear(node_attrs)
@compile_mode("script")
class LinearReadoutBlock(torch.nn.Module):
def __init__(self, irreps_in: o3.Irreps):
super().__init__()
self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=o3.Irreps("1x0e"))
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 1]
@compile_mode("script")
class NonLinearReadoutBlock(torch.nn.Module):
def __init__(
self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, gate: Optional[Callable]
):
super().__init__()
self.hidden_irreps = MLP_irreps
self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps)
self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate])
self.linear_2 = o3.Linear(
irreps_in=self.hidden_irreps, irreps_out=o3.Irreps("9x0e")
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.non_linearity(self.linear_1(x))
return self.linear_2(x) # [n_nodes, 1]
@compile_mode("script")
class LinearDipoleReadoutBlock(torch.nn.Module):
def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False):
super().__init__()
if dipole_only:
self.irreps_out = o3.Irreps("1x1o")
else:
self.irreps_out = o3.Irreps("1x0e + 1x1o")
self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
return self.linear(x) # [n_nodes, 1]
@compile_mode("script")
class NonLinearDipoleReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
gate: Callable,
dipole_only: bool = False,
):
super().__init__()
self.hidden_irreps = MLP_irreps
if dipole_only:
self.irreps_out = o3.Irreps("1x1o")
else:
self.irreps_out = o3.Irreps("1x0e + 1x1o")
irreps_scalars = o3.Irreps(
[(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out]
)
irreps_gated = o3.Irreps(
[(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out]
)
irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated)
self.equivariant_nonlin = nn.Gate(
irreps_scalars=irreps_scalars,
act_scalars=[gate for _, ir in irreps_scalars],
irreps_gates=irreps_gates,
act_gates=[gate] * len(irreps_gates),
irreps_gated=irreps_gated,
)
self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify()
self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin)
self.linear_2 = o3.Linear(
irreps_in=self.hidden_irreps, irreps_out=self.irreps_out
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.equivariant_nonlin(self.linear_1(x))
return self.linear_2(x) # [n_nodes, 1]
@compile_mode("script")
class RadialEmbeddingBlock(torch.nn.Module):
def __init__(
self,
r_max: float,
num_bessel: int,
num_polynomial_cutoff: int,
radial_type: str = "bessel",
):
super().__init__()
if radial_type == "bessel":
self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel)
elif radial_type == "gaussian":
self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel)
self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff)
self.out_dim = num_bessel
def forward(
self,
edge_lengths: torch.Tensor, # [n_edges, 1]
):
radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis]
cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1]
return radial * cutoff # [n_edges, n_basis]
@compile_mode("script")
class EquivariantProductBasisBlock(torch.nn.Module):
def __init__(
self,
node_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
correlation: int,
use_sc: bool = True,
num_elements: Optional[int] = None,
) -> None:
super().__init__()
self.use_sc = use_sc
self.symmetric_contractions = SymmetricContraction(
irreps_in=node_feats_irreps,
irreps_out=target_irreps,
correlation=correlation,
num_elements=num_elements,
)
# Update linear
self.linear = o3.Linear(
target_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
)
def forward(
self,
node_feats: torch.Tensor,
sc: Optional[torch.Tensor],
node_attrs: torch.Tensor,
) -> torch.Tensor:
node_feats = self.symmetric_contractions(node_feats, node_attrs)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)
@compile_mode("script")
class InteractionBlock(torch.nn.Module):
def __init__(
self,
node_attrs_irreps: o3.Irreps,
node_feats_irreps: o3.Irreps,
edge_attrs_irreps: o3.Irreps,
edge_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
hidden_irreps: o3.Irreps,
avg_num_neighbors: float,
radial_MLP: Optional[List[int]] = None,
) -> None:
super().__init__()
self.node_attrs_irreps = node_attrs_irreps
self.node_feats_irreps = node_feats_irreps
self.edge_attrs_irreps = edge_attrs_irreps
self.edge_feats_irreps = edge_feats_irreps
self.target_irreps = target_irreps
self.hidden_irreps = hidden_irreps
self.avg_num_neighbors = avg_num_neighbors
if radial_MLP is None:
radial_MLP = [64, 64, 64]
self.radial_MLP = radial_MLP
self._setup()
@abstractmethod
def _setup(self) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh}
@compile_mode("script")
class TensorProductWeightsBlock(torch.nn.Module):
def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int):
super().__init__()
weights = torch.empty(
(num_elements, num_edge_feats, num_feats_out),
dtype=torch.get_default_dtype(),
)
torch.nn.init.xavier_uniform_(weights)
self.weights = torch.nn.Parameter(weights)
def forward(
self,
sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded
edge_feats: torch.Tensor,
):
return torch.einsum(
"be, ba, aek -> bk", edge_feats, sender_or_receiver_node_attrs, self.weights
)
def __repr__(self):
return (
f'{self.__class__.__name__}(shape=({", ".join(str(s) for s in self.weights.shape)}), '
f"weights={np.prod(self.weights.shape)})"
)
@compile_mode("script")
class ResidualElementDependentInteractionBlock(InteractionBlock):
def _setup(self) -> None:
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
self.conv_tp_weights = TensorProductWeightsBlock(
num_elements=self.node_attrs_irreps.num_irreps,
num_edge_feats=self.edge_feats_irreps.num_irreps,
num_feats_out=self.conv_tp.weight_numel,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
self.irreps_out = self.irreps_out.simplify()
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out
)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
return message + sc # [n_nodes, irreps]
@compile_mode("script")
class AgnosticNonlinearInteractionBlock(InteractionBlock):
def _setup(self) -> None:
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
self.irreps_out = self.irreps_out.simplify()
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.irreps_out, self.node_attrs_irreps, self.irreps_out
)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
tp_weights = self.conv_tp_weights(edge_feats)
node_feats = self.linear_up(node_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
message = self.skip_tp(message, node_attrs)
return message # [n_nodes, irreps]
@compile_mode("script")
class AgnosticResidualNonlinearInteractionBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
self.irreps_out = self.irreps_out.simplify()
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out
)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> torch.Tensor:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
message = message + sc
return message # [n_nodes, irreps]
@compile_mode("script")
class RealAgnosticInteractionBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.irreps_out, self.node_attrs_irreps, self.irreps_out
)
self.reshape = reshape_irreps(self.irreps_out)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
message = self.skip_tp(message, node_attrs)
return (
self.reshape(message),
None,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
input_dim = self.edge_feats_irreps.num_irreps
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + self.radial_MLP + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
)
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
self.reshape = reshape_irreps(self.irreps_out)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_tp(node_feats, node_attrs)
node_feats = self.linear_up(node_feats)
tp_weights = self.conv_tp_weights(edge_feats)
mji = self.conv_tp(
node_feats[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class RealAgnosticAttResidualInteractionBlock(InteractionBlock):
def _setup(self) -> None:
self.node_feats_down_irreps = o3.Irreps("64x0e")
# First linear
self.linear_up = o3.Linear(
self.node_feats_irreps,
self.node_feats_irreps,
internal_weights=True,
shared_weights=True,
)
# TensorProduct
irreps_mid, instructions = tp_out_irreps_with_instructions(
self.node_feats_irreps,
self.edge_attrs_irreps,
self.target_irreps,
)
self.conv_tp = o3.TensorProduct(
self.node_feats_irreps,
self.edge_attrs_irreps,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
# Convolution weights
self.linear_down = o3.Linear(
self.node_feats_irreps,
self.node_feats_down_irreps,
internal_weights=True,
shared_weights=True,
)
input_dim = (
self.edge_feats_irreps.num_irreps
+ 2 * self.node_feats_down_irreps.num_irreps
)
self.conv_tp_weights = nn.FullyConnectedNet(
[input_dim] + 3 * [256] + [self.conv_tp.weight_numel],
torch.nn.functional.silu,
)
# Linear
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps
self.linear = o3.Linear(
irreps_mid,
self.irreps_out,
internal_weights=True,
shared_weights=True,
)
self.reshape = reshape_irreps(self.irreps_out)
# Skip connection.
self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps)
def forward(
self,
node_attrs: torch.Tensor,
node_feats: torch.Tensor,
edge_attrs: torch.Tensor,
edge_feats: torch.Tensor,
edge_index: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
sender = edge_index[0]
receiver = edge_index[1]
num_nodes = node_feats.shape[0]
sc = self.skip_linear(node_feats)
node_feats_up = self.linear_up(node_feats)
node_feats_down = self.linear_down(node_feats)
augmented_edge_feats = torch.cat(
[
edge_feats,
node_feats_down[sender],
node_feats_down[receiver],
],
dim=-1,
)
tp_weights = self.conv_tp_weights(augmented_edge_feats)
mji = self.conv_tp(
node_feats_up[sender], edge_attrs, tp_weights
) # [n_edges, irreps]
message = scatter_sum(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]
message = self.linear(message) / self.avg_num_neighbors
return (
self.reshape(message),
sc,
) # [n_nodes, channels, (lmax + 1)**2]
@compile_mode("script")
class ScaleShiftBlock(torch.nn.Module):
def __init__(self, scale: float, shift: float):
super().__init__()
self.register_buffer(
"scale", torch.tensor(scale, dtype=torch.get_default_dtype())
)
self.register_buffer(
"shift", torch.tensor(shift, dtype=torch.get_default_dtype())
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.scale * x + self.shift
def __repr__(self):
return (
f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})"
)