Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Climate projection #26

Merged
merged 3 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions configs/climate_projection.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
seed_everything: 42

# ---------------------------- TRAINER -------------------------------------------
trainer:
default_root_dir: ${oc.env:AMLT_OUTPUT_DIR,/home/tungnd/ClimaX/exps/climate_projection_climax}

precision: 16

gpus: null
num_nodes: 1
accelerator: gpu
strategy: ddp

min_epochs: 1
max_epochs: 50
enable_progress_bar: true

sync_batchnorm: True
enable_checkpointing: True
resume_from_checkpoint: null

# debugging
fast_dev_run: false

logger:
class_path: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${trainer.default_root_dir}/logs
name: null
version: null
log_graph: False
default_hp_metric: True
prefix: ""

callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: "step"

- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: "${trainer.default_root_dir}/checkpoints/"
monitor: "val/w_mse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
verbose: False
filename: "epoch_{epoch:03d}"
auto_insert_metric_name: False

- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: "val/w_mse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
patience: 5 # how many validation epochs of not improving until training stops
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement

- class_path: pytorch_lightning.callbacks.RichModelSummary
init_args:
max_depth: -1

- class_path: pytorch_lightning.callbacks.RichProgressBar

# ---------------------------- MODEL -------------------------------------------
model:
lr: 5e-4
beta_1: 0.9
beta_2: 0.999
weight_decay: 1e-5
warmup_epochs: 60
max_epochs: 600
warmup_start_lr: 1e-8
eta_min: 1e-8
pretrained_path: "https://huggingface.co/tungnd/climax/resolve/main/5.625deg.ckpt"

net:
class_path: climax.climate_projection.arch.ClimaXClimateBench
init_args:
default_vars: [
'CO2',
'SO2',
'CH4',
'BC'
]
out_vars: "tas" # diurnal_temperature_range, tas, pr, pr90
img_size: [32, 64]
time_history: 10
patch_size: 2
embed_dim: 1024
depth: 8
num_heads: 16
mlp_ratio: 4
drop_path: 0.1
drop_rate: 0.1
parallel_patch_embed: False
freeze_encoder: True

# ---------------------------- DATA -------------------------------------------
data:
root_dir: /home/data/datasets/climate-learn/climatebench/5.625deg/
history: 10
list_train_simu: [
'ssp126',
'ssp370',
'ssp585',
'historical',
'hist-GHG',
'hist-aer'
]
list_test_simu: ['ssp245']
variables: [
'CO2',
'SO2',
'CH4',
'BC'
]
out_variables: 'tas'
train_ratio: 0.9
batch_size: 1
num_workers: 1
pin_memory: False
36 changes: 36 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,42 @@ python src/climax/regional_forecast/train.py --config configs/regional_forecast_
```
To train ClimaX from scratch, set `--model.pretrained_path=""`.

## Climate Projection

### Data Preparation

First, download [ClimateBench](https://doi.org/10.5281/zenodo.5196512) data. ClimaX can work with either the original ClimateBench data or the regridded version. In the experiment in the paper, we regridded to ClimateBench data to 5.625 degree. To do that, run
```bash
python src/data_preprocessing/regrid_climatebench.py /mnt/data/climatebench/train_val \
--save_path /mnt/data/climatebench/5.625deg/train_val --ddeg_out 5.625
```
and
```bash
python src/data_preprocessing/regrid_climatebench.py /mnt/data/climatebench/test \
--save_path /mnt/data/climatebench/5.625deg/test --ddeg_out 5.625
```

### Training

To finetune ClimaX for climate projection, use
```
python src/climax/climate_projection/train.py --config <path/to/config>
```
For example, to finetune ClimaX on 8 GPUs use
```bash
python python src/climax/climate_projection/train.py --config configs/climate_projection.yaml \
--trainer.strategy=ddp --trainer.devices=8 \
--trainer.max_epochs=50 \
--data.root_dir=/mnt/data/climatebench/5.625deg \
--data.out_variables="tas" \
--data.batch_size=16 \
--model.pretrained_path='https://huggingface.co/tungnd/climax/resolve/main/5.625deg.ckpt' \
--model.out_vars="tas" \
--model.lr=5e-4 --model.beta_1="0.9" --model.beta_2="0.99" \
--model.weight_decay=1e-5
```
To train ClimaX from scratch, set `--model.pretrained_path=""`.

## Visualization

Coming soon
Empty file.
144 changes: 144 additions & 0 deletions src/climax/climate_projection/arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import numpy as np
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
import torch
import torch.nn as nn
from climax.arch import ClimaX
from climax.utils.pos_embed import get_1d_sincos_pos_embed_from_grid


class ClimaXClimateBench(ClimaX):
def __init__(
self,
default_vars,
out_vars,
img_size=[32, 64],
time_history=1,
patch_size=2,
embed_dim=1024,
depth=8,
decoder_depth=2,
num_heads=16,
mlp_ratio=4.0,
drop_path=0.1,
drop_rate=0.1,
parallel_patch_embed=False,
freeze_encoder=False,
):
assert out_vars is not None

super().__init__(
default_vars,
img_size,
patch_size,
embed_dim,
depth,
decoder_depth,
num_heads,
mlp_ratio,
drop_path,
drop_rate,
parallel_patch_embed
)

self.out_vars = out_vars
self.time_history = time_history
self.freeze_encoder = freeze_encoder

# used to aggregate multiple timesteps in the input
self.time_pos_embed = nn.Parameter(torch.zeros(1, time_history, embed_dim), requires_grad=True)
self.time_agg = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.time_query = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)

# initialize time embedding
time_pos_embed = get_1d_sincos_pos_embed_from_grid(self.time_pos_embed.shape[-1], np.arange(self.time_history))
self.time_pos_embed.data.copy_(torch.from_numpy(time_pos_embed).float().unsqueeze(0))

# overwrite ClimaX
# use a linear prediction head for this task
self.head = nn.Linear(embed_dim, img_size[0]*img_size[1])

if freeze_encoder:
for name, p in self.blocks.named_parameters():
name = name.lower()
# we do not freeze the norm layers, as suggested by https://arxiv.org/abs/2103.05247
if 'norm' in name:
continue
else:
p.requires_grad_(False)
rejuvyesh marked this conversation as resolved.
Show resolved Hide resolved

def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables):
# x: `[B, T, V, H, W]` shape.

if isinstance(variables, list):
variables = tuple(variables)

b, t, _, _, _ = x.shape
x = x.flatten(0, 1) # BxT, V, H, W

# tokenize each variable separately
embeds = []
var_ids = self.get_var_ids(variables, x.device)

if self.parallel_patch_embed:
x = self.token_embeds(x, var_ids) # BxT, V, L, D
else:
for i in range(len(var_ids)):
id = var_ids[i]
embeds.append(self.token_embeds[id](x[:, i : i + 1]))
x = torch.stack(embeds, dim=1) # BxT, V, L, D

# add variable embedding
var_embed = self.get_var_emb(self.var_embed, variables)
x = x + var_embed.unsqueeze(2) # BxT, V, L, D

# variable aggregation
x = self.aggregate_variables(x) # BxT, L, D

# add pos embedding
x = x + self.pos_embed

# add time embedding
# time emb: 1, T, D
x = x.unflatten(0, sizes=(b, t)) # B, T, L, D
x = x + self.time_pos_embed.unsqueeze(2)

# add lead time embedding
lead_time_emb = self.lead_time_embed(lead_times.unsqueeze(-1)) # B, D
lead_time_emb = lead_time_emb.unsqueeze(1).unsqueeze(2)
x = x + lead_time_emb # B, T, L, D

x = x.flatten(0, 1) # BxT, L, D

x = self.pos_drop(x)

# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x) # BxT, L, D
x = x.unflatten(0, sizes=(b, t)) # B, T, L, D

# global average pooling, also used in CNN-LSTM baseline in ClimateBench
x = x.mean(-2) # B, T, D
time_query = self.time_query.repeat_interleave(x.shape[0], dim=0)
x, _ = self.time_agg(time_query, x, x) # B, 1, D

return x

def forward(self, x, y, lead_times, variables, out_variables, metric, lat):
x = self.forward_encoder(x, lead_times, variables) # B, 1, D
preds = self.head(x)
preds = preds.reshape(-1, 1, self.img_size[0], self.img_size[1]) # B, 1, H, W
if metric is None:
loss = None
else:
loss = [m(preds, y, out_variables, lat) for m in metric]
return loss, preds
Loading