Skip to content

Commit

Permalink
clippy + added to ci + rename kinds + fix link
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-r-earp committed Mar 15, 2024
1 parent 0a6f249 commit b9338f1
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
run: cargo build -p neural-network-benches --benches
- name: test
if: ${{ !cancelled() }}
run: cargo test --no-default-features --features "serde neural-network" -- --format=terse
run: cargo test --no-default-features --features "serde neural-network mnist" -- --format=terse
- name: test avx
env:
RUST_BUILD_RUSTFLAGS: -Ctarget-feature=+avx
Expand Down
12 changes: 6 additions & 6 deletions examples/neural-network-mnist/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ impl LeNet5 {
#[derive(Clone, Copy, derive_more::Display, Debug, ValueEnum)]
enum Dataset {
#[display(fmt = "mnist")]
Mnist,
MNIST,
#[display(fmt = "fashion-mnist")]
Fashion,
FashionMNIST,
}

#[derive(Clone, Copy, derive_more::Display, Debug, ValueEnum)]
Expand All @@ -128,7 +128,7 @@ impl From<ScalarKind> for ScalarType {
struct Options {
#[arg(long)]
device: Option<usize>,
#[arg(long, default_value_t = Dataset::Mnist)]
#[arg(long, default_value_t = Dataset::MNIST)]
dataset: Dataset,
#[arg(long, default_value_t = ScalarKind::F32)]
scalar_type: ScalarKind,
Expand All @@ -148,8 +148,8 @@ fn main() -> Result<()> {
let options = Options::parse();
println!("{options:#?}");
let mnist_kind = match options.dataset {
Dataset::Mnist => MnistKind::Digits,
Dataset::Fashion => MnistKind::Fashion,
Dataset::MNIST => MnistKind::MNIST,
Dataset::FashionMNIST => MnistKind::FashionMNIST,
};
let Mnist {
train_images,
Expand Down Expand Up @@ -230,7 +230,7 @@ fn main() -> Result<()> {
let test_acc = test_stats.accuracy();
let epoch_elapsed = epoch_start.elapsed();
println!(
"[{epoch}] train_loss: {train_loss} train_acc: {train_acc}% {train_correct}/{train_count} test_loss: {test_loss} test_acc: {test_acc}% {test_correct}/{test_count} elapsed: {epoch_elapsed:?}"
"[{epoch}] train_loss: {train_loss:.5} train_acc: {train_acc:.2}% {train_correct}/{train_count} test_loss: {test_loss:.5} test_acc: {test_acc:.2}% {test_correct}/{test_count} elapsed: {epoch_elapsed:.2?}"
);
}
println!("Finished in {:?}.", start.elapsed());
Expand Down
88 changes: 41 additions & 47 deletions src/dataset/mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,14 @@ use std::{
};

/// The kind of Mnist.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, derive_more::Display)]
pub enum MnistKind {
/// [MNIST](<http://yann.lecun.com/exdb/mnist/>)
Digits,
/// [FashionMNIST](<https://github.com/zalandoresearch/fashion-mnist>)
Fashion,
}

impl MnistKind {
fn name(&self) -> &'static str {
match self {
Self::Digits => "mnist",
Self::Fashion => "fashion-mnist",
}
}
/// [mnist](<http://yann.lecun.org/exdb/mnist>)
#[display(fmt = "mnist")]
MNIST,
/// [fashion-mnist](<https://github.com/zalandoresearch/fashion-mnist>)
#[display(fmt = "fashion-mnist")]
FashionMNIST,
}

/// Mnist builder.
Expand All @@ -49,7 +42,7 @@ pub mod builders {
fn default() -> Self {
Self {
path: None,
kind: MnistKind::Digits,
kind: MnistKind::MNIST,
download: false,
verbose: false,
}
Expand All @@ -60,15 +53,13 @@ pub mod builders {
/// The path to load the dataset from.
///
/// This is the folder the files will be downloaded to / loaded from. If not specified, uses the OS specific "Downloads" directory or the "Temp" directory.
pub fn path(self, path: impl Into<PathBuf>) -> MnistBuilder {
MnistBuilder {
pub fn path(self, path: impl Into<PathBuf>) -> Self {
Self {
path: Some(path.into()),
kind: self.kind,
download: self.download,
verbose: self.verbose,
..self
}
}
/// The kind of Mnist to use. Defaults to [`MnistKind::Digits`] (ie the original MNIST dataset).
/// The kind of Mnist to use. Defaults to [`MnistKind::MNIST`].
pub fn kind(self, kind: MnistKind) -> Self {
Self { kind, ..self }
}
Expand Down Expand Up @@ -121,26 +112,30 @@ pub struct Mnist {

impl Mnist {
/// Returns an [`MnistBuilder`] used to specify options.
/*
```
/**
```no_run
# use autograph::{
# result::Result,
# anyhow::Result,
# dataset::mnist::{Mnist, MnistKind},
# };
# fn main() -> Result<()> {
let mnist = Mnist::builder()
.path("data")
.kind(MnistKind::Fashion)
.download(true)
.build()?;
# Ok(())
let mnist = Mnist::builder()
.path("data")
.kind(MnistKind::FashionMNIST)
.download(true)
.build()?;
# Ok(())
# }
*/
pub fn builder() -> MnistBuilder {
MnistBuilder::default()
}
fn build(builder: MnistBuilder) -> Result<Self> {
let mnist_name = builder.kind.name();
let kind = builder.kind;
let mnist_name = match kind {
MnistKind::MNIST => "mnist",
MnistKind::FashionMNIST => "fashion-mnist",
};
let mnist_path = builder
.path
.unwrap_or_else(|| dirs::download_dir().unwrap_or_else(std::env::temp_dir))
Expand All @@ -151,17 +146,17 @@ impl Mnist {
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
];
let sizes = match builder.kind {
MnistKind::Digits => [9_912_422, 28_881, 1_648_877, 4_542],
MnistKind::Fashion => [26_421_880, 29_515, 4_422_102, 5_148],
let sizes = match kind {
MnistKind::MNIST => [9_912_422, 28_881, 1_648_877, 4_542],
MnistKind::FashionMNIST => [26_421_880, 29_515, 4_422_102, 5_148],
};
if !mnist_path.exists() {
if builder.download {
fs::create_dir_all(&mnist_path)?;
download(builder.kind, &mnist_path, names, sizes, builder.verbose)
.map_err(|e| e.context(format!("Downloading {mnist_name} failed!")))?;
.map_err(|e| e.context(format!("Downloading {kind:?} failed!")))?;
} else {
bail!("{mnist_name} not found at {mnist_path:?}!");
bail!("{kind:?} not found at {mnist_path:?}!");
}
}
let [train_images, train_classes, test_images, test_classes] =
Expand Down Expand Up @@ -212,9 +207,8 @@ fn download(
sizes: [usize; 4],
verbose: bool,
) -> Result<()> {
let mnist_name = kind.name();
if verbose {
eprintln!("Downloading {mnist_name} to {mnist_path:?}...");
eprintln!("Downloading {kind:?} to {mnist_path:?}...");
}
let style = ProgressStyle::with_template(
"[{elapsed}] eta {eta} [{bar:40}] {bytes:>7} / {total_bytes:7}: {msg}",
Expand All @@ -239,10 +233,10 @@ fn download(
let result = names.into_par_iter().zip(bars).try_for_each(|(name, bar)| {
let guard = AbortGuard::new(&done, &bar);
let url = match kind {
MnistKind::Digits => {
format!("http://yann.lecun.com/exdb/mnist/{}.gz", name)
MnistKind::MNIST => {
format!("http://yann.lecun.org/exdb/mnist/{}.gz", name)
}
MnistKind::Fashion => format!(
MnistKind::FashionMNIST => format!(
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/{}.gz",
name
),
Expand Down Expand Up @@ -283,7 +277,7 @@ fn unzip(mnist_path: &Path, names: [&str; 4], sizes: [usize; 4]) -> Result<[Vec<
let mut data = <[Vec<u8>; 4]>::default();
data.par_iter_mut()
.zip(names.into_par_iter().zip(sizes))
.try_for_each(|(mut data, (name, size))| {
.try_for_each(|(data, (name, size))| {
let gz_path = mnist_path.join(name).with_extension("gz");
let file = File::open(gz_path)?;
ensure!(file.metadata()?.len() == u64::try_from(size).unwrap());
Expand All @@ -300,7 +294,7 @@ fn unzip(mnist_path: &Path, names: [&str; 4], sizes: [usize; 4]) -> Result<[Vec<
ensure!(decoder.read_i32::<BigEndian>()? == 28);
}
*data = Vec::with_capacity(len);
decoder.read_to_end(&mut data)?;
decoder.read_to_end(data)?;
ensure!(data.len() == len);
Ok(())
})?;
Expand All @@ -312,10 +306,10 @@ mod tests {
use super::*;

#[test]
fn mnist_digits() {
fn mnist() {
let dir = tempfile::tempdir().unwrap();
Mnist::builder()
.kind(MnistKind::Digits)
.kind(MnistKind::MNIST)
.download(true)
.path(dir.path())
.verbose(false)
Expand All @@ -325,10 +319,10 @@ mod tests {
}

#[test]
fn mnist_fashion() {
fn fashion() {
let dir = tempfile::tempdir().unwrap();
Mnist::builder()
.kind(MnistKind::Fashion)
.kind(MnistKind::FashionMNIST)
.download(true)
.path(dir.path())
.verbose(false)
Expand Down

0 comments on commit b9338f1

Please sign in to comment.