-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnetworks.py
1648 lines (1320 loc) · 58.3 KB
/
networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# This file contains the network architectures and models used in the project's scripts.
# It includes the implementations of the 3D U-Net, our Conditional 3D U-Net,
# the Video Diffusion Models (VDM), and our Video to Video Diffusion Model (Vid2Vid-DM).
import math
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import rearrange
from einops_exts import check_shape, rearrange_many
from rotary_embedding_torch import RotaryEmbedding
from torch import einsum, nn
from tqdm import tqdm
class EMA():
def __init__(self, beta: float):
"""
Exponential Moving Average (EMA) class.
Args:
beta (float): The decay factor for the moving average.
"""
super().__init__()
self.beta = beta
def update_model_average(self, ma_model: nn.Module, current_model: nn.Module) -> None:
"""
Updates the moving average of the model's parameters.
Args:
ma_model (nn.Module): The model with the moving average parameters.
current_model (nn.Module): The model with the current parameters.
"""
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old: Optional[torch.Tensor], new: torch.Tensor) -> torch.Tensor:
"""
Updates the moving average of a parameter tensor.
Args:
old (Optional[torch.Tensor]): The previous moving average tensor.
new (torch.Tensor): The new tensor to be incorporated into the moving average.
Returns:
torch.Tensor: The updated moving average tensor.
"""
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def normalize_img(img: torch.Tensor) -> torch.Tensor:
"""
Normalize an image tensor to the range [-1, 1].
Args:
img (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Normalized image tensor.
"""
return img * 2 - 1
def unnormalize_img(img: torch.Tensor) -> torch.Tensor:
"""
Unnormalize an image tensor from the range [-1, 1] to [0, 1].
Args:
img (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Unnormalized image tensor.
"""
return (img + 1) * 0.5
def extract(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor:
"""
Extract elements from tensor 'a' using indices 't' and reshape the result.
Args:
a (torch.Tensor): Input tensor.
t (torch.Tensor): Index tensor.
x_shape (Tuple[int]): Shape of the output tensor (excluding batch dimension).
Returns:
torch.Tensor: Extracted and reshaped tensor.
"""
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor:
"""
Generate a cosine beta schedule for use in training.
Args:
timesteps (int): Number of timesteps in the schedule.
s (float): Offset parameter for the cosine function.
Returns:
torch.Tensor: Cosine beta schedule tensor.
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(
((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.9999)
class RelativePositionBias(nn.Module):
def __init__(
self,
heads: int = 8,
num_buckets: int = 32,
max_distance: int = 128
):
"""
Relative Position Bias module.
Args:
heads (int): Number of attention heads.
num_buckets (int): Number of buckets for relative positions.
max_distance (int): Maximum distance for relative positions.
"""
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position: torch.Tensor,
num_buckets: int = 32,
max_distance: int = 128
) -> torch.Tensor:
"""
Computes the bucket index for a given relative position.
Args:
relative_position (torch.Tensor): Relative position tensor.
num_buckets (int): Number of buckets.
max_distance (int): Maximum distance.
Returns:
torch.Tensor: Bucket index tensor.
"""
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance /
max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(
val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n: int, device: torch.device) -> torch.Tensor:
"""
Forward pass of the RelativePositionBias module.
Args:
n (int): Sequence length.
device (torch.device): Device to be used.
Returns:
torch.Tensor: Relative position bias tensor.
"""
q_pos = torch.arange(n, dtype=torch.long, device=device)
k_pos = torch.arange(n, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')
class Residual(nn.Module):
def __init__(self, fn: nn.Module):
"""
Residual module that adds the input tensor to the output of the given function.
Args:
fn (nn.Module): The function or module to apply.
"""
super().__init__()
self.fn = fn
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Forward pass of the Residual module.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor with the residual connection added.
"""
return self.fn(x, *args, **kwargs) + x
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim: int):
"""
Sinusoidal Position Embedding module.
Args:
dim (int): Dimension of the position embedding.
"""
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the SinusoidalPosEmb module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Position embedding tensor.
"""
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def Upsample(dim: int) -> nn.Module:
"""
Upsampling module using 3D transposed convolution.
Args:
dim (int): Number of input and output channels.
Returns:
nn.Module: Upsampling module.
"""
return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
def Downsample(dim: int) -> nn.Module:
"""
Downsampling module using 3D convolution.
Args:
dim (int): Number of input and output channels.
Returns:
nn.Module: Downsampling module.
"""
return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
class LayerNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
"""
Layer normalization module.
Args:
dim (int): Number of input channels.
eps (float): Small value added to the denominator for numerical stability.
"""
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the LayerNorm module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized output tensor.
"""
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma
class PreNorm(nn.Module):
def __init__(self, dim: int, fn: nn.Module):
"""
PreNormalization module that applies layer normalization before the given function.
Args:
dim (int): Number of input channels.
fn (nn.Module): The function or module to apply.
"""
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Forward pass of the PreNorm module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying layer normalization and the function.
"""
x = self.norm(x)
return self.fn(x, **kwargs)
class Block(nn.Module):
def __init__(self, dim: int, dim_out: int, groups: int = 8):
"""
Block module consisting of convolution, group normalization, and activation.
Args:
dim (int): Number of input channels.
dim_out (int): Number of output channels.
groups (int): Number of groups for group normalization.
"""
super().__init__()
self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor, scale_shift: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the Block module.
Args:
x (torch.Tensor): Input tensor.
scale_shift (torch.Tensor): Scale and shift parameters for conditional computation.
Returns:
torch.Tensor: Output tensor after applying convolution, normalization, and activation.
"""
x = self.proj(x)
x = self.norm(x)
if scale_shift is not None:
scale, shift = scale_shift
x = x * (scale + 1) + shift
return self.act(x)
class ResnetBlock(nn.Module):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int = None, groups: int = 8):
"""
Residual block module with optional time embedding.
Args:
dim (int): Number of input channels.
dim_out (int): Number of output channels.
time_emb_dim (int): Dimension of the time embedding. Defaults to None.
groups (int): Number of groups for group normalization. Defaults to 8.
"""
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if time_emb_dim is not None else None
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv3d(
dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x: torch.Tensor, time_emb: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the ResnetBlock module.
Args:
x (torch.Tensor): Input tensor.
time_emb (torch.Tensor): Time embedding tensor. Defaults to None.
Returns:
torch.Tensor: Output tensor after applying the residual block.
"""
scale_shift = None
if self.mlp is not None:
assert time_emb is not None, 'time emb must be passed in'
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
scale_shift = time_emb.chunk(2, dim=1)
h = self.block1(x, scale_shift=scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class SpatialLinearAttention(nn.Module):
def __init__(self, dim: int, heads: int = 4, dim_head: int = 32):
"""
Spatial Linear Attention module.
Args:
dim (int): Number of input channels.
heads (int): Number of attention heads. Defaults to 4.
dim_head (int): Dimension of each attention head. Defaults to 32.
"""
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the SpatialLinearAttention module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying spatial linear attention.
"""
b, c, f, h, w = x.shape
x = rearrange(x, 'b c f h w -> (b f) c h w')
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = rearrange_many(
qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)
q = q.softmax(dim=-2)
k = k.softmax(dim=-1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y',
h=self.heads, x=h, y=w)
out = self.to_out(out)
return rearrange(out, '(b f) c h w -> b c f h w', b=b)
class EinopsToAndFrom(nn.Module):
def __init__(self, from_einops: str, to_einops: str, fn):
"""
Module that converts input tensor from one einops format to another format using a given function.
Args:
from_einops (str): Source einops format.
to_einops (str): Target einops format.
fn: Function to be applied to the tensor in the target format.
"""
super().__init__()
self.from_einops = from_einops
self.to_einops = to_einops
self.fn = fn
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Forward pass of the EinopsToAndFrom module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after converting to the target einops format and applying the function.
"""
shape = x.shape
reconstitute_kwargs = dict(
tuple(zip(self.from_einops.split(' '), shape)))
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
x = self.fn(x, **kwargs)
x = rearrange(
x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
heads: int = 4,
dim_head: int = 32,
rotary_emb=None
):
"""
Attention module that performs multi-head self-attention on the input tensor.
Args:
dim (int): Input tensor dimension.
heads (int): Number of attention heads.
dim_head (int): Dimension of each attention head.
rotary_emb: Rotary positional embedding module.
"""
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.rotary_emb = rotary_emb
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
self.to_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(
self,
x: torch.Tensor,
pos_bias: torch.Tensor = None,
focus_present_mask: torch.Tensor = None
) -> torch.Tensor:
"""
Forward pass of the Attention module.
Args:
x (torch.Tensor): Input tensor.
pos_bias (torch.Tensor): Relative positional bias tensor.
focus_present_mask (torch.Tensor): Mask indicating which tokens are focusing on the present.
Returns:
torch.Tensor: Output tensor after performing self-attention.
"""
n, device = x.shape[-2], x.device
qkv = self.to_qkv(x).chunk(3, dim=-1)
if focus_present_mask is not None and focus_present_mask.all():
values = qkv[-1]
return self.to_out(values)
q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)
q = q * self.scale
if self.rotary_emb is not None:
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
sim = einsum('... h i d, ... h j d -> ... h i j', q, k)
if pos_bias is not None:
sim = sim + pos_bias
if focus_present_mask is not None and not (~focus_present_mask).all():
attend_all_mask = torch.ones(
(n, n), device=device, dtype=torch.bool)
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
mask = torch.where(
rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum('... h i j, ... h j d -> ... h i d', attn, v)
out = rearrange(out, '... h n d -> ... n (h d)')
return self.to_out(out)
# Cond_UNet model
class Cond_UNet(nn.Module):
def __init__(
self,
im_size: int,
dim_mults: tuple = (1, 2, 4, 8),
channels: int = 3,
cond_channels: int = 3,
attn_heads: int = 8,
attn_dim_head: int = 32,
init_kernel_size: int = 7,
resnet_groups: int = 8
):
"""
Conditional UNet model, used as the Vdieo UNet of the Vid2Vid-DM.
Args:
im_size (int): Size of the input frame (assumed to be square).
dim_mults (tuple): Dimension multipliers for different levels of the UNet.
channels (int): Number of input channels.
cond_channels (int): Number of conditional channels (e.g., for conditioning on depth video).
attn_heads (int): Number of attention heads.
attn_dim_head (int): Dimension of each attention head.
init_kernel_size (int): Kernel size for the initial convolutional layer.
resnet_groups (int): Number of groups for the ResNet blocks.
"""
super().__init__()
self.im_size = im_size
self.dim_mults = dim_mults
self.channels = channels
self.cond_channels = cond_channels
self.in_channels = self.channels + self.cond_channels
self.attn_heads = attn_heads
self.attn_dim_head = attn_dim_head
self.init_kernel_size = init_kernel_size
self.resnet_groups = resnet_groups
# temporal attention and its relative positional encoding
rotary_emb = RotaryEmbedding(min(32, self.attn_dim_head))
def temporal_attn(dim: int) -> EinopsToAndFrom:
"""
Create a temporal attention module.
Args:
dim (int): Input dimension.
Returns:
EinopsToAndFrom: Temporal attention module wrapped in an EinopsToAndFrom module.
"""
return EinopsToAndFrom(
'b c f h w', 'b (h w) f c',
Attention(dim, heads=self.attn_heads, dim_head=self.attn_dim_head,
rotary_emb=rotary_emb))
self.time_rel_pos_bias = RelativePositionBias(
heads=self.attn_heads, max_distance=32)
# initial conv
assert (self.init_kernel_size % 2) == 1
self.init_padding = self.init_kernel_size // 2
self.init_conv = nn.Conv3d(self.in_channels, self.im_size, (
1, self.init_kernel_size, self.init_kernel_size), padding=(0, self.init_padding, self.init_padding))
self.init_temporal_attn = Residual(
PreNorm(self.im_size, temporal_attn(self.im_size)))
# dimensions
self.dims = [self.im_size, *
map(lambda m: self.im_size * m, self.dim_mults)]
self.in_out = list(zip(self.dims[:-1], self.dims[1:]))
# time conditioning
self.time_dim = self.im_size * 4
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.im_size),
nn.Linear(self.im_size, self.time_dim),
nn.GELU(),
nn.Linear(self.time_dim, self.time_dim)
)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
self.num_resolutions = len(self.in_out)
# block type
block_klass = partial(ResnetBlock, groups=resnet_groups)
block_klass_cond = partial(block_klass, time_emb_dim=self.time_dim)
# modules for all downstream path layers
for ind, (dim_in, dim_out) in enumerate(self.in_out):
is_last = ind >= (self.num_resolutions - 1)
self.downs.append(nn.ModuleList([
block_klass_cond(dim_in*2, dim_out),
block_klass_cond(dim_out, dim_out),
Residual(PreNorm(dim_out, SpatialLinearAttention(
dim_out, heads=self.attn_heads))),
Residual(PreNorm(dim_out, temporal_attn(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity()
]))
self.mid_dim = self.dims[-1]
self.mid_block1 = block_klass_cond(self.mid_dim, self.mid_dim)
self.spatial_attn = EinopsToAndFrom(
'b c f h w', 'b f (h w) c', Attention(self.mid_dim, heads=self.attn_heads))
self.mid_spatial_attn = Residual(
PreNorm(self.mid_dim, self.spatial_attn))
self.mid_temporal_attn = Residual(
PreNorm(self.mid_dim, temporal_attn(self.mid_dim)))
self.mid_block2 = block_klass_cond(self.mid_dim, self.mid_dim)
# modules for all upstream path layers
for ind, (dim_in, dim_out) in enumerate(reversed(self.in_out)):
is_last = ind >= (self.num_resolutions - 1)
self.ups.append(nn.ModuleList([
block_klass_cond(dim_out * 2, dim_in),
block_klass_cond(dim_in, dim_in),
Residual(PreNorm(dim_in, SpatialLinearAttention(
dim_in, heads=self.attn_heads))),
Residual(PreNorm(dim_in, temporal_attn(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity()
]))
self.final_conv = nn.Sequential(
block_klass(self.im_size * 2, self.im_size),
nn.Conv3d(self.im_size, self.channels, 1)
)
def forward(
self,
x: torch.Tensor,
time: torch.Tensor,
cond_list: List[torch.Tensor] = []) -> torch.Tensor:
"""
Forward pass of the Cond_UNet model.
Args:
x (torch.Tensor): Input video tensor.
time (torch.Tensor): Time tensor.
cond_list (List[torch.Tensor]): List of output tensors from the UNet model.
Returns:
torch.Tensor: Output tensor of the model.
"""
batch, device = x.shape[0], x.device
focus_present_mask = torch.zeros(
(batch,), device=device, dtype=torch.bool)
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)
x = torch.cat((x, cond_list.pop()), dim=1)
x = self.init_conv(x)
x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)
r = x.clone()
t = self.time_mlp(time) if self.time_mlp is not None else None
h = []
# Downstream
for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
x = torch.cat((x, cond_list.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = spatial_attn(x)
x = temporal_attn(x, pos_bias=time_rel_pos_bias,
focus_present_mask=focus_present_mask)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_spatial_attn(x)
x = self.mid_temporal_attn(
x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
x = self.mid_block2(x, t)
# Upstream
for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = spatial_attn(x)
x = temporal_attn(x, pos_bias=time_rel_pos_bias,
focus_present_mask=focus_present_mask)
x = upsample(x)
x = torch.cat((x, r), dim=1)
return self.final_conv(x)
# UNet model
class UNet(nn.Module):
def __init__(
self,
im_size: int,
out_list: bool = False,
dim_mults: Tuple[int] = (1, 2, 4, 8),
channels: int = 3,
attn_heads: int = 8,
attn_dim_head: int = 32,
init_kernel_size: int = 7,
resnet_groups: int = 8
):
"""
Initializes the UNet model.
Args:
im_size (int): Input frame size.
out_list (bool): Whether to output a list of intermediate feature maps (True used for Vid2Vid-DM).
dim_mults (Tuple[int]): Dimension multipliers for each U-Net layer.
channels (int): Number of input channels.
attn_heads (int): Number of attention heads.
attn_dim_head (int): Dimension of each attention head.
init_kernel_size (int): Kernel size for the initial convolution.
resnet_groups (int): Number of groups for GroupNorm in the residual blocks.
"""
super().__init__()
self.im_size = im_size
self.out_list = out_list
self.dim_mults = dim_mults
self.channels = channels
self.attn_heads = attn_heads
self.attn_dim_head = attn_dim_head
self.init_kernel_size = init_kernel_size
self.resnet_groups = resnet_groups
# temporal attention and its relative positional encoding
rotary_emb = RotaryEmbedding(min(32, self.attn_dim_head))
def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(
dim, heads=self.attn_heads, dim_head=self.attn_dim_head, rotary_emb=rotary_emb))
# realistically will not be able to generate that many frames of video... yet
self.time_rel_pos_bias = RelativePositionBias(
heads=self.attn_heads, max_distance=32)
# initial conv
assert (self.init_kernel_size % 2) == 1
self.init_padding = self.init_kernel_size // 2
self.init_conv = nn.Conv3d(self.channels, self.im_size, (1, self.init_kernel_size,
self.init_kernel_size), padding=(0, self.init_padding, self.init_padding))
self.init_temporal_attn = Residual(
PreNorm(self.im_size, temporal_attn(self.im_size)))
# dimensions
self.dims = [self.im_size, *map(lambda m: self.im_size * m, dim_mults)]
self.in_out = list(zip(self.dims[:-1], self.dims[1:]))
# time conditioning
self.time_dim = self.im_size * 4
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.im_size),
nn.Linear(self.im_size, self.time_dim),
nn.GELU(),
nn.Linear(self.time_dim, self.time_dim)
)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
self.num_resolutions = len(self.in_out)
# block type
block_klass = partial(ResnetBlock, groups=resnet_groups)
block_klass_cond = partial(block_klass, time_emb_dim=self.time_dim)
# modules for all downstream path layers
for ind, (dim_in, dim_out) in enumerate(self.in_out):
is_last = ind >= (self.num_resolutions - 1)
self.downs.append(nn.ModuleList([
block_klass_cond(dim_in, dim_out),
block_klass_cond(dim_out, dim_out),
Residual(PreNorm(dim_out, SpatialLinearAttention(
dim_out, heads=self.attn_heads))),
Residual(PreNorm(dim_out, temporal_attn(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity()
]))
self.mid_dim = self.dims[-1]
self.mid_block1 = block_klass_cond(self.mid_dim, self.mid_dim)
self.spatial_attn = EinopsToAndFrom(
'b c f h w', 'b f (h w) c', Attention(self.mid_dim, heads=self.attn_heads))
self.mid_spatial_attn = Residual(
PreNorm(self.mid_dim, self.spatial_attn))
self.mid_temporal_attn = Residual(
PreNorm(self.mid_dim, temporal_attn(self.mid_dim)))
self.mid_block2 = block_klass_cond(self.mid_dim, self.mid_dim)
# modules for all upstream path layers
for ind, (dim_in, dim_out) in enumerate(reversed(self.in_out)):
is_last = ind >= (self.num_resolutions - 1)
self.ups.append(nn.ModuleList([
block_klass_cond(dim_out * 2, dim_in),
block_klass_cond(dim_in, dim_in),
Residual(PreNorm(dim_in, SpatialLinearAttention(
dim_in, heads=attn_heads))),
Residual(PreNorm(dim_in, temporal_attn(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity()
]))
self.out_dims = [self.channels] + self.dims[:self.num_resolutions]
self.out_res = [int(self.im_size*self.im_size/list_dim)
for list_dim in self.dims]
self.final_conv = nn.Sequential(
block_klass(self.im_size * 2, self.im_size),
nn.Conv3d(self.im_size, self.channels, 1)
)
def forward(
self,
x: torch.Tensor,
time: torch.Tensor,
cond_dropout: float = 0.
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Forward pass of the UNet model.
Args:
x (torch.Tensor): Input video tensor.
time (torch.Tensor): Time tensor.
cond_dropout (float): Dropout probability for the conditional. Defaults to 0.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: Output tensor or a list of intermediate feature maps
depending on the 'out_list' argument passed during initialization.
"""
batch, frames, device = x.shape[0], x.shape[2], x.device
if torch.rand(1) <= cond_dropout:
output = []
for dim, res in zip(reversed(self.out_dims), reversed(self.out_res)):
output.append(torch.zeros(
(batch, dim, frames, res, res), device=device))
return output
focus_present_mask = torch.zeros(
(batch,), device=device, dtype=torch.bool)
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)
x = self.init_conv(x)
x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)
r = x.clone()
t = self.time_mlp(time) if self.time_mlp is not None else None
h = []
# Downstream
for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = spatial_attn(x)
x = temporal_attn(x, pos_bias=time_rel_pos_bias,
focus_present_mask=focus_present_mask)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_spatial_attn(x)
x = self.mid_temporal_attn(
x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
x = self.mid_block2(x, t)
output = []
# Upstream
for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = spatial_attn(x)
x = temporal_attn(x, pos_bias=time_rel_pos_bias,
focus_present_mask=focus_present_mask)
output.append(x)
x = upsample(x)
x = torch.cat((x, r), dim=1)
x = self.final_conv(x)
output.append(x)
if self.out_list:
return output
else:
return x
# VDM gaussian diffusion model
class GaussianDiffusion(nn.Module):
def __init__(
self,
denoise_fn: nn.Module,
image_size: int,
num_frames: int,
channels: int = 3,
timesteps: int = 1000
):