Skip to content

CarloLucibello/Tsunami.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tsunami.jl

codecov

A high-level deep learning framework for the Julia language that helps you focus and organize the relevant part of your code while removing the boilerplate.

Tsunami is built on top of Flux.jl and it is heavily inspired by pytorch-lightning (although LightningAI is not involved in this project).

Features

  • Use Tsunami.fit! instead of implementing a training loop.
  • Logging (tensorboard).
  • Checkpoints (save and resume training).
  • Hyperparameters' schedulers.
  • CUDA, AMDGPU, and Metal GPU support.
  • Progress bars.
  • Nice organization of your code.
  • Automatic Differentiation through Zygote or Enzyme.

Installation

Install Tsunami using the Julia package manager:

pkg> add Tsunami

Usage

Define your model by subtyping the FluxModule abstract type, implement a few required methods, then let the Trainer train the model on your dataset with Tsunami.fit!. Tsunami will handle the boilerplate (training loop, logging, gpu movement, validation, ...).

In the following script, we train a Multilayer Perceptron on the FashionMNIST dataset using Tsunami:

using Flux, Optimisers, Statistics, Tsunami, MLDatasets
using MLUtils: DataLoader, flatten, mapobs
## uncomment one of the following for GPU acceleration
# using CUDA
# using AMDGPU
# using Metal

## Define the model 

struct MLP{T} <: FluxModule
    net::T
end

MLP() = MLP(Chain(Dense(28^2 => 512, relu), Dense(512 => 10)))

(model::MLP)(x) = model.net(flatten(x))

function loss_and_accuracy(model::MLP, batch)
    x, y = batch
    ŷ = model(x)
    return Flux.logitcrossentropy(ŷ, y), Tsunami.accuracy(ŷ, y)
end

function Tsunami.train_step(model::MLP, trainer, batch)
    loss, acc = loss_and_accuracy(model, batch)
    Tsunami.log(trainer, "loss/train", loss, prog_bar=true)
    Tsunami.log(trainer, "accuracy/train", acc, prog_bar=true)
    return loss
end

function Tsunami.val_step(model::MLP, trainer, batch)
    loss, acc = loss_and_accuracy(model, batch)
    Tsunami.log(trainer, "loss/val", loss)
    Tsunami.log(trainer, "accuracy/val", acc)
end

Tsunami.configure_optimisers(model::MLP, trainer) = 
    Optimisers.setup(Optimisers.AdamW(1e-3), model)

## Prepare the data

function mnist_transform(batch)
    x, y = batch
    y = Flux.onehotbatch(y, 0:9)
    return (x, y)
end

train_data = FashionMNIST(split=:train)
train_data = mapobs(mnist_transform, train_data)[:]
train_loader = DataLoader(train_data, batchsize=128, shuffle=true)

test_data = FashionMNIST(split=:test)
test_data = mapobs(mnist_transform, test_data)[:]
test_loader = DataLoader(test_data, batchsize=128)

## Create and train the model

model = MLP()
trainer = Trainer(max_epochs=5)
Tsunami.fit!(model, trainer, train_loader, test_loader)

What follows is the final output of the script. The script will train the model on GPU if available and will also write tensorboard logs and and model checkpoints to disk.

See the documentation and check the examples folder to learn more.

Contributions are welcome!

If you want to contribute to Tsunami, please open an issue or a pull request. Any help is appreciated!

Similar julia libraries

About

Neural network training, fast and easy.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages