-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathCrossformer.py
145 lines (125 loc) · 6.21 KB
/
Crossformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from layers.Crossformer_EncDec import scale_block, Encoder, Decoder, DecoderLayer
from layers.Embed import PatchEmbedding
from layers.SelfAttention_Family import AttentionLayer, FullAttention, TwoStageAttentionLayer
from models.PatchTST import FlattenHead
from math import ceil
class Model(nn.Module):
"""
Paper link: https://openreview.net/pdf?id=vSVLM2j9eie
"""
def __init__(self, configs):
super(Model, self).__init__()
self.enc_in = configs.enc_in
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.seg_len = 12
self.win_size = 2
self.task_name = configs.task_name
# The padding operation to handle invisible sgemnet length
self.pad_in_len = ceil(1.0 * configs.seq_len / self.seg_len) * self.seg_len
self.pad_out_len = ceil(1.0 * configs.pred_len / self.seg_len) * self.seg_len
self.in_seg_num = self.pad_in_len // self.seg_len
self.out_seg_num = ceil(self.in_seg_num / (self.win_size ** (configs.e_layers - 1)))
self.head_nf = configs.d_model * self.out_seg_num
# Embedding
self.enc_value_embedding = PatchEmbedding(configs.d_model, self.seg_len, self.seg_len, self.pad_in_len - configs.seq_len, 0)
self.enc_pos_embedding = nn.Parameter(
torch.randn(1, configs.enc_in, self.in_seg_num, configs.d_model))
self.pre_norm = nn.LayerNorm(configs.d_model)
# Encoder
self.encoder = Encoder(
[
scale_block(configs, 1 if l is 0 else self.win_size, configs.d_model, configs.n_heads, configs.d_ff,
1, configs.dropout,
self.in_seg_num if l is 0 else ceil(self.in_seg_num / self.win_size ** l), configs.factor
) for l in range(configs.e_layers)
]
)
# Decoder
self.dec_pos_embedding = nn.Parameter(
torch.randn(1, configs.enc_in, (self.pad_out_len // self.seg_len), configs.d_model))
self.decoder = Decoder(
[
DecoderLayer(
TwoStageAttentionLayer(configs, (self.pad_out_len // self.seg_len), configs.factor, configs.d_model, configs.n_heads,
configs.d_ff, configs.dropout),
AttentionLayer(
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False),
configs.d_model, configs.n_heads),
self.seg_len,
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
# activation=configs.activation,
)
for l in range(configs.e_layers + 1)
],
)
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len,
head_dropout=configs.dropout)
elif self.task_name == 'classification':
self.flatten = nn.Flatten(start_dim=-2)
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(
self.head_nf * configs.enc_in, configs.num_class)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# embedding
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d = n_vars)
x_enc += self.enc_pos_embedding
x_enc = self.pre_norm(x_enc)
enc_out, attns = self.encoder(x_enc)
dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat=x_enc.shape[0])
dec_out = self.decoder(dec_in, enc_out)
return dec_out
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
# embedding
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
x_enc += self.enc_pos_embedding
x_enc = self.pre_norm(x_enc)
enc_out, attns = self.encoder(x_enc)
dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1)
return dec_out
def anomaly_detection(self, x_enc):
# embedding
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
x_enc += self.enc_pos_embedding
x_enc = self.pre_norm(x_enc)
enc_out, attns = self.encoder(x_enc)
dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1)
return dec_out
def classification(self, x_enc, x_mark_enc):
# embedding
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
x_enc += self.enc_pos_embedding
x_enc = self.pre_norm(x_enc)
enc_out, attns = self.encoder(x_enc)
# Output from Non-stationary Transformer
output = self.flatten(enc_out[-1].permute(0, 1, 3, 2))
output = self.dropout(output)
output = output.reshape(output.shape[0], -1)
output = self.projection(output)
return output
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'imputation':
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
return None