Skip to content

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on pytorch.

License

Notifications You must be signed in to change notification settings

Marei33/torch-em

 
 

Repository files navigation

DOC Build Status DOI Anaconda-Server Badge

torch-em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on PyTorch. Any feedback is highly appreciated, just open an issue!

Highlights:

  • Functional API with sensible defaults to train a state-of-the-art segmentation model with a few lines of code.
  • Differentiable augmentations on GPU and CPU thanks to kornia.
  • Off-the-shelf logging with tensorboard or wandb.
  • Export trained models to bioimage.io model format with one function call to deploy them in ilastik or deepimageJ.

Design:

  • All parameters are specified in code, no configuration files.
  • No callback logic; to extend the core functionality inherit from torch_em.trainer.DefaultTrainer instead.
  • All data-loading is lazy to support training on large datasets.

torch_em can be installed via conda: conda install -c conda-forge. Find an example script for how to train a 2D U-Net with it below and check out the documentation for more details.

# Train a 2d U-Net for foreground and boundary segmentation of nuclei, using data from
# https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip

import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets import get_dsb_loader

model = UNet2d(in_channels=1, out_channels=2)

# Transform to convert from instance segmentation labels to foreground and boundary probabilties.
label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True, ndim=2)

# Create the training and validation data loader.
data_path = "./dsb"  # The training data will be downloaded and saved here.
train_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="train",
    download=True,
    label_transform=label_transform,
)
val_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="test",
    label_transform=label_transform,
)

# The trainer handles the details of the training process.
# It will save checkpoints in "checkpoints/dsb-boundary-model"
# and the tensorboard logs in "logs/dsb-boundary-model".
trainer = torch_em.default_segmentation_trainer(
    name="dsb-boundary-model",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    learning_rate=1e-4,
)
trainer.fit(iterations=5000)  # Fit for 5000 iterations.

# Export the trained model to the bioimage.io model format.
from glob import glob
import imageio
from torch_em.util import export_bioimageio_model

# Load one of the images to use as reference image.
# Crop it to a shape that is guaranteed to fit the network.
test_im = imageio.imread(glob(f"{data_path}/test/images/*.tif")[0])[:256, :256]

# Export the model.
export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)

About

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on pytorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 52.0%
  • Python 48.0%