-
-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⚗️ Benchmark experiments loading ERA5 Zarr data using kvikIO #4
Conversation
Setting up a LightningDataModule with a torchdata DataPipe and looping through one epoch (23 mini-batches) as a timing benchmark experiment using the kvikIO engine. The DataPipe does loading from Zarr, slicing with xbatcher, batching, and collating to torch.Tensor objects. Timer uses Python's time.perf_counter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @negin513, this is the benchmark script I mentioned two days ago during the Pangeo ML Working Group meeting. Feel free to give this a try if you have time, and let me know if you have any questions!
I'm gonna experiment on this a bit more over the next few days, and will hopefully come up with some proper numbers and nice graphs.
Cc @maxrjones for visibility.
# Step 1.2 - Slice datacube along time-dimension into 12 hour chunks (2x 6-hourly) | ||
.slice_with_xbatcher( | ||
input_dims={"latitude": 721, "longitude": 1440, "time": 2}, | ||
preload_batch=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like kvikio
engine (11s) is slightly faster than zarr
(14s) when preload_batch=False
, whereas zarr
(10s) is slightly faster than kvikio
(11s) when the default preload_batch=True
is set. Maybe because the loading from dask.Array objects is not so optimized for kvikIO compared to Zarr yet?
For benchmark purposes though, it's probably best to disable this preload_batch
setting since it's somewhat like a cache, and we want to look at raw IO speed. And yes, the timings are probably not significantly different, so I'll run it over more epochs as mentioned at #4 (comment) to get a better average time.
torch.set_float32_matmul_precision(precision="medium") | ||
|
||
# Setup data | ||
datamodule: L.LightningDataModule = WeatherBench2DataModule(engine="kvikio") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently changing the engine manually here (either kvikio
or zarr
). Should make an CLI flag to set the engine here.
# Training loop | ||
for i, batch in tqdm.tqdm(iterable=enumerate(train_dataloader), total=23): | ||
input, target, metadata = batch | ||
# Compute Mean Squared Error loss between t=0 and t=1, just for fun | ||
loss: torch.Tensor = torch.functional.F.mse_loss(input=input, target=target) | ||
print(f"Batch {i}, MSE Loss: {loss}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO train on more than just 1 epoch (maybe 100?) to get a nicer average result comparing between zarr
and kvikio
.
Gonna merge this as this and iterate/improve on it later (if I have time). |
Setting up a LightningDataModule with a torchdata DataPipe and looping through one epoch (23 mini-batches with a batch-size of 32) as a timing benchmark experiment using the kvikIO engine.
The DataPipe does:
Timer uses Python's time.perf_counter. May look into proper profiling later.
References:
WeatherBench2DataModule
is adapted from theZarrDataPipeModule
at https://gitlab.com/frontierdevelopmentlab/2022-us-sarchangedetection/deepslide/-/blob/main/src/datamodules/datapipemodule.py?ref_type=headskvikio
backend/engine is from Add Kvikio backend entrypoint xarray-contrib/cupy-xarray#10