Skip to content

mhubii/libtorch_custom_dataset

Repository files navigation

Libtorch Custom Dataset

This is a short example on how to generate custom datasets for libtorch. The CustomDataset class in custom_dataset.h implements a torch::data::Dataset. It loads the image locations from file_names.csv into a std::vector<std::tuple<std::string, int>>, so that the CustomDataset can load images at runtime with the get method using OpenCV. You may want to change this and load all images to the RAM, since this may significantly speed up training if you are not using SSDs.


Fig. 2: (Left) An apple in the dataset, (Right) a banana in the dataset.

Build

Make sure to get libtorch running. For a clean installation from Anaconda, checkout this short tutorial, or this tutorial, to only download the binaries.

Clone this repository

git clone https://github.com/mhubii/libtorch_custom_dataset.git
cd libtorch_custom_dataset

Build the executables

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make

Train the classifier

cd build
./train

Classify an image

cd build
./classify filename
# for example
./classify ../data/apples/img0.jpg

Notes

The dataset is a modified version of a dataset that can be found on Kaggle. Especially the training loop is inspired by an implementation of Peter Goldsborough for the MNIST dataset in the PyTorch example repository.

About

Demo dataset for libtorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published