From 714978fb67aca25054a93ff52cc19d34b50d7ce0 Mon Sep 17 00:00:00 2001 From: lx Date: Wed, 12 Feb 2025 17:21:14 +0800 Subject: [PATCH 1/2] 'add-card' --- baselines/CARD/ETTm2.py | 158 +++++++++++++++++++++++++++++++ baselines/CARD/Electricity.py | 158 +++++++++++++++++++++++++++++++ baselines/CARD/Weather.py | 158 +++++++++++++++++++++++++++++++ baselines/CARD/arch/__init__.py | 1 + baselines/CARD/arch/attention.py | 134 ++++++++++++++++++++++++++ baselines/CARD/arch/card_arch.py | 110 +++++++++++++++++++++ baselines/CARD/loss/__init__.py | 1 + baselines/CARD/loss/loss.py | 49 ++++++++++ 8 files changed, 769 insertions(+) create mode 100644 baselines/CARD/ETTm2.py create mode 100644 baselines/CARD/Electricity.py create mode 100644 baselines/CARD/Weather.py create mode 100644 baselines/CARD/arch/__init__.py create mode 100644 baselines/CARD/arch/attention.py create mode 100644 baselines/CARD/arch/card_arch.py create mode 100644 baselines/CARD/loss/__init__.py create mode 100644 baselines/CARD/loss/loss.py diff --git a/baselines/CARD/ETTm2.py b/baselines/CARD/ETTm2.py new file mode 100644 index 0000000..3e6624c --- /dev/null +++ b/baselines/CARD/ETTm2.py @@ -0,0 +1,158 @@ +import os +import sys +from easydict import EasyDict +sys.path.append(os.path.abspath(__file__ + '/../../..')) +from basicts.metrics import masked_mae, masked_mse +from basicts.data import TimeSeriesForecastingDataset +from basicts.runners import SimpleTimeSeriesForecastingRunner +from basicts.scaler import ZScoreScaler +from basicts.utils import get_regular_settings + +from .arch import CARD +from .loss import card_loss + +############################## Hot Parameters ############################## +# Dataset & Metrics configuration +DATA_NAME = 'ETTm2' # Dataset name +regular_settings = get_regular_settings(DATA_NAME) +INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence +OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence +TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios +NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data +RESCALE = regular_settings['RESCALE'] # Whether to rescale the data +NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data +# Model architecture and parameters +MODEL_ARCH = CARD +NUM_NODES = 7 +MODEL_PARAM = { + "enc_in": NUM_NODES, # num nodes + "dec_in": NUM_NODES, + "c_out": NUM_NODES, + "seq_len": INPUT_LEN, + "pred_len": OUTPUT_LEN, # prediction sequence length + "patch_len": 16, + "stride": 8, + "d_model": 16, + "d_ff": 32, + "use_statistic": False, + "e_layers": 2, + "n_heads": 2, + "dropout": 0.2, + "momentum": 0.1, + "dp_rank": 8, + "merge_size": 2, + "alpha": 0.5, + "over_hidden": False, + "trianable_smooth": False, + "untoken": False, + "use_h_loss": True, + "model_token_number": 0, + "num_time_features": 4, # number of used time features + "time_of_day_size": 96, + "day_of_week_size": 7, + "day_of_month_size": 31, + "day_of_year_size": 366 + } +NUM_EPOCHS = 100 + +############################## General Configuration ############################## +CFG = EasyDict() +# General settings +CFG.DESCRIPTION = 'An Example Config' +CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) +# Runner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner + +############################## Dataset Configuration ############################## +CFG.DATASET = EasyDict() +# Dataset settings +CFG.DATASET.NAME = DATA_NAME +CFG.DATASET.TYPE = TimeSeriesForecastingDataset +CFG.DATASET.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, + 'input_len': INPUT_LEN, + 'output_len': OUTPUT_LEN, + # 'mode' is automatically set by the runner +}) + +############################## Scaler Configuration ############################## +CFG.SCALER = EasyDict() +# Scaler settings +CFG.SCALER.TYPE = ZScoreScaler # Scaler class +CFG.SCALER.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_ratio': TRAIN_VAL_TEST_RATIO[0], + 'norm_each_channel': NORM_EACH_CHANNEL, + 'rescale': RESCALE, +}) + +############################## Model Configuration ############################## +CFG.MODEL = EasyDict() +# Model settings +CFG.MODEL.NAME = MODEL_ARCH.__name__ +CFG.MODEL.ARCH = MODEL_ARCH +CFG.MODEL.PARAM = MODEL_PARAM +CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4] +CFG.MODEL.TARGET_FEATURES = [0] + +############################## Metrics Configuration ############################## + +CFG.METRICS = EasyDict() +# Metrics settings +CFG.METRICS.FUNCS = EasyDict({ + 'MAE': masked_mae, + 'MSE': masked_mse + }) +CFG.METRICS.TARGET = 'MSE' +CFG.METRICS.NULL_VAL = NULL_VAL + +############################## Training Configuration ############################## +CFG.TRAIN = EasyDict() +CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + MODEL_ARCH.__name__, + '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) +) +CFG.TRAIN.LOSS = card_loss +# Optimizer settings +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "Adam" +CFG.TRAIN.OPTIM.PARAM = { + "lr": 0.0001, +} +# Learning rate scheduler settings +CFG.TRAIN.LR_SCHEDULER = EasyDict() +CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" +CFG.TRAIN.LR_SCHEDULER.PARAM = { + "milestones": [1, 25, 50], + "gamma": 0.5 +} +CFG.TRAIN.CLIP_GRAD_PARAM = { + 'max_norm': 5.0 +} +# Train data loader settings +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 32 +CFG.TRAIN.DATA.SHUFFLE = True + +############################## Validation Configuration ############################## +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.BATCH_SIZE = 32 + +############################## Test Configuration ############################## +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +CFG.TEST.DATA = EasyDict() +CFG.TEST.DATA.BATCH_SIZE = 32 + +############################## Evaluation Configuration ############################## + +CFG.EVAL = EasyDict() + +# Evaluation parameters +CFG.EVAL.HORIZONS = [12, 24, 48, 96] +CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True diff --git a/baselines/CARD/Electricity.py b/baselines/CARD/Electricity.py new file mode 100644 index 0000000..7e05b7f --- /dev/null +++ b/baselines/CARD/Electricity.py @@ -0,0 +1,158 @@ +import os +import sys +from easydict import EasyDict +sys.path.append(os.path.abspath(__file__ + '/../../..')) +from basicts.metrics import masked_mae, masked_mse +from basicts.data import TimeSeriesForecastingDataset +from basicts.runners import SimpleTimeSeriesForecastingRunner +from basicts.scaler import ZScoreScaler +from basicts.utils import get_regular_settings + +from .arch import CARD +from .loss import card_loss + +############################## Hot Parameters ############################## +# Dataset & Metrics configuration +DATA_NAME = 'Electricity' # Dataset name +regular_settings = get_regular_settings(DATA_NAME) +INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence +OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence +TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios +NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data +RESCALE = regular_settings['RESCALE'] # Whether to rescale the data +NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data +# Model architecture and parameters +MODEL_ARCH = CARD +NUM_NODES = 321 +MODEL_PARAM = { + "enc_in": NUM_NODES, # num nodes + "dec_in": NUM_NODES, + "c_out": NUM_NODES, + "seq_len": INPUT_LEN, + "pred_len": OUTPUT_LEN, # prediction sequence length + "patch_len": 16, + "stride": 8, + "d_model": 128, + "d_ff": 256, + "use_statistic": False, + "e_layers": 2, + "n_heads": 16, + "dropout": 0.2, + "momentum": 0.1, + "dp_rank": 8, + "merge_size": 2, + "alpha": 0.5, + "over_hidden": False, + "trianable_smooth": False, + "untoken": False, + "use_h_loss": True, + "model_token_number": 0, + "num_time_features": 4, # number of used time features + "time_of_day_size": 24, + "day_of_week_size": 7, + "day_of_month_size": 31, + "day_of_year_size": 366 + } +NUM_EPOCHS = 100 + +############################## General Configuration ############################## +CFG = EasyDict() +# General settings +CFG.DESCRIPTION = 'An Example Config' +CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) +# Runner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner + +############################## Dataset Configuration ############################## +CFG.DATASET = EasyDict() +# Dataset settings +CFG.DATASET.NAME = DATA_NAME +CFG.DATASET.TYPE = TimeSeriesForecastingDataset +CFG.DATASET.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, + 'input_len': INPUT_LEN, + 'output_len': OUTPUT_LEN, + # 'mode' is automatically set by the runner +}) + +############################## Scaler Configuration ############################## +CFG.SCALER = EasyDict() +# Scaler settings +CFG.SCALER.TYPE = ZScoreScaler # Scaler class +CFG.SCALER.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_ratio': TRAIN_VAL_TEST_RATIO[0], + 'norm_each_channel': NORM_EACH_CHANNEL, + 'rescale': RESCALE, +}) + +############################## Model Configuration ############################## +CFG.MODEL = EasyDict() +# Model settings +CFG.MODEL.NAME = MODEL_ARCH.__name__ +CFG.MODEL.ARCH = MODEL_ARCH +CFG.MODEL.PARAM = MODEL_PARAM +CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4] +CFG.MODEL.TARGET_FEATURES = [0] + +############################## Metrics Configuration ############################## + +CFG.METRICS = EasyDict() +# Metrics settings +CFG.METRICS.FUNCS = EasyDict({ + 'MAE': masked_mae, + 'MSE': masked_mse + }) +CFG.METRICS.TARGET = 'MSE' +CFG.METRICS.NULL_VAL = NULL_VAL + +############################## Training Configuration ############################## +CFG.TRAIN = EasyDict() +CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + MODEL_ARCH.__name__, + '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) +) +CFG.TRAIN.LOSS = card_loss +# Optimizer settings +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "Adam" +CFG.TRAIN.OPTIM.PARAM = { + "lr": 0.0001, +} +# Learning rate scheduler settings +CFG.TRAIN.LR_SCHEDULER = EasyDict() +CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" +CFG.TRAIN.LR_SCHEDULER.PARAM = { + "milestones": [1, 25, 50], + "gamma": 0.5 +} +CFG.TRAIN.CLIP_GRAD_PARAM = { + 'max_norm': 5.0 +} +# Train data loader settings +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 32 +CFG.TRAIN.DATA.SHUFFLE = True + +############################## Validation Configuration ############################## +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.BATCH_SIZE = 32 + +############################## Test Configuration ############################## +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +CFG.TEST.DATA = EasyDict() +CFG.TEST.DATA.BATCH_SIZE = 32 + +############################## Evaluation Configuration ############################## + +CFG.EVAL = EasyDict() + +# Evaluation parameters +CFG.EVAL.HORIZONS = [12, 24, 48, 96] +CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True diff --git a/baselines/CARD/Weather.py b/baselines/CARD/Weather.py new file mode 100644 index 0000000..fec96f1 --- /dev/null +++ b/baselines/CARD/Weather.py @@ -0,0 +1,158 @@ +import os +import sys +from easydict import EasyDict +sys.path.append(os.path.abspath(__file__ + '/../../..')) +from basicts.metrics import masked_mae, masked_mse +from basicts.data import TimeSeriesForecastingDataset +from basicts.runners import SimpleTimeSeriesForecastingRunner +from basicts.scaler import ZScoreScaler +from basicts.utils import get_regular_settings + +from .arch import CARD +from .loss import card_loss + +############################## Hot Parameters ############################## +# Dataset & Metrics configuration +DATA_NAME = 'Weather' # Dataset name +regular_settings = get_regular_settings(DATA_NAME) +INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence +OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence +TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios +NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data +RESCALE = regular_settings['RESCALE'] # Whether to rescale the data +NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data +# Model architecture and parameters +MODEL_ARCH = CARD +NUM_NODES = 21 +MODEL_PARAM = { + "enc_in": NUM_NODES, # num nodes + "dec_in": NUM_NODES, + "c_out": NUM_NODES, + "seq_len": INPUT_LEN, + "pred_len": OUTPUT_LEN, # prediction sequence length + "patch_len": 16, + "stride": 8, + "d_model": 128, + "d_ff": 256, + "use_statistic": False, + "e_layers": 2, + "n_heads": 16, + "dropout": 0.2, + "momentum": 0.1, + "dp_rank": 8, + "merge_size": 2, + "alpha": 0.5, + "over_hidden": False, + "trianable_smooth": False, + "untoken": False, + "use_h_loss": True, + "model_token_number": 0, + "num_time_features": 4, # number of used time features + "time_of_day_size": 144, + "day_of_week_size": 7, + "day_of_month_size": 31, + "day_of_year_size": 366 + } +NUM_EPOCHS = 100 + +############################## General Configuration ############################## +CFG = EasyDict() +# General settings +CFG.DESCRIPTION = 'An Example Config' +CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) +# Runner +CFG.RUNNER = SimpleTimeSeriesForecastingRunner + +############################## Dataset Configuration ############################## +CFG.DATASET = EasyDict() +# Dataset settings +CFG.DATASET.NAME = DATA_NAME +CFG.DATASET.TYPE = TimeSeriesForecastingDataset +CFG.DATASET.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, + 'input_len': INPUT_LEN, + 'output_len': OUTPUT_LEN, + # 'mode' is automatically set by the runner +}) + +############################## Scaler Configuration ############################## +CFG.SCALER = EasyDict() +# Scaler settings +CFG.SCALER.TYPE = ZScoreScaler # Scaler class +CFG.SCALER.PARAM = EasyDict({ + 'dataset_name': DATA_NAME, + 'train_ratio': TRAIN_VAL_TEST_RATIO[0], + 'norm_each_channel': NORM_EACH_CHANNEL, + 'rescale': RESCALE, +}) + +############################## Model Configuration ############################## +CFG.MODEL = EasyDict() +# Model settings +CFG.MODEL.NAME = MODEL_ARCH.__name__ +CFG.MODEL.ARCH = MODEL_ARCH +CFG.MODEL.PARAM = MODEL_PARAM +CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4] +CFG.MODEL.TARGET_FEATURES = [0] + +############################## Metrics Configuration ############################## + +CFG.METRICS = EasyDict() +# Metrics settings +CFG.METRICS.FUNCS = EasyDict({ + 'MAE': masked_mae, + 'MSE': masked_mse + }) +CFG.METRICS.TARGET = 'MSE' +CFG.METRICS.NULL_VAL = NULL_VAL + +############################## Training Configuration ############################## +CFG.TRAIN = EasyDict() +CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + MODEL_ARCH.__name__, + '_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) +) +CFG.TRAIN.LOSS = card_loss +# Optimizer settings +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = "Adam" +CFG.TRAIN.OPTIM.PARAM = { + "lr": 0.0001, +} +# Learning rate scheduler settings +CFG.TRAIN.LR_SCHEDULER = EasyDict() +CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" +CFG.TRAIN.LR_SCHEDULER.PARAM = { + "milestones": [1, 25, 50], + "gamma": 0.5 +} +CFG.TRAIN.CLIP_GRAD_PARAM = { + 'max_norm': 5.0 +} +# Train data loader settings +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 32 +CFG.TRAIN.DATA.SHUFFLE = True + +############################## Validation Configuration ############################## +CFG.VAL = EasyDict() +CFG.VAL.INTERVAL = 1 +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.BATCH_SIZE = 32 + +############################## Test Configuration ############################## +CFG.TEST = EasyDict() +CFG.TEST.INTERVAL = 1 +CFG.TEST.DATA = EasyDict() +CFG.TEST.DATA.BATCH_SIZE = 32 + +############################## Evaluation Configuration ############################## + +CFG.EVAL = EasyDict() + +# Evaluation parameters +CFG.EVAL.HORIZONS = [12, 24, 48, 96] +CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True diff --git a/baselines/CARD/arch/__init__.py b/baselines/CARD/arch/__init__.py new file mode 100644 index 0000000..0ae2e02 --- /dev/null +++ b/baselines/CARD/arch/__init__.py @@ -0,0 +1 @@ +from .card_arch import CARD \ No newline at end of file diff --git a/baselines/CARD/arch/attention.py b/baselines/CARD/arch/attention.py new file mode 100644 index 0000000..dc8cda0 --- /dev/null +++ b/baselines/CARD/arch/attention.py @@ -0,0 +1,134 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + + +class Transpose(nn.Module): + def __init__(self, *dims, contiguous=False): + super().__init__() + self.dims, self.contiguous = dims, contiguous + def forward(self, x): + if self.contiguous: return x.transpose(*self.dims).contiguous() + else: return x.transpose(*self.dims) + + +class Attenion(nn.Module): + def __init__(self,config, over_hidden = False,trianable_smooth = False,untoken = False): + super().__init__() + + self.over_hidden = over_hidden + self.untoken = untoken + self.n_heads = config.n_heads + self.c_in = config.enc_in + self.qkv = nn.Linear(config.d_model, config.d_model * 3, bias=True) + + self.attn_dropout = nn.Dropout(config.dropout) + self.head_dim = config.d_model // config.n_heads + + self.dropout_mlp = nn.Dropout(config.dropout) + self.mlp = nn.Linear( config.d_model, config.d_model) + + self.norm_post1 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(config.d_model,momentum = config.momentum), Transpose(1,2)) + self.norm_post2 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(config.d_model,momentum = config.momentum), Transpose(1,2)) + + self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(config.d_model,momentum = config.momentum), Transpose(1,2)) + + self.dp_rank = config.dp_rank + self.dp_k = nn.Linear(self.head_dim, self.dp_rank) + self.dp_v = nn.Linear(self.head_dim, self.dp_rank) + + self.ff_1 = nn.Sequential(nn.Linear(config.d_model, config.d_ff, bias=True), + nn.GELU(), + nn.Dropout(config.dropout), + nn.Linear(config.d_ff, config.d_model, bias=True) + ) + + self.ff_2= nn.Sequential(nn.Linear(config.d_model, config.d_ff, bias=True), + nn.GELU(), + nn.Dropout(config.dropout), + nn.Linear(config.d_ff, config.d_model, bias=True) + ) + self.merge_size = config.merge_size + + ema_size = max(config.enc_in,config.total_token_number,config.dp_rank) + ema_matrix = torch.zeros((ema_size,ema_size)) + alpha = config.alpha + ema_matrix[0][0] = 1 + for i in range(1,config.total_token_number): + for j in range(i): + ema_matrix[i][j] = ema_matrix[i-1][j]*(1-alpha) + ema_matrix[i][i] = alpha + self.register_buffer('ema_matrix',ema_matrix) + + + def ema(self,src): + return torch.einsum('bnhad,ga ->bnhgd',src,self.ema_matrix[:src.shape[-2],:src.shape[-2]]) + + def ema_trianable(self,src): + alpha = F.sigmoid(self.alpha) + weights = alpha * (1 - alpha) ** self.arange[-src.shape[-2]:] + w_f = torch.fft.rfft(weights,n = src.shape[-2]*2) + src_f = torch.fft.rfft(src.float(),dim = -2,n = src.shape[-2]*2) + src_f = (src_f.permute(0,1,2,4,3)*w_f) + src1 =torch.fft.irfft(src_f.float(),dim = -1,n=src.shape[-2]*2)[...,:src.shape[-2]].permute(0,1,2,4,3)#.half() + return src1 + + def dynamic_projection(self,src,mlp): + src_dp = mlp(src) + src_dp = F.softmax(src_dp,dim = -1) + src_dp = torch.einsum('bnhef,bnhec -> bnhcf',src,src_dp) + return src_dp + + def forward(self, src): + B,nvars, H, C, = src.shape + + qkv = self.qkv(src).reshape(B,nvars, H, 3, self.n_heads, C // self.n_heads).permute(3, 0, 1,4, 2, 5) + + q, k, v = qkv[0], qkv[1], qkv[2] + + if not self.over_hidden: + + attn_score_along_token = torch.einsum('bnhed,bnhfd->bnhef', self.ema(q), self.ema(k))/ self.head_dim ** -0.5 + attn_along_token = self.attn_dropout(F.softmax(attn_score_along_token, dim=-1) ) + output_along_token = torch.einsum('bnhef,bnhfd->bnhed', attn_along_token, v) + + else: + v_dp,k_dp = self.dynamic_projection(v,self.dp_v) , self.dynamic_projection(k,self.dp_k) + attn_score_along_token = torch.einsum('bnhed,bnhfd->bnhef', self.ema(q), self.ema(k_dp))/ self.head_dim ** -0.5 + + + attn_along_token = self.attn_dropout(F.softmax(attn_score_along_token, dim=-1) ) + output_along_token = torch.einsum('bnhef,bnhfd->bnhed', attn_along_token, v_dp) + + attn_score_along_hidden = torch.einsum('bnhae,bnhaf->bnhef', q,k)/ q.shape[-2] ** -0.5 + attn_along_hidden = self.attn_dropout(F.softmax(attn_score_along_hidden, dim=-1) ) + output_along_hidden = torch.einsum('bnhef,bnhaf->bnhae', attn_along_hidden, v) + + merge_size = self.merge_size + if not self.untoken: + output1 = rearrange(output_along_token.reshape(B*nvars,-1,self.head_dim), + 'bn (hl1 hl2 hl3) d -> bn hl2 (hl3 hl1) d', + hl1 = self.n_heads//merge_size, hl2 = output_along_token.shape[-2] ,hl3 = merge_size + ).reshape(B*nvars,-1,self.head_dim*self.n_heads) + + + output2 = rearrange(output_along_hidden.reshape(B*nvars,-1,self.head_dim), + 'bn (hl1 hl2 hl3) d -> bn hl2 (hl3 hl1) d', + hl1 = self.n_heads//merge_size, hl2 = output_along_token.shape[-2] ,hl3 = merge_size + ).reshape(B*nvars,-1,self.head_dim*self.n_heads) + + + output1 = self.norm_post1(output1) + output1 = output1.reshape(B,nvars, -1, self.n_heads * self.head_dim) + output2 = self.norm_post2(output2) + output2 = output2.reshape(B,nvars, -1, self.n_heads * self.head_dim) + + src2 = self.ff_1(output1)+self.ff_2(output2) + + src = src + src2 + src = src.reshape(B*nvars, -1, self.n_heads * self.head_dim) + src = self.norm_attn(src) + + src = src.reshape(B,nvars, -1, self.n_heads * self.head_dim) + return src \ No newline at end of file diff --git a/baselines/CARD/arch/card_arch.py b/baselines/CARD/arch/card_arch.py new file mode 100644 index 0000000..588b866 --- /dev/null +++ b/baselines/CARD/arch/card_arch.py @@ -0,0 +1,110 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange +import numpy as np + +from basicts.utils import data_transformation_4_xformer +from .attention import Attenion, Transpose + +from argparse import Namespace + +import pdb + +class CARD(nn.Module): + + def __init__(self, **config): + super().__init__() + config = Namespace(**config) + self.patch_len = config.patch_len + self.pred_len = config.pred_len + self.stride = config.stride + self.d_model = config.d_model + patch_num = int((config.seq_len - self.patch_len)/self.stride + 1) + self.patch_num = patch_num + self.W_pos_embed = nn.Parameter(torch.randn(patch_num,config.d_model)*1e-2) + self.model_token_number = config.model_token_number + + if self.model_token_number > 0: + self.model_token = nn.Parameter(torch.randn(config.enc_in,self.model_token_number,config.d_model)*1e-2) + + self.total_token_number = (self.patch_num + self.model_token_number + 1) + config.total_token_number = self.total_token_number + + self.W_input_projection = nn.Linear(self.patch_len, config.d_model) + self.input_dropout = nn.Dropout(config.dropout) + + self.use_statistic = config.use_statistic + self.use_h_loss = config.use_h_loss + self.W_statistic = nn.Linear(2,config.d_model) + self.cls = nn.Parameter(torch.randn(1,config.d_model)*1e-2) + + self.W_out = nn.Linear((patch_num+1+self.model_token_number)*config.d_model, config.pred_len) + + self.Attentions_over_token = nn.ModuleList([Attenion(config) for i in range(config.e_layers)]) + self.Attentions_over_channel = nn.ModuleList([Attenion(config,over_hidden = True) for i in range(config.e_layers)]) + self.Attentions_mlp = nn.ModuleList([nn.Linear(config.d_model,config.d_model) for i in range(config.e_layers)]) + self.Attentions_dropout = nn.ModuleList([nn.Dropout(config.dropout) for i in range(config.e_layers)]) + self.Attentions_norm = nn.ModuleList([nn.Sequential(Transpose(1,2), nn.BatchNorm1d(config.d_model,momentum = config.momentum), Transpose(1,2)) for i in range(config.e_layers)]) + + self.time_of_day_size = config.time_of_day_size + + def forward_xformer(self, x_enc: torch.Tensor, x_mark_enc: torch.Tensor, x_dec: torch.Tensor, x_mark_dec: torch.Tensor) -> torch.Tensor: + + z = x_enc.transpose(1,2) + b,c,s = z.shape + # use-norm + z_mean = torch.mean(z,dim = (-1),keepdims = True) + z_std = torch.std(z,dim = (-1),keepdims = True) + z = (z - z_mean)/(z_std + 1e-4) + + zcube = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) + z_embed = self.input_dropout(self.W_input_projection(zcube))+ self.W_pos_embed + + if self.use_statistic: + z_stat = torch.cat((z_mean,z_std),dim = -1) + if z_stat.shape[-2]>1: + z_stat = (z_stat - torch.mean(z_stat,dim =-2,keepdims = True))/( torch.std(z_stat,dim =-2,keepdims = True)+1e-4) + z_stat = self.W_statistic(z_stat) + z_embed = torch.cat((z_stat.unsqueeze(-2),z_embed),dim = -2) + + else: + cls_token = self.cls.repeat(z_embed.shape[0],z_embed.shape[1],1,1) + z_embed = torch.cat((cls_token,z_embed),dim = -2) + + inputs = z_embed + b,c,t,h = inputs.shape + for a_2,a_1,mlp,drop,norm in zip(self.Attentions_over_token, self.Attentions_over_channel,self.Attentions_mlp ,self.Attentions_dropout,self.Attentions_norm ): + output_1 = a_1(inputs.permute(0,2,1,3)).permute(0,2,1,3) + output_2 = a_2(output_1) + outputs = drop(mlp(output_1+output_2))+inputs + outputs = norm(outputs.reshape(b*c,t,-1)).reshape(b,c,t,-1) + inputs = outputs + + # de-norm + z_out = self.W_out(outputs.reshape(b,c,-1)) + z = z_out *(z_std+1e-4) + z_mean + + return z + + def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, + **kwargs) -> torch.Tensor: + """ + + Args: + history_data (Tensor): Input data with shape: [B, L1, N, C] + future_data (Tensor): Future data with shape: [B, L2, N, C] + + Returns: + torch.Tensor: outputs with shape [B, L2, N, 1] + """ + history_data[..., 1] = history_data[..., 1] * self.time_of_day_size // (self.time_of_day_size / 24) / 23.0 + x_enc, x_mark_enc, x_dec, x_mark_dec = data_transformation_4_xformer(history_data=history_data, + future_data=future_data, + start_token_len=0) + #print(x_mark_enc.shape, x_mark_dec.shape) + prediction = self.forward_xformer(x_enc=x_enc, x_mark_enc=x_mark_enc, x_dec=x_dec, x_mark_dec=x_mark_dec) + + return {"prediction": prediction.transpose(1,2).unsqueeze(-1), + "use_h_loss": self.use_h_loss, + "pred_len": self.pred_len} \ No newline at end of file diff --git a/baselines/CARD/loss/__init__.py b/baselines/CARD/loss/__init__.py new file mode 100644 index 0000000..42f72fc --- /dev/null +++ b/baselines/CARD/loss/__init__.py @@ -0,0 +1 @@ +from .loss import card_loss \ No newline at end of file diff --git a/baselines/CARD/loss/loss.py b/baselines/CARD/loss/loss.py new file mode 100644 index 0000000..bbc887d --- /dev/null +++ b/baselines/CARD/loss/loss.py @@ -0,0 +1,49 @@ +import torch +import torch.nn.functional as F +import numpy as np +from basicts.metrics import masked_mse +import pdb + + +def h_loss(outputs, batch_y, ratio): + + h_level_range = [4,8,16,24,48,96] + for h_level in h_level_range: + batch,length,channel = outputs.shape + # print(outputs.shape) + h_outputs = outputs.transpose(-1,-2).reshape(batch,channel,-1,h_level) + h_outputs = torch.mean(h_outputs,dim = -1,keepdims = True) + h_batch_y = batch_y.transpose(-1,-2).reshape(batch,channel,-1,h_level) + h_batch_y = torch.mean(h_batch_y,dim = -1,keepdims = True) + h_ratio = ratio[:h_outputs.shape[-2],:] + # print(h_outputs.shape,h_ratio.shape) + h_ouputs_agg = torch.mean(h_outputs,dim = 1,keepdims = True) + h_batch_y_agg = torch.mean(h_batch_y,dim = 1,keepdims = True) + + h_outputs = h_outputs*h_ratio + h_batch_y = h_batch_y*h_ratio + + h_ouputs_agg *= h_ratio + h_batch_y_agg *= h_ratio + + loss_1 = F.l1_loss(h_outputs, h_batch_y)*np.sqrt(h_level) / 2 + loss_2 = F.l1_loss(h_ouputs_agg, h_batch_y_agg)*np.sqrt(h_level) / 2 + + return loss_1 + loss_2 + +def card_loss(prediction, target, use_h_loss, pred_len): + outputs, batch_y = prediction.squeeze(-1), target.squeeze(-1) + + ratio = np.array([max(1/np.sqrt(i+1),0.0) for i in range(pred_len)]) + ratio = torch.tensor(ratio).unsqueeze(-1).to(prediction.device) + outputs = outputs * ratio + batch_y = batch_y * ratio + loss = F.l1_loss(prediction, target) + + if not use_h_loss: + return loss + + else: + return loss + h_loss(outputs, batch_y, ratio) * 1e-2 + + From d036e6ba868b79be31fcc336ddb4adb178e92810 Mon Sep 17 00:00:00 2001 From: lx Date: Wed, 12 Feb 2025 17:26:36 +0800 Subject: [PATCH 2/2] 'update-card' --- baselines/CARD/arch/card_arch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baselines/CARD/arch/card_arch.py b/baselines/CARD/arch/card_arch.py index 588b866..35d21b3 100644 --- a/baselines/CARD/arch/card_arch.py +++ b/baselines/CARD/arch/card_arch.py @@ -107,4 +107,4 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_s return {"prediction": prediction.transpose(1,2).unsqueeze(-1), "use_h_loss": self.use_h_loss, - "pred_len": self.pred_len} \ No newline at end of file + "pred_len": self.pred_len}