Skip to content

Commit

Permalink
final tweaks to flow version, still not good enough, will try another…
Browse files Browse the repository at this point in the history
… day
  • Loading branch information
lucidrains committed Sep 26, 2024
1 parent 3dcd6fa commit 6be13c4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 47 deletions.
65 changes: 19 additions & 46 deletions autoregressive_diffusion_pytorch/autoregressive_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,41 +61,7 @@ def unpack_one(to_unpack, unpack_pattern = None):

return packed, unpack_one

# sinusoidal embedding

class AdaptiveLayerNorm(Module):
def __init__(
self,
dim,
dim_condition = None
):
super().__init__()
dim_condition = default(dim_condition, dim)

self.ln = nn.LayerNorm(dim, elementwise_affine = False)
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
nn.init.zeros_(self.to_gamma.weight)

def forward(self, x, *, condition):
normed = self.ln(x)
gamma = self.to_gamma(condition)
return normed * (gamma + 1.)

class LearnedSinusoidalPosEmb(Module):
def __init__(self, dim):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))

def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered

# gaussian diffusion
# rectified flow

class Flow(Module):
def __init__(
Expand Down Expand Up @@ -184,7 +150,7 @@ def __init__(
dim_head = 64,
heads = 8,
mlp_depth = 3,
mlp_width = None,
mlp_width = 1024,
dim_input = None,
decoder_kwargs: dict = dict(),
mlp_kwargs: dict = dict(),
Expand All @@ -211,11 +177,13 @@ def __init__(
**decoder_kwargs
)

self.to_cond_emb = nn.Linear(dim, dim, bias = False)

self.denoiser = MLP(
dim_cond = dim,
dim_input = dim_input,
depth = mlp_depth,
width = default(mlp_width, dim),
width = mlp_width,
**mlp_kwargs
)

Expand All @@ -229,20 +197,15 @@ def __init__(
def device(self):
return next(self.transformer.parameters()).device

def add_abs_pos_emb(self, seq):
seq_len = seq.shape[1]

def axial_pos_emb(self):
# prepare maybe axial positional embedding

pos_emb, *rest_pos_embs = self.abs_pos_emb

for rest_pos_emb in rest_pos_embs:
pos_emb = einx.add('i d, j d -> (i j) d', pos_emb, rest_pos_emb)

pos_emb = F.pad(pos_emb, (0, 0, 1, 0), value = 0.) # account for start token

seq = seq + pos_emb[:seq_len]
return seq
return F.pad(pos_emb, (0, 0, 1, 0), value = 0.)

@torch.no_grad()
def sample(
Expand All @@ -266,12 +229,18 @@ def sample(
cond = self.proj_in(out)

cond = torch.cat((start_tokens, cond), dim = 1)
cond = self.add_abs_pos_emb(cond)

seq_len = cond.shape[-2]
axial_pos_emb = self.axial_pos_emb()
cond += axial_pos_emb[:seq_len]

cond, cache = self.transformer(cond, cache = cache, return_hiddens = True)

last_cond = cond[:, -1]

last_cond += axial_pos_emb[seq_len]
last_cond = self.to_cond_emb(last_cond)

denoised_pred = self.flow.sample(cond = last_cond)

denoised_pred = rearrange(denoised_pred, 'b d -> b 1 d')
Expand Down Expand Up @@ -303,10 +272,14 @@ def forward(

seq = torch.cat((start_token, seq), dim = 1)

seq = self.add_abs_pos_emb(seq)
axial_pos_emb = self.axial_pos_emb()
seq = seq + axial_pos_emb[:seq_len]

cond = self.transformer(seq)

cond = cond + axial_pos_emb[1:(seq_len + 1)]
cond = self.to_cond_emb(cond)

# pack batch and sequence dimensions, so to train each token with different noise levels

target, _ = pack_one(target, '* d')
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "autoregressive-diffusion-pytorch"
version = "0.2.5"
version = "0.2.7"
description = "Autoregressive Diffusion - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 6be13c4

Please sign in to comment.