diff --git a/configs/climate_projection.yaml b/configs/climate_projection.yaml new file mode 100644 index 0000000..633d772 --- /dev/null +++ b/configs/climate_projection.yaml @@ -0,0 +1,121 @@ +seed_everything: 42 + +# ---------------------------- TRAINER ------------------------------------------- +trainer: + default_root_dir: ${oc.env:AMLT_OUTPUT_DIR,/home/tungnd/ClimaX/exps/climate_projection_climax} + + precision: 16 + + gpus: null + num_nodes: 1 + accelerator: gpu + strategy: ddp + + min_epochs: 1 + max_epochs: 50 + enable_progress_bar: true + + sync_batchnorm: True + enable_checkpointing: True + resume_from_checkpoint: null + + # debugging + fast_dev_run: false + + logger: + class_path: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + init_args: + save_dir: ${trainer.default_root_dir}/logs + name: null + version: null + log_graph: False + default_hp_metric: True + prefix: "" + + callbacks: + - class_path: pytorch_lightning.callbacks.LearningRateMonitor + init_args: + logging_interval: "step" + + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + dirpath: "${trainer.default_root_dir}/checkpoints/" + monitor: "val/w_mse" # name of the logged metric which determines when model is improving + mode: "min" # "max" means higher metric value is better, can be also "min" + save_top_k: 1 # save k best models (determined by above metric) + save_last: True # additionaly always save model from last epoch + verbose: False + filename: "epoch_{epoch:03d}" + auto_insert_metric_name: False + + - class_path: pytorch_lightning.callbacks.EarlyStopping + init_args: + monitor: "val/w_mse" # name of the logged metric which determines when model is improving + mode: "min" # "max" means higher metric value is better, can be also "min" + patience: 5 # how many validation epochs of not improving until training stops + min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement + + - class_path: pytorch_lightning.callbacks.RichModelSummary + init_args: + max_depth: -1 + + - class_path: pytorch_lightning.callbacks.RichProgressBar + +# ---------------------------- MODEL ------------------------------------------- +model: + lr: 5e-4 + beta_1: 0.9 + beta_2: 0.999 + weight_decay: 1e-5 + warmup_epochs: 60 + max_epochs: 600 + warmup_start_lr: 1e-8 + eta_min: 1e-8 + pretrained_path: "https://huggingface.co/tungnd/climax/resolve/main/5.625deg.ckpt" + + net: + class_path: climax.climate_projection.arch.ClimaXClimateBench + init_args: + default_vars: [ + 'CO2', + 'SO2', + 'CH4', + 'BC' + ] + out_vars: "tas" # diurnal_temperature_range, tas, pr, pr90 + img_size: [32, 64] + time_history: 10 + patch_size: 2 + embed_dim: 1024 + depth: 8 + num_heads: 16 + mlp_ratio: 4 + drop_path: 0.1 + drop_rate: 0.1 + parallel_patch_embed: False + freeze_encoder: True + +# ---------------------------- DATA ------------------------------------------- +data: + root_dir: /home/data/datasets/climate-learn/climatebench/5.625deg/ + history: 10 + list_train_simu: [ + 'ssp126', + 'ssp370', + 'ssp585', + 'historical', + 'hist-GHG', + 'hist-aer' + ] + list_test_simu: ['ssp245'] + variables: [ + 'CO2', + 'SO2', + 'CH4', + 'BC' + ] + out_variables: 'tas' + train_ratio: 0.9 + batch_size: 1 + num_workers: 1 + pin_memory: False diff --git a/docs/usage.md b/docs/usage.md index 36aee53..d42946e 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -118,6 +118,42 @@ python src/climax/regional_forecast/train.py --config configs/regional_forecast_ ``` To train ClimaX from scratch, set `--model.pretrained_path=""`. +## Climate Projection + +### Data Preparation + +First, download [ClimateBench](https://doi.org/10.5281/zenodo.5196512) data. ClimaX can work with either the original ClimateBench data or the regridded version. In the experiment in the paper, we regridded to ClimateBench data to 5.625 degree. To do that, run +```bash +python src/data_preprocessing/regrid_climatebench.py /mnt/data/climatebench/train_val \ + --save_path /mnt/data/climatebench/5.625deg/train_val --ddeg_out 5.625 +``` +and +```bash +python src/data_preprocessing/regrid_climatebench.py /mnt/data/climatebench/test \ + --save_path /mnt/data/climatebench/5.625deg/test --ddeg_out 5.625 +``` + +### Training + +To finetune ClimaX for climate projection, use +``` +python src/climax/climate_projection/train.py --config +``` +For example, to finetune ClimaX on 8 GPUs use +```bash +python python src/climax/climate_projection/train.py --config configs/climate_projection.yaml \ + --trainer.strategy=ddp --trainer.devices=8 \ + --trainer.max_epochs=50 \ + --data.root_dir=/mnt/data/climatebench/5.625deg \ + --data.out_variables="tas" \ + --data.batch_size=16 \ + --model.pretrained_path='https://huggingface.co/tungnd/climax/resolve/main/5.625deg.ckpt' \ + --model.out_vars="tas" \ + --model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \ + --model.weight_decay=1e-5 +``` +To train ClimaX from scratch, set `--model.pretrained_path=""`. + ## Visualization Coming soon diff --git a/src/climax/climate_projection/__init__.py b/src/climax/climate_projection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/climax/climate_projection/arch.py b/src/climax/climate_projection/arch.py new file mode 100644 index 0000000..2e1510e --- /dev/null +++ b/src/climax/climate_projection/arch.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import numpy as np +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +import torch +import torch.nn as nn +from climax.arch import ClimaX +from climax.utils.pos_embed import get_1d_sincos_pos_embed_from_grid + + +class ClimaXClimateBench(ClimaX): + def __init__( + self, + default_vars, + out_vars, + img_size=[32, 64], + time_history=1, + patch_size=2, + embed_dim=1024, + depth=8, + decoder_depth=2, + num_heads=16, + mlp_ratio=4.0, + drop_path=0.1, + drop_rate=0.1, + parallel_patch_embed=False, + freeze_encoder=False, + ): + assert out_vars is not None + + super().__init__( + default_vars, + img_size, + patch_size, + embed_dim, + depth, + decoder_depth, + num_heads, + mlp_ratio, + drop_path, + drop_rate, + parallel_patch_embed + ) + + self.out_vars = out_vars + self.time_history = time_history + self.freeze_encoder = freeze_encoder + + # used to aggregate multiple timesteps in the input + self.time_pos_embed = nn.Parameter(torch.zeros(1, time_history, embed_dim), requires_grad=True) + self.time_agg = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.time_query = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True) + + # initialize time embedding + time_pos_embed = get_1d_sincos_pos_embed_from_grid(self.time_pos_embed.shape[-1], np.arange(self.time_history)) + self.time_pos_embed.data.copy_(torch.from_numpy(time_pos_embed).float().unsqueeze(0)) + + # overwrite ClimaX + # use a linear prediction head for this task + self.head = nn.Linear(embed_dim, img_size[0]*img_size[1]) + + if freeze_encoder: + for name, p in self.blocks.named_parameters(): + name = name.lower() + # we do not freeze the norm layers, as suggested by https://arxiv.org/abs/2103.05247 + if 'norm' in name: + continue + else: + p.requires_grad_(False) + + def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables): + # x: `[B, T, V, H, W]` shape. + + if isinstance(variables, list): + variables = tuple(variables) + + b, t, _, _, _ = x.shape + x = x.flatten(0, 1) # BxT, V, H, W + + # tokenize each variable separately + embeds = [] + var_ids = self.get_var_ids(variables, x.device) + + if self.parallel_patch_embed: + x = self.token_embeds(x, var_ids) # BxT, V, L, D + else: + for i in range(len(var_ids)): + id = var_ids[i] + embeds.append(self.token_embeds[id](x[:, i : i + 1])) + x = torch.stack(embeds, dim=1) # BxT, V, L, D + + # add variable embedding + var_embed = self.get_var_emb(self.var_embed, variables) + x = x + var_embed.unsqueeze(2) # BxT, V, L, D + + # variable aggregation + x = self.aggregate_variables(x) # BxT, L, D + + # add pos embedding + x = x + self.pos_embed + + # add time embedding + # time emb: 1, T, D + x = x.unflatten(0, sizes=(b, t)) # B, T, L, D + x = x + self.time_pos_embed.unsqueeze(2) + + # add lead time embedding + lead_time_emb = self.lead_time_embed(lead_times.unsqueeze(-1)) # B, D + lead_time_emb = lead_time_emb.unsqueeze(1).unsqueeze(2) + x = x + lead_time_emb # B, T, L, D + + x = x.flatten(0, 1) # BxT, L, D + + x = self.pos_drop(x) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) # BxT, L, D + x = x.unflatten(0, sizes=(b, t)) # B, T, L, D + + # global average pooling, also used in CNN-LSTM baseline in ClimateBench + x = x.mean(-2) # B, T, D + time_query = self.time_query.repeat_interleave(x.shape[0], dim=0) + x, _ = self.time_agg(time_query, x, x) # B, 1, D + + return x + + def forward(self, x, y, lead_times, variables, out_variables, metric, lat): + x = self.forward_encoder(x, lead_times, variables) # B, 1, D + preds = self.head(x) + preds = preds.reshape(-1, 1, self.img_size[0], self.img_size[1]) # B, 1, H, W + if metric is None: + loss = None + else: + loss = [m(preds, y, out_variables, lat) for m in metric] + return loss, preds diff --git a/src/climax/climate_projection/datamodule.py b/src/climax/climate_projection/datamodule.py new file mode 100644 index 0000000..90e76f7 --- /dev/null +++ b/src/climax/climate_projection/datamodule.py @@ -0,0 +1,133 @@ +import os +from typing import Optional + +import numpy as np +import torch +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader + +from climax.climate_projection.dataset import ClimateBenchDataset, input_for_training, load_x_y, output_for_training, split_train_val + + +def collate_fn(batch): + inp = torch.stack([batch[i][0] for i in range(len(batch))]) + out = torch.stack([batch[i][1] for i in range(len(batch))]) + lead_times = torch.cat([batch[i][2] for i in range(len(batch))]) + variables = batch[0][3] + out_variables = batch[0][4] + return inp, out, lead_times, variables, out_variables + + +class ClimateBenchDataModule(LightningDataModule): + def __init__( + self, + root_dir, # contains metadata and train + val + test + history=10, + list_train_simu=[ + 'ssp126', + 'ssp370', + 'ssp585', + 'historical', + 'hist-GHG', + 'hist-aer' + ], + list_test_simu=[ + 'ssp245' + ], + variables=[ + 'CO2', + 'SO2', + 'CH4', + 'BC' + ], + out_variables='tas', + train_ratio=0.9, + batch_size: int = 128, + num_workers: int = 1, + pin_memory: bool = False, + ): + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + self.save_hyperparameters(logger=False) + + if isinstance(out_variables, str): + out_variables = [out_variables] + self.hparams.out_variables = out_variables + + # split train and val datasets + dict_x_train_val, dict_y_train_val, lat, lon = load_x_y(os.path.join(root_dir, 'train_val'), list_train_simu, out_variables) + self.lat, self.lon = lat, lon + x_train_val = np.concatenate([ + input_for_training( + dict_x_train_val[simu], skip_historical=(i<2), history=history, len_historical=165 + ) for i, simu in enumerate(dict_x_train_val.keys()) + ], axis = 0) # N, T, C, H, W + y_train_val = np.concatenate([ + output_for_training( + dict_y_train_val[simu], skip_historical=(i<2), history=history, len_historical=165 + ) for i, simu in enumerate(dict_y_train_val.keys()) + ], axis=0) # N, 1, H, W + x_train, y_train, x_val, y_val = split_train_val(x_train_val, y_train_val, train_ratio) + + self.dataset_train = ClimateBenchDataset( + x_train, y_train, variables, out_variables, lat, 'train' + ) + self.dataset_val = ClimateBenchDataset( + x_val, y_val, variables, out_variables, lat, 'val' + ) + self.dataset_val.set_normalize(self.dataset_train.inp_transform, self.dataset_train.out_transform) + + dict_x_test, dict_y_test, _, _ = load_x_y(os.path.join(root_dir, 'test'), list_test_simu, out_variables) + x_test = input_for_training( + dict_x_test[list_test_simu[0]], skip_historical=True, history=history, len_historical=165 + ) + y_test = output_for_training( + dict_y_test[list_test_simu[0]], skip_historical=True, history=history, len_historical=165 + ) + self.dataset_test = ClimateBenchDataset( + x_test, y_test, variables, out_variables, lat, 'test' + ) + self.dataset_test.set_normalize(self.dataset_train.inp_transform, self.dataset_train.out_transform) + + def get_lat_lon(self): + return self.lat, self.lon + + def set_patch_size(self, p): + self.patch_size = p + + def get_test_clim(self): + return self.dataset_test.y_normalization + + def train_dataloader(self): + return DataLoader( + self.dataset_train, + batch_size=self.hparams.batch_size, + shuffle=True, + # drop_last=True, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.dataset_val, + batch_size=self.hparams.batch_size, + shuffle=False, + # drop_last=True, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.dataset_test, + batch_size=self.hparams.batch_size, + shuffle=False, + # drop_last=True, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + collate_fn=collate_fn, + ) \ No newline at end of file diff --git a/src/climax/climate_projection/dataset.py b/src/climax/climate_projection/dataset.py new file mode 100644 index 0000000..ad60c1b --- /dev/null +++ b/src/climax/climate_projection/dataset.py @@ -0,0 +1,155 @@ +### Adapted from https://github.com/duncanwp/ClimateBench/blob/main/prep_input_data.ipynb + +import os + +import numpy as np +import torch +import xarray as xr +from torch.utils.data import Dataset +from torchvision.transforms import transforms + + +def load_x_y(data_path, list_simu, out_var): + x_all, y_all = {}, {} + for simu in list_simu: + input_name = 'inputs_' + simu + '.nc' + output_name = 'outputs_' + simu + '.nc' + if 'hist' in simu: + # load inputs + input_xr = xr.open_dataset(os.path.join(data_path, input_name)) + + # load outputs + output_xr = xr.open_dataset(os.path.join(data_path, output_name)).mean(dim='member') + output_xr = output_xr.assign({ + "pr": output_xr.pr * 86400, + "pr90": output_xr.pr90 * 86400 + }).rename({ + 'lon':'longitude', + 'lat': 'latitude' + }).transpose('time','latitude', 'longitude').drop(['quantile']) + + # Concatenate with historical data in the case of scenario 'ssp126', 'ssp370' and 'ssp585' + else: + # load inputs + input_xr = xr.open_mfdataset([ + os.path.join(data_path, 'inputs_historical.nc'), + os.path.join(data_path, input_name) + ]).compute() + + # load outputs + output_xr = xr.concat([ + xr.open_dataset(os.path.join(data_path, 'outputs_historical.nc')).mean(dim='member'), + xr.open_dataset(os.path.join(data_path, output_name)).mean(dim='member') + ], dim='time').compute() + output_xr = output_xr.assign({ + "pr": output_xr.pr * 86400, + "pr90": output_xr.pr90 * 86400 + }).rename({ + 'lon':'longitude', + 'lat': 'latitude' + }).transpose('time','latitude', 'longitude').drop(['quantile']) + + print(input_xr.dims, output_xr.dims, simu) + + x = input_xr.to_array().to_numpy() + x = x.transpose(1, 0, 2, 3).astype(np.float32) # N, C, H, W + x_all[simu] = x + + y = output_xr[out_var].to_array().to_numpy() # 1, N, H, W + # y = np.expand_dims(y, axis=1) # N, 1, H, W + y = y.transpose(1, 0, 2, 3).astype(np.float32) + y_all[simu] = y + + temp = xr.open_dataset(os.path.join(data_path, 'inputs_' + list_simu[0] + '.nc')).compute() + if 'latitude' in temp: + lat = np.array(temp['latitude']) + lon = np.array(temp['longitude']) + else: + lat = np.array(temp['lat']) + lon = np.array(temp['lon']) + + return x_all, y_all, lat, lon + +def input_for_training(x, skip_historical, history, len_historical): + time_length = x.shape[0] + # If we skip historical data, the first sequence created has as last element the first scenario data point + if skip_historical: + X_train_to_return = np.array([ + x[i:i+history] for i in range(len_historical-history+1, time_length-history+1) + ]) + # Else we just go through the whole dataset historical + scenario (does not matter in the case of 'hist-GHG' and 'hist_aer') + else: + X_train_to_return = np.array([x[i:i+history] for i in range(0, time_length-history+1)]) + + return X_train_to_return + +def output_for_training(y, skip_historical, history, len_historical): + time_length = y.shape[0] + # If we skip historical data, the first sequence created has as target element the first scenario data point + if skip_historical: + Y_train_to_return = np.array([ + y[i+history-1] for i in range(len_historical-history+1, time_length-history+1) + ]) + # Else we just go through the whole dataset historical + scenario (does not matter in the case of 'hist-GHG' and 'hist_aer') + else: + Y_train_to_return = np.array([y[i+history-1] for i in range(0, time_length-history+1)]) + + return Y_train_to_return + +def split_train_val(x, y, train_ratio=0.9): + shuffled_ids = np.random.permutation(x.shape[0]) + train_len = int(train_ratio * x.shape[0]) + train_ids = shuffled_ids[:train_len] + val_ids = shuffled_ids[train_len:] + return x[train_ids], y[train_ids], x[val_ids], y[val_ids] + +class ClimateBenchDataset(Dataset): + def __init__(self, X_train_all, Y_train_all, variables, out_variables, lat, partition='train'): + super().__init__() + self.X_train_all = X_train_all + self.Y_train_all = Y_train_all + self.len_historical = 165 + self.variables = variables + self.out_variables = out_variables + self.lat = lat + self.partition = partition + + if partition == 'train': + self.inp_transform = self.get_normalize(self.X_train_all) + # self.out_transform = self.get_normalize(self.Y_train_all) + self.out_transform = transforms.Normalize(np.array([0.]), np.array([1.])) + else: + self.inp_transform = None + self.out_transform = None + + if partition == 'test': + # only use 2080 - 2100 according to ClimateBench + self.X_train_all = self.X_train_all[-21:] + self.Y_train_all = self.Y_train_all[-21:] + self.get_rmse_normalization() + + def get_normalize(self, data): + mean = np.mean(data, axis=(0, 1, 3, 4)) + std = np.std(data, axis=(0, 1, 3, 4)) + return transforms.Normalize(mean, std) + + def set_normalize(self, inp_normalize, out_normalize): # for val and test + self.inp_transform = inp_normalize + self.out_transform = out_normalize + + def get_rmse_normalization(self): + y_avg = torch.from_numpy(self.Y_train_all).squeeze(1).mean(0) # H, W + w_lat = np.cos(np.deg2rad(self.lat)) # (H,) + w_lat = w_lat / w_lat.mean() + w_lat = torch.from_numpy(w_lat).unsqueeze(-1).to(dtype=y_avg.dtype, device=y_avg.device) # (H, 1) + self.y_normalization = torch.abs(torch.mean(y_avg * w_lat)) + + def __len__(self): + return self.X_train_all.shape[0] + + def __getitem__(self, index): + inp = self.inp_transform(torch.from_numpy(self.X_train_all[index])) + out = self.out_transform(torch.from_numpy(self.Y_train_all[index])) + # lead times = 0 + lead_times = torch.Tensor([0.0]).to(dtype=inp.dtype) + return inp, out, lead_times, self.variables, self.out_variables diff --git a/src/climax/climate_projection/module.py b/src/climax/climate_projection/module.py new file mode 100644 index 0000000..9eaee8a --- /dev/null +++ b/src/climax/climate_projection/module.py @@ -0,0 +1,221 @@ +from typing import Any, Dict + +import numpy as np +import torch +from pytorch_lightning import LightningModule +from climax.climate_projection.arch import ClimaXClimateBench +from climax.utils.lr_scheduler import LinearWarmupCosineAnnealingLR +from climax.utils.metrics import ( + mse, + lat_weighted_mse_val, + lat_weighted_nrmse, + lat_weighted_rmse, +) +from climax.utils.pos_embed import interpolate_pos_embed +from torchvision.transforms import transforms + + +class ClimateProjectionModule(LightningModule): + """Lightning module for climate projection with the ClimaXClimateBench model. + + Args: + net (ClimaXClimateBench): ClimaXClimateBench model. + pretrained_path (str, optional): Path to pre-trained checkpoint. + lr (float, optional): Learning rate. + beta_1 (float, optional): Beta 1 for AdamW. + beta_2 (float, optional): Beta 2 for AdamW. + weight_decay (float, optional): Weight decay for AdamW. + warmup_epochs (int, optional): Number of warmup epochs. + max_epochs (int, optional): Number of total epochs. + warmup_start_lr (float, optional): Starting learning rate for warmup. + eta_min (float, optional): Minimum learning rate. + """ + def __init__( + self, + net: ClimaXClimateBench, + pretrained_path: str = "", + lr: float = 5e-4, + beta_1: float = 0.9, + beta_2: float = 0.99, + weight_decay: float = 1e-5, + warmup_epochs: int = 60, + max_epochs: int = 600, + warmup_start_lr: float = 1e-8, + eta_min: float = 1e-8, + ): + super().__init__() + self.save_hyperparameters(logger=False, ignore=["net"]) + self.net = net + if len(pretrained_path) > 0: + self.load_mae_weights(pretrained_path) + + def load_mae_weights(self, pretrained_path): + if pretrained_path.startswith("http"): + checkpoint = torch.hub.load_state_dict_from_url(pretrained_path, map_location=torch.device("cpu")) + else: + checkpoint = torch.load(pretrained_path, map_location=torch.device("cpu")) + + print("Loading pre-trained checkpoint from: %s" % pretrained_path) + checkpoint_model = checkpoint["state_dict"] + # interpolate positional embedding + interpolate_pos_embed(self.net, checkpoint_model, new_size=self.net.img_size) + + state_dict = self.state_dict() + if self.net.parallel_patch_embed: + if "token_embeds.proj_weights" not in checkpoint_model.keys(): + raise ValueError( + "Pretrained checkpoint does not have token_embeds.proj_weights for parallel processing. Please convert the checkpoints first or disable parallel patch_embed tokenization." + ) + + for k in list(checkpoint_model.keys()): + if "channel" in k: + checkpoint_model[k.replace("channel", "var")] = checkpoint_model[k] + del checkpoint_model[k] + + if 'token_embeds' in k or 'head' in k: # initialize embedding from scratch + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + continue + + for k in list(checkpoint_model.keys()): + if k not in state_dict.keys() or checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + # load pre-trained model + msg = self.load_state_dict(checkpoint_model, strict=False) + print(msg) + + def set_denormalization(self, mean, std): + self.denormalization = transforms.Normalize(mean, std) + + def set_lat_lon(self, lat, lon): + self.lat = lat + self.lon = lon + + def set_pred_range(self, r): + self.pred_range = r + + def set_val_clim(self, clim): + self.val_clim = clim + + def set_test_clim(self, clim): + self.test_clim = clim + + def training_step(self, batch: Any, batch_idx: int): + x, y, lead_times, variables, out_variables = batch + + loss_dict, _ = self.net.forward(x, y, lead_times, variables, out_variables, [mse], lat=self.lat) + loss_dict = loss_dict[0] + for var in loss_dict.keys(): + self.log( + "train/" + var, + loss_dict[var], + on_step=True, + on_epoch=False, + prog_bar=True, + ) + loss = loss_dict['loss'] + + return loss + + def validation_step(self, batch: Any, batch_idx: int): + x, y, lead_times, variables, out_variables = batch + + all_loss_dicts = self.net.evaluate( + x, + y, + lead_times, + variables, + out_variables, + transform=self.denormalization, + metrics=[lat_weighted_mse_val, lat_weighted_rmse], + lat=self.lat, + clim=self.val_clim, + log_postfix=None + ) + + loss_dict = {} + for d in all_loss_dicts: + for k in d.keys(): + loss_dict[k] = d[k] + + for var in loss_dict.keys(): + self.log( + "val/" + var, + loss_dict[var], + on_step=False, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + return loss_dict + + def test_step(self, batch: Any, batch_idx: int): + x, y, lead_times, variables, out_variables = batch + + all_loss_dicts = self.net.evaluate( + x, + y, + lead_times, + variables, + out_variables, + transform=self.denormalization, + metrics=[lat_weighted_mse_val, lat_weighted_rmse, lat_weighted_nrmse], + lat=self.lat, + clim=self.test_clim, + log_postfix=None + ) + + loss_dict = {} + for d in all_loss_dicts: + for k in d.keys(): + loss_dict[k] = d[k] + + for var in loss_dict.keys(): + self.log( + "test/" + var, + loss_dict[var], + on_step=False, + on_epoch=True, + prog_bar=False, + sync_dist=True, + ) + return loss_dict + + def configure_optimizers(self): + decay = [] + no_decay = [] + for name, m in self.named_parameters(): + if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name: + no_decay.append(m) + else: + decay.append(m) + + optimizer = torch.optim.AdamW( + [ + { + "params": decay, + "lr": self.hparams.lr, + "betas": (self.hparams.beta_1, self.hparams.beta_2), + "weight_decay": self.hparams.weight_decay, + }, + { + "params": no_decay, + "lr": self.hparams.lr, + "betas": (self.hparams.beta_1, self.hparams.beta_2), + "weight_decay": 0 + }, + ] + ) + + lr_scheduler = LinearWarmupCosineAnnealingLR( + optimizer, + self.hparams.warmup_epochs, + self.hparams.max_epochs, + self.hparams.warmup_start_lr, + self.hparams.eta_min, + ) + scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1} + + return {"optimizer": optimizer, "lr_scheduler": scheduler} \ No newline at end of file diff --git a/src/climax/climate_projection/train.py b/src/climax/climate_projection/train.py new file mode 100644 index 0000000..37c47e9 --- /dev/null +++ b/src/climax/climate_projection/train.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os + +from pytorch_lightning.cli import LightningCLI +from climax.climate_projection.module import ClimateProjectionModule +from climax.climate_projection.datamodule import ClimateBenchDataModule + + +def main(): + # Initialize Lightning with the model and data modules, and instruct it to parse the config yml + cli = LightningCLI( + model_class=ClimateProjectionModule, + datamodule_class=ClimateBenchDataModule, + seed_everything_default=42, + save_config_overwrite=True, + run=False, + # auto_registry=True, + parser_kwargs={"parser_mode": "omegaconf", "error_handler": None}, + ) + os.makedirs(cli.trainer.default_root_dir, exist_ok=True) + + normalization = cli.datamodule.dataset_train.out_transform + mean_norm, std_norm = normalization.mean, normalization.std + mean_denorm, std_denorm = -mean_norm / std_norm, 1 / std_norm + cli.model.set_denormalization(mean_denorm, std_denorm) + cli.model.set_lat_lon(*cli.datamodule.get_lat_lon()) + cli.model.set_pred_range(0) + cli.model.set_val_clim(None) + cli.model.set_test_clim(cli.datamodule.get_test_clim()) + + # fit() runs the training + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + + # test the trained model + cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path='best') + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/climax/utils/metrics.py b/src/climax/utils/metrics.py index ca96e2b..bb1bb3b 100644 --- a/src/climax/utils/metrics.py +++ b/src/climax/utils/metrics.py @@ -162,10 +162,10 @@ def lat_weighted_acc(pred, y, transform, vars, lat, clim, log_postfix): return loss_dict -def lat_weighted_nrmses(pred, y, transform, vars, lat, log_steps, log_days, clim): +def lat_weighted_nrmses(pred, y, transform, vars, lat, clim, log_postfix): """ - y: [N, T, C, H, W] - pred: [N, T, C, H, W] + y: [B, V, H, W] + pred: [B V, H, W] vars: list of variable names lat: H """ @@ -175,26 +175,26 @@ def lat_weighted_nrmses(pred, y, transform, vars, lat, log_steps, log_days, clim y_normalization = clim # lattitude weights - w_lat = np.cos(np.deg2rad(lat)) # (H,) - w_lat = w_lat / w_lat.mean() + w_lat = np.cos(np.deg2rad(lat)) + w_lat = w_lat / w_lat.mean() # (H, ) w_lat = torch.from_numpy(w_lat).unsqueeze(-1).to(dtype=y.dtype, device=y.device) # (H, 1) loss_dict = {} with torch.no_grad(): for i, var in enumerate(vars): - for day, step in zip(log_days, log_steps): - pred_ = pred[:, step - 1, i] # N, H, W - y_ = y[:, step - 1, i] # N, H, W - error = (torch.mean(pred_, dim=0) - torch.mean(y_, dim=0)) ** 2 # (H, W) - error = torch.mean(error * w_lat) - loss_dict[f"w_nrmses_{var}"] = torch.sqrt(error) / y_normalization + pred_ = pred[:, i] # B, H, W + y_ = y[:, i] # B, H, W + error = (torch.mean(pred_, dim=0) - torch.mean(y_, dim=0)) ** 2 # H, W + error = torch.mean(error * w_lat) + loss_dict[f"w_nrmses_{var}"] = torch.sqrt(error) / y_normalization + return loss_dict -def lat_weighted_nrmseg(pred, y, transform, vars, lat, log_steps, log_days, clim): +def lat_weighted_nrmseg(pred, y, transform, vars, lat, clim, log_postfix): """ - y: [N, T, C, H, W] - pred: [N, T, C, H, W] + y: [B, V, H, W] + pred: [B V, H, W] vars: list of variable names lat: H """ @@ -204,32 +204,33 @@ def lat_weighted_nrmseg(pred, y, transform, vars, lat, log_steps, log_days, clim y_normalization = clim # lattitude weights - w_lat = np.cos(np.deg2rad(lat)) # (H,) - w_lat = w_lat / w_lat.mean() + w_lat = np.cos(np.deg2rad(lat)) + w_lat = w_lat / w_lat.mean() # (H, ) w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=y.dtype, device=y.device) # (1, H, 1) loss_dict = {} with torch.no_grad(): for i, var in enumerate(vars): - for day, step in zip(log_days, log_steps): - pred_ = pred[:, step - 1, i] # N, H, W - pred_ = torch.mean(pred_ * w_lat, dim=(-2, -1)) # N - y_ = y[:, step - 1, i] # N, H, W - y_ = torch.mean(y_ * w_lat, dim=(-2, -1)) # N - error = torch.mean((pred_ - y_) ** 2) - loss_dict[f"w_nrmseg_{var}"] = torch.sqrt(error) / y_normalization + pred_ = pred[:, i] # B, H, W + pred_ = torch.mean(pred_ * w_lat, dim=(-2, -1)) # B + y_ = y[:, i] # B, H, W + y_ = torch.mean(y_ * w_lat, dim=(-2, -1)) # B + error = torch.mean((pred_ - y_) ** 2) + loss_dict[f"w_nrmseg_{var}"] = torch.sqrt(error) / y_normalization + return loss_dict -def lat_weighted_nrmse(pred, y, transform, vars, lat, log_steps, log_days, clim): +def lat_weighted_nrmse(pred, y, transform, vars, lat, clim, log_postfix): """ - y: [N, T, C, H, W] - pred: [N, T, C, H, W] + y: [B, V, H, W] + pred: [B V, H, W] vars: list of variable names lat: H """ - nrmses = lat_weighted_nrmses(pred, y, transform, vars, lat, log_steps, log_days, clim) - nrmseg = lat_weighted_nrmseg(pred, y, transform, vars, lat, log_steps, log_days, clim) + + nrmses = lat_weighted_nrmses(pred, y, transform, vars, lat, clim, log_postfix) + nrmseg = lat_weighted_nrmseg(pred, y, transform, vars, lat, clim, log_postfix) loss_dict = {} for var in vars: loss_dict[f"w_nrmses_{var}"] = nrmses[f"w_nrmses_{var}"] diff --git a/src/data_preprocessing/regrid_climatebench.py b/src/data_preprocessing/regrid_climatebench.py new file mode 100644 index 0000000..ab913b3 --- /dev/null +++ b/src/data_preprocessing/regrid_climatebench.py @@ -0,0 +1,97 @@ +import os +from glob import glob + +import click +import xarray as xr +import numpy as np +import xesmf as xe + +def regrid( + ds_in, + ddeg_out, + method='bilinear', + reuse_weights=True, + cmip=False, + rename=None +): + """ + Regrid horizontally. + :param ds_in: Input xarray dataset + :param ddeg_out: Output resolution + :param method: Regridding method + :param reuse_weights: Reuse weights for regridding + :return: ds_out: Regridded dataset + """ + # import pdb; pdb.set_trace() + # Rename to ESMF compatible coordinates + if 'latitude' in ds_in.coords: + ds_in = ds_in.rename({'latitude': 'lat', 'longitude': 'lon'}) + if cmip: + ds_in = ds_in.drop(('lat_bnds', 'lon_bnds')) + if hasattr(ds_in, 'plev_bnds'): + ds_in = ds_in.drop(('plev_bnds')) + if hasattr(ds_in, 'time_bnds'): + ds_in = ds_in.drop(('time_bnds')) + if rename is not None: + ds_in = ds_in.rename({rename[0]: rename[1]}) + + # Create output grid + grid_out = xr.Dataset( + { + 'lat': (['lat'], np.arange(-90+ddeg_out/2, 90, ddeg_out)), + 'lon': (['lon'], np.arange(0, 360, ddeg_out)), + } + ) + + # Create regridder + regridder = xe.Regridder( + ds_in, grid_out, method, periodic=True, reuse_weights=reuse_weights + ) + + # Hack to speed up regridding of large files + ds_out = regridder(ds_in, keep_attrs=True).astype('float32') + + if rename is not None: + if rename[0] == 'zg': + ds_out['z'] *= 9.807 + if rename[0] == 'rsdt': + ds_out['tisr'] *= 60*60 + ds_out = ds_out.isel(time=slice(1, None, 12)) + ds_out = ds_out.assign_coords({'time': ds_out.time + np.timedelta64(90, 'm')}) + + # # Regrid dataset + # ds_out = regridder(ds_in) + return ds_out + +@click.command() +@click.argument("path", type=click.Path(exists=True)) +@click.option("--save_path", type=str) +@click.option("--ddeg_out", type=float, default=5.625) +def main( + path, + save_path, + ddeg_out +): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + list_simu = ['hist-GHG.nc', 'hist-aer.nc', 'historical.nc', 'ssp126.nc', 'ssp370.nc', 'ssp585.nc', 'ssp245.nc'] + ps = glob(os.path.join(path, f"*.nc")) + ps_ = [] + for p in ps: + for simu in list_simu: + if simu in p: + ps_.append(p) + ps = ps_ + + constant_vars = ['CO2', 'CH4'] + for p in ps: + x = xr.open_dataset(p) + if 'input' in p: + for v in constant_vars: + x[v] = x[v].expand_dims(dim={'latitude': 96, 'longitude': 144}, axis=(1,2)) + x_regridded = regrid(x, ddeg_out, reuse_weights=False) + x_regridded.to_netcdf(os.path.join(save_path, os.path.basename(p))) + +if __name__ == "__main__": + main() \ No newline at end of file