diff --git a/examples/gpt2/examples/gpt2-no-ndarray.rs b/examples/gpt2/examples/gpt2-no-ndarray.rs index f62952a5..832bb39d 100644 --- a/examples/gpt2/examples/gpt2-no-ndarray.rs +++ b/examples/gpt2/examples/gpt2-no-ndarray.rs @@ -4,7 +4,7 @@ use std::{ sync::Arc }; -use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session}; +use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session}; use rand::Rng; use tokenizers::Tokenizer; @@ -36,7 +36,7 @@ fn main() -> ort::Result<()> { let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? - .with_model_downloaded(GPT2::GPT2LmHead)?; + .with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?; // Load the tokenizer and encode the prompt into a sequence of tokens. let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap(); diff --git a/examples/gpt2/examples/gpt2.rs b/examples/gpt2/examples/gpt2.rs index 207b6505..c75cae9c 100644 --- a/examples/gpt2/examples/gpt2.rs +++ b/examples/gpt2/examples/gpt2.rs @@ -4,7 +4,7 @@ use std::{ }; use ndarray::{array, concatenate, s, Array1, Axis}; -use ort::{download::language::machine_comprehension::GPT2, inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Tensor}; +use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Tensor}; use rand::Rng; use tokenizers::Tokenizer; @@ -36,7 +36,7 @@ fn main() -> ort::Result<()> { let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? - .with_model_downloaded(GPT2::GPT2LmHead)?; + .with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/gpt2.onnx")?; // Load the tokenizer and encode the prompt into a sequence of tokens. let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap(); diff --git a/src/download/language.rs b/src/download/language.rs deleted file mode 100644 index 5bcc3d4a..00000000 --- a/src/download/language.rs +++ /dev/null @@ -1,5 +0,0 @@ -//! Models for language understanding. - -pub mod machine_comprehension; - -pub use machine_comprehension::{MachineComprehension, RoBERTa, GPT2}; diff --git a/src/download/language/machine_comprehension.rs b/src/download/language/machine_comprehension.rs deleted file mode 100644 index 5e48a3fc..00000000 --- a/src/download/language/machine_comprehension.rs +++ /dev/null @@ -1,77 +0,0 @@ -#![allow(clippy::upper_case_acronyms)] - -//! Models for machine language comprehension. - -use crate::download::ModelUrl; - -/// Machine comprehension models. -/// -/// A subset of natural language processing models that answer questions about a given context paragraph. -#[derive(Debug, Clone)] -pub enum MachineComprehension { - /// Answers a query about a given context paragraph. - BiDAF, - /// Answers questions based on the context of the given input paragraph. - BERTSquad, - /// Large transformer-based model that predicts sentiment based on given input text. - RoBERTa(RoBERTa), - /// Generates synthetic text samples in response to the model being primed with an arbitrary input. - GPT2(GPT2) -} - -/// Large transformer-based model that predicts sentiment based on given input text. -#[derive(Debug, Clone)] -pub enum RoBERTa { - /// Base RoBERTa model. - RoBERTaBase, - /// RoBERTa model for sequence classification. - RoBERTaSequenceClassification -} - -/// Generates synthetic text samples in response to the model being primed with an arbitrary input. -#[derive(Debug, Clone)] -pub enum GPT2 { - /// Base GPT-2 model. - GPT2, - /// GPT-2 model with a causal LM head. - GPT2LmHead -} - -impl ModelUrl for MachineComprehension { - fn model_url(&self) -> &'static str { - match self { - MachineComprehension::BiDAF => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/text/machine_comprehension/bidirectional_attention_flow/model/bidaf-9.onnx" - } - MachineComprehension::BERTSquad => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx" - } - MachineComprehension::RoBERTa(variant) => variant.model_url(), - MachineComprehension::GPT2(variant) => variant.model_url() - } - } -} - -impl ModelUrl for RoBERTa { - fn model_url(&self) -> &'static str { - match self { - RoBERTa::RoBERTaBase => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/text/machine_comprehension/roberta/model/roberta-base-11.onnx" - } - RoBERTa::RoBERTaSequenceClassification => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/text/machine_comprehension/roberta/model/roberta-sequence-classification-9.onnx" - } - } - } -} - -impl ModelUrl for GPT2 { - fn model_url(&self) -> &'static str { - match self { - GPT2::GPT2 => "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/text/machine_comprehension/gpt-2/model/gpt2-10.onnx", - GPT2::GPT2LmHead => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx" - } - } - } -} diff --git a/src/download/mod.rs b/src/download/mod.rs deleted file mode 100644 index 01e97819..00000000 --- a/src/download/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub mod language; -pub mod vision; - -/// Represents a type that returns an ONNX model URL. -pub trait ModelUrl { - /// Returns the model URL associated with this model. - fn model_url(&self) -> &'static str; -} - -impl ModelUrl for &'static str { - fn model_url(&self) -> &'static str { - self - } -} diff --git a/src/download/vision.rs b/src/download/vision.rs deleted file mode 100644 index adc3334c..00000000 --- a/src/download/vision.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! Models for computer vision. - -pub mod body_face_gesture_analysis; -pub mod domain_based_image_classification; -pub mod image_classification; -pub mod image_manipulation; -pub mod object_detection_image_segmentation; - -pub use body_face_gesture_analysis::BodyFaceGestureAnalysis; -pub use domain_based_image_classification::DomainBasedImageClassification; -pub use image_classification::{ImageClassification, InceptionVersion, ResNetV1, ResNetV2, ShuffleNetVersion, Vgg}; -pub use image_manipulation::{FastNeuralStyleTransferStyle, ImageManipulation}; -pub use object_detection_image_segmentation::ObjectDetectionImageSegmentation; diff --git a/src/download/vision/body_face_gesture_analysis.rs b/src/download/vision/body_face_gesture_analysis.rs deleted file mode 100644 index 58fa4f7d..00000000 --- a/src/download/vision/body_face_gesture_analysis.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! Models for body, face, & gesture analysis. - -use crate::download::ModelUrl; - -/// Models for body, face, & gesture analysis. -#[derive(Debug, Clone)] -pub enum BodyFaceGestureAnalysis { - /// A CNN based model for face recognition which learns discriminative features of faces and produces embeddings for - /// input face images. - ArcFace, - /// Deep CNN for emotion recognition trained on images of faces. - EmotionFerPlus -} - -impl ModelUrl for BodyFaceGestureAnalysis { - fn model_url(&self) -> &'static str { - match self { - BodyFaceGestureAnalysis::ArcFace => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/body_analysis/arcface/model/arcfaceresnet100-8.onnx" - } - BodyFaceGestureAnalysis::EmotionFerPlus => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/body_analysis/emotion_ferplus/model/emotion-ferplus-8.onnx" - } - } - } -} diff --git a/src/download/vision/domain_based_image_classification.rs b/src/download/vision/domain_based_image_classification.rs deleted file mode 100644 index fdfb245a..00000000 --- a/src/download/vision/domain_based_image_classification.rs +++ /dev/null @@ -1,20 +0,0 @@ -//! Models for domain-based image classification. - -use crate::download::ModelUrl; - -/// Models for domain-based image classification. -#[derive(Debug, Clone)] -pub enum DomainBasedImageClassification { - /// Handwritten digit prediction using CNN. - Mnist -} - -impl ModelUrl for DomainBasedImageClassification { - fn model_url(&self) -> &'static str { - match self { - DomainBasedImageClassification::Mnist => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/mnist/model/mnist-8.onnx" - } - } - } -} diff --git a/src/download/vision/image_classification.rs b/src/download/vision/image_classification.rs deleted file mode 100644 index b0c46f0d..00000000 --- a/src/download/vision/image_classification.rs +++ /dev/null @@ -1,352 +0,0 @@ -//! Models for image classification. - -#![allow(clippy::upper_case_acronyms)] - -use crate::download::ModelUrl; - -/// Convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition -/// Challenge in 2012. -#[derive(Debug, Clone)] -pub enum AlexNet { - /// AlexNet at full fp32 precision. - /// - **Size**: 233 MB - /// - **Top-1 accuracy**: 54.80% - /// - **Top-5 accuracy**: 78.23% - FullPrecision, - /// AlexNet at int8 precision. - /// - **Size**: 58 MB - /// - **Top-1 accuracy**: 54.68% - /// - **Top-5 accuracy**: 78.23% - Int8, - /// AlexNet with QDQ quantization. - /// - **Size**: 59 MB - /// - **Top-1 accuracy**: 54.71% - /// - **Top-5 accuracy**: 78.22% - QDQ -} - -/// CaffeNet a variant of AlexNet. AlexNet is the name of a convolutional neural network for classification, which -/// competed in the ImageNet Large Scale Visual Recognition Challenge in 2012. -#[derive(Debug, Clone)] -pub enum CaffeNet { - /// CaffeNet at full fp32 precision. - /// - **Size**: 233 MB - /// - **Top-1 accuracy**: 56.27% - /// - **Top-5 accuracy**: 79.52% - FullPrecision, - /// CaffeNet at int8 precision. - /// - **Size**: 58 MB - /// - **Top-1 accuracy**: 56.22% - /// - **Top-5 accuracy**: 79.52% - Int8, - /// CaffeNet with QDQ quantization. - /// - **Size**: 59 MB - /// - **Top-1 accuracy**: 56.26% - /// - **Top-5 accuracy**: 79.45% - QDQ -} - -/// Models for image classification. -#[derive(Debug, Clone)] -pub enum ImageClassification { - /// Image classification aimed for mobile targets. - /// - /// > MobileNet models perform image classification - they take images as input and classify the major - /// > object in the image into a set of pre-defined classes. They are trained on ImageNet dataset which - /// > contains images from 1000 classes. MobileNet models are also very efficient in terms of speed and - /// > size and hence are ideal for embedded and mobile applications. - MobileNet, - /// A small CNN with AlexNet level accuracy on ImageNet with 50x fewer parameters. - /// - /// > SqueezeNet is a small CNN which achieves AlexNet level accuracy on ImageNet with 50x fewer parameters. - /// > SqueezeNet requires less communication across servers during distributed training, less bandwidth to - /// > export a new model from the cloud to an autonomous car and more feasible to deploy on FPGAs and other - /// > hardware with limited memory. - SqueezeNet, - /// Image classification, trained on ImageNet with 1000 classes. - /// - /// > VGG models provide very high accuracies but at the cost of increased model sizes. They are ideal for - /// > cases when high accuracy of classification is essential and there are limited constraints on model sizes. - Vgg(Vgg), - /// Convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition - /// Challenge in 2012. - AlexNet, - /// Convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition - /// Challenge in 2014. - GoogleNet, - /// Variant of AlexNet, it's the name of a convolutional neural network for classification, which competed in the - /// ImageNet Large Scale Visual Recognition Challenge in 2012. - CaffeNet, - /// Convolutional neural network for detection. - /// - /// > This model was made by transplanting the R-CNN SVM classifiers into a fc-rcnn classification layer. - RcnnIlsvrc13, - /// Convolutional neural network for classification. - DenseNet121, - /// Google's Inception - Inception(InceptionVersion), - /// Computationally efficient CNN architecture designed specifically for mobile devices with very limited computing - /// power. - ShuffleNet(ShuffleNetVersion), - /// Deep convolutional networks for classification. - /// - /// > This model's 4th layer has 512 maps instead of 1024 maps mentioned in the paper. - ZFNet512, - /// Image classification model that achieves state-of-the-art accuracy. - /// - /// > It is designed to run on mobile CPU, GPU, and EdgeTPU devices, allowing for applications on mobile and loT, - /// where computational resources are limited. - EfficientNetLite4 -} - -#[derive(Debug, Clone)] -pub enum InceptionVersion { - V1, - V2 -} - -/// ResNet models perform image classification - they take images as input and classify the major object in the image -/// into a set of pre-defined classes. They are trained on ImageNet dataset which contains images from 1000 classes. -/// ResNet models provide very high accuracies with affordable model sizes. They are ideal for cases when high accuracy -/// of classification is required. -#[derive(Debug, Clone)] -pub enum ResNetV1 { - /// ResNet v1 with 18 layers. - /// - **Size**: 44.7 MB - /// - **Top-1 accuracy**: 69.93% - /// - **Top-5 accuracy**: 89.29% - L18, - /// ResNet v1 with 34 layers. - /// - **Size**: 83.3 MB - /// - **Top-1 accuracy**: 73.73% - /// - **Top-5 accuracy**: 91.40% - L34, - /// ResNet v1 with 50 layers. - /// - **Size**: 97.8 MB - /// - **Top-1 accuracy**: 74.93% - /// - **Top-5 accuracy**: 92.38% - L50, - /// ResNet v1 with 101 layers. - /// - **Size**: 170.6 MB - /// - **Top-1 accuracy**: 76.48% - /// - **Top-5 accuracy**: 93.20% - L101, - /// ResNet v1 with 152 layers. - /// - **Size**: 230.6 MB - /// - **Top-1 accuracy**: 77.11% - /// - **Top-5 accuracy**: 93.61% - L152 -} - -/// ResNet models perform image classification - they take images as input and classify the major object in the image -/// into a set of pre-defined classes. They are trained on ImageNet dataset which contains images from 1000 classes. -/// ResNet models provide very high accuracies with affordable model sizes. They are ideal for cases when high accuracy -/// of classification is required. -/// -/// ResNet v2 uses pre-activation function, whereas [`ResNetV1`] uses post-activation for the residual blocks. ResNet v2 -/// models achieve slightly better top-5 accuracy than their ResNet v1 counterparts. -#[derive(Debug, Clone)] -pub enum ResNetV2 { - /// ResNet v2 with 18 layers. - /// - **Size**: 44.6 MB - /// - **Top-1 accuracy**: 69.70% - /// - **Top-5 accuracy**: 89.49% - L18, - /// ResNet v2 with 34 layers. - /// - **Size**: 83.2 MB - /// - **Top-1 accuracy**: 73.36% - /// - **Top-5 accuracy**: 91.43% - L34, - /// ResNet v2 with 50 layers. - /// - **Size**: 97.7 MB - /// - **Top-1 accuracy**: 75.81% - /// - **Top-5 accuracy**: 92.82% - L50, - /// ResNet v2 with 101 layers. - /// - **Size**: 170.4 MB - /// - **Top-1 accuracy**: 77.42% - /// - **Top-5 accuracy**: 93.61% - L101, - /// ResNet v2 with 152 layers. - /// - **Size**: 230.3 MB - /// - **Top-1 accuracy**: 78.20% - /// - **Top-5 accuracy**: 94.21% - L152 -} - -#[derive(Debug, Clone)] -pub enum Vgg { - /// VGG with 16 convolutional layers - Vgg16, - /// VGG with 16 convolutional layers, with batch normalization applied after each convolutional layer. - /// - /// The batch normalization leads to better convergence and slightly better accuracies. - Vgg16Bn, - /// VGG with 19 convolutional layers - Vgg19, - /// VGG with 19 convolutional layers, with batch normalization applied after each convolutional layer. - /// - /// The batch normalization leads to better convergence and slightly better accuracies. - Vgg19Bn -} - -/// Computationally efficient CNN architecture designed specifically for mobile devices with very limited computing -/// power. -#[derive(Debug, Clone)] -pub enum ShuffleNetVersion { - /// The original ShuffleNet. - V1, - /// ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff - /// used for image classification. - V2 -} - -impl ModelUrl for AlexNet { - fn model_url(&self) -> &'static str { - match self { - AlexNet::FullPrecision => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/alexnet/model/bvlcalexnet-12.onnx" - } - AlexNet::Int8 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/alexnet/model/bvlcalexnet-12-int8.onnx" - } - AlexNet::QDQ => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/alexnet/model/bvlcalexnet-12-qdq.onnx" - } - } - } -} - -impl ModelUrl for CaffeNet { - fn model_url(&self) -> &'static str { - match self { - CaffeNet::FullPrecision => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/caffenet/model/caffenet-12.onnx" - } - CaffeNet::Int8 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/caffenet/model/caffenet-12-int8.onnx" - } - CaffeNet::QDQ => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/caffenet/model/caffenet-12-qdq.onnx" - } - } - } -} - -impl ModelUrl for ImageClassification { - fn model_url(&self) -> &'static str { - match self { - ImageClassification::MobileNet => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/mobilenet/model/mobilenetv2-7.onnx" - } - ImageClassification::SqueezeNet => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/squeezenet/model/squeezenet1.1-7.onnx" - } - ImageClassification::Inception(version) => version.model_url(), - ImageClassification::Vgg(variant) => variant.model_url(), - ImageClassification::AlexNet => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/alexnet/model/bvlcalexnet-9.onnx" - } - ImageClassification::GoogleNet => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/inception_and_googlenet/googlenet/model/googlenet-9.onnx" - } - ImageClassification::CaffeNet => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/caffenet/model/caffenet-9.onnx" - } - ImageClassification::RcnnIlsvrc13 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/rcnn_ilsvrc13/model/rcnn-ilsvrc13-9.onnx" - } - ImageClassification::DenseNet121 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/densenet-121/model/densenet-9.onnx" - } - ImageClassification::ShuffleNet(version) => version.model_url(), - ImageClassification::ZFNet512 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/zfnet-512/model/zfnet512-9.onnx" - } - ImageClassification::EfficientNetLite4 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/efficientnet-lite4/model/efficientnet-lite4.onnx" - } - } - } -} - -impl ModelUrl for InceptionVersion { - fn model_url(&self) -> &'static str { - match self { - InceptionVersion::V1 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-9.onnx" - } - InceptionVersion::V2 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-9.onnx" - } - } - } -} - -impl ModelUrl for ResNetV1 { - fn model_url(&self) -> &'static str { - match self { - ResNetV1::L18 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet18-v1-7.onnx" - } - ResNetV1::L34 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet34-v1-7.onnx" - } - ResNetV1::L50 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet50-v1-7.onnx" - } - ResNetV1::L101 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet101-v1-7.onnx" - } - ResNetV1::L152 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet152-v1-7.onnx" - } - } - } -} - -impl ModelUrl for ResNetV2 { - fn model_url(&self) -> &'static str { - match self { - ResNetV2::L18 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet18-v2-7.onnx" - } - ResNetV2::L34 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet34-v2-7.onnx" - } - ResNetV2::L50 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet50-v2-7.onnx" - } - ResNetV2::L101 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet101-v2-7.onnx" - } - ResNetV2::L152 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/resnet/model/resnet152-v2-7.onnx" - } - } - } -} - -impl ModelUrl for Vgg { - fn model_url(&self) -> &'static str { - match self { - Vgg::Vgg16 => "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/vgg/model/vgg16-7.onnx", - Vgg::Vgg16Bn => "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/vgg/model/vgg16-bn-7.onnx", - Vgg::Vgg19 => "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/vgg/model/vgg19-7.onnx", - Vgg::Vgg19Bn => "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/vgg/model/vgg19-bn-7.onnx" - } - } -} - -impl ModelUrl for ShuffleNetVersion { - fn model_url(&self) -> &'static str { - match self { - ShuffleNetVersion::V1 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/shufflenet/model/shufflenet-9.onnx" - } - ShuffleNetVersion::V2 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/shufflenet/model/shufflenet-v2-10.onnx" - } - } - } -} diff --git a/src/download/vision/image_manipulation.rs b/src/download/vision/image_manipulation.rs deleted file mode 100644 index 9d6c9e8b..00000000 --- a/src/download/vision/image_manipulation.rs +++ /dev/null @@ -1,62 +0,0 @@ -use crate::download::ModelUrl; - -/// Image Manipulation -/// -/// > Image manipulation models use neural networks to transform input images to modified output images. Some -/// > popular models in this category involve style transfer or enhancing images by increasing resolution. -#[derive(Debug, Clone)] -pub enum ImageManipulation { - /// Super Resolution - /// - /// > The Super Resolution machine learning model sharpens and upscales the input image to refine the - /// > details and improve quality. - SuperResolution, - /// Fast Neural Style Transfer - /// - /// > This artistic style transfer model mixes the content of an image with the style of another image. - /// > Examples of the styles can be seen - /// > [in this PyTorch example](https://github.com/pytorch/examples/tree/master/fast_neural_style#models). - FastNeuralStyleTransfer(FastNeuralStyleTransferStyle) -} - -#[derive(Debug, Clone)] -pub enum FastNeuralStyleTransferStyle { - Mosaic, - Candy, - RainPrincess, - Udnie, - Pointilism -} - -impl ModelUrl for ImageManipulation { - fn model_url(&self) -> &'static str { - match self { - ImageManipulation::SuperResolution => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx" - } - ImageManipulation::FastNeuralStyleTransfer(style) => style.model_url() - } - } -} - -impl ModelUrl for FastNeuralStyleTransferStyle { - fn model_url(&self) -> &'static str { - match self { - FastNeuralStyleTransferStyle::Mosaic => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/style_transfer/fast_neural_style/model/mosaic-9.onnx" - } - FastNeuralStyleTransferStyle::Candy => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/style_transfer/fast_neural_style/model/candy-9.onnx" - } - FastNeuralStyleTransferStyle::RainPrincess => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx" - } - FastNeuralStyleTransferStyle::Udnie => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/style_transfer/fast_neural_style/model/udnie-9.onnx" - } - FastNeuralStyleTransferStyle::Pointilism => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/style_transfer/fast_neural_style/model/pointilism-9.onnx" - } - } - } -} diff --git a/src/download/vision/object_detection_image_segmentation.rs b/src/download/vision/object_detection_image_segmentation.rs deleted file mode 100644 index 343c7164..00000000 --- a/src/download/vision/object_detection_image_segmentation.rs +++ /dev/null @@ -1,94 +0,0 @@ -#![allow(clippy::upper_case_acronyms)] - -use crate::download::ModelUrl; - -/// Object Detection & Image Segmentation -/// -/// > Object detection models detect the presence of multiple objects in an image and segment out areas of the -/// > image where the objects are detected. Semantic segmentation models partition an input image by labeling each pixel -/// > into a set of pre-defined categories. -#[derive(Debug, Clone)] -pub enum ObjectDetectionImageSegmentation { - /// A real-time CNN for object detection that detects 20 different classes. A smaller version of the - /// more complex full YOLOv2 network. - TinyYoloV2, - /// Single Stage Detector: real-time CNN for object detection that detects 80 different classes. - Ssd, - /// A variant of MobileNet that uses the Single Shot Detector (SSD) model framework. The model detects 80 - /// different object classes and locates up to 10 objects in an image. - SSDMobileNetV1, - /// Increases efficiency from R-CNN by connecting a RPN with a CNN to create a single, unified network for - /// object detection that detects 80 different classes. - FasterRcnn, - /// A real-time neural network for object instance segmentation that detects 80 different classes. Extends - /// Faster R-CNN as each of the 300 elected ROIs go through 3 parallel branches of the network: label - /// prediction, bounding box prediction and mask prediction. - MaskRcnn, - /// A real-time dense detector network for object detection that addresses class imbalance through Focal Loss. - /// RetinaNet is able to match the speed of previous one-stage detectors and defines the state-of-the-art in - /// two-stage detectors (surpassing R-CNN). - RetinaNet, - /// A CNN model for real-time object detection system that can detect over 9000 object categories. It uses a - /// single network evaluation, enabling it to be more than 1000x faster than R-CNN and 100x faster than - /// Faster R-CNN. - YoloV2, - /// A CNN model for real-time object detection system that can detect over 9000 object categories. It uses - /// a single network evaluation, enabling it to be more than 1000x faster than R-CNN and 100x faster than - /// Faster R-CNN. This model is trained with COCO dataset and contains 80 classes. - YoloV2Coco, - /// A deep CNN model for real-time object detection that detects 80 different classes. A little bigger than - /// YOLOv2 but still very fast. As accurate as SSD but 3 times faster. - YoloV3, - /// A smaller version of YOLOv3 model. - TinyYoloV3, - /// Optimizes the speed and accuracy of object detection. Two times faster than EfficientDet. It improves - /// YOLOv3's AP and FPS by 10% and 12%, respectively, with mAP50 of 52.32 on the COCO 2017 dataset and - /// FPS of 41.7 on Tesla 100. - YoloV4, - /// Deep CNN based pixel-wise semantic segmentation model with >80% mIOU (mean Intersection Over Union). - /// Trained on cityscapes dataset, which can be effectively implemented in self driving vehicle systems. - Duc -} - -impl ModelUrl for ObjectDetectionImageSegmentation { - fn model_url(&self) -> &'static str { - match self { - ObjectDetectionImageSegmentation::TinyYoloV2 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/tiny-yolov2/model/tinyyolov2-8.onnx" - } - ObjectDetectionImageSegmentation::Ssd => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/ssd/model/ssd-10.onnx" - } - ObjectDetectionImageSegmentation::SSDMobileNetV1 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.onnx" - } - ObjectDetectionImageSegmentation::FasterRcnn => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/faster-rcnn/model/FasterRCNN-10.onnx" - } - ObjectDetectionImageSegmentation::MaskRcnn => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/mask-rcnn/model/MaskRCNN-10.onnx" - } - ObjectDetectionImageSegmentation::RetinaNet => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/retinanet/model/retinanet-9.onnx" - } - ObjectDetectionImageSegmentation::YoloV2 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/yolov2/model/yolov2-voc-8.onnx" - } - ObjectDetectionImageSegmentation::YoloV2Coco => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/yolov2-coco/model/yolov2-coco-9.onnx" - } - ObjectDetectionImageSegmentation::YoloV3 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/yolov3/model/yolov3-10.onnx" - } - ObjectDetectionImageSegmentation::TinyYoloV3 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/tiny-yolov3/model/tiny-yolov3-11.onnx" - } - ObjectDetectionImageSegmentation::YoloV4 => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/yolov4/model/yolov4.onnx" - } - ObjectDetectionImageSegmentation::Duc => { - "https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/object_detection_segmentation/duc/model/ResNet101-DUC-7.onnx" - } - } - } -} diff --git a/src/lib.rs b/src/lib.rs index f481cd2a..7c783b0b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,6 @@ //! `ort` is a Rust binding for [ONNX Runtime](https://onnxruntime.ai/). For information on how to get started with `ort`, //! see . -pub mod download; pub(crate) mod environment; pub(crate) mod error; pub(crate) mod execution_providers; diff --git a/src/session/mod.rs b/src/session/mod.rs index 5bf37655..0e3de974 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -17,6 +17,8 @@ use std::{ #[cfg(feature = "fetch-models")] use std::{path::PathBuf, time::Duration}; +#[cfg(feature = "fetch-models")] +use super::error::FetchModelError; use super::{ api, char_p_to_string, environment::get_environment, @@ -30,8 +32,6 @@ use super::{ value::{Value, ValueType}, AllocatorType, GraphOptimizationLevel, MemType }; -#[cfg(feature = "fetch-models")] -use super::{download::ModelUrl, error::FetchModelError}; pub(crate) mod input; pub(crate) mod output; @@ -293,18 +293,10 @@ impl SessionBuilder { Ok(self) } - /// Downloads a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models) and builds the session. + /// Downloads a pre-trained ONNX model from the given URL and builds the session. #[cfg(feature = "fetch-models")] #[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))] - pub fn with_model_downloaded(self, model: M) -> Result - where - M: ModelUrl - { - self.with_model_downloaded_monomorphized(model.model_url()) - } - - #[cfg(feature = "fetch-models")] - fn with_model_downloaded_monomorphized(self, model: &str) -> Result { + pub fn with_model_downloaded(self, model_url: impl AsRef) -> Result { let mut download_dir = ort_sys::internal::dirs::cache_dir() .expect("could not determine cache directory") .join("models"); @@ -312,7 +304,38 @@ impl SessionBuilder { download_dir = std::env::current_dir().unwrap(); } - let downloaded_path = self.download_to(model, download_dir)?; + let url = model_url.as_ref(); + let model_filename = PathBuf::from(url.split('/').last().unwrap()); + let model_filepath = download_dir.join(model_filename); + let downloaded_path = if model_filepath.exists() { + tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); + model_filepath + } else { + tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{:?}", url).as_str(), "Downloading model"); + + let resp = ureq::get(url).call().map_err(Box::new).map_err(FetchModelError::FetchError)?; + + assert!(resp.has("Content-Length")); + let len = resp.header("Content-Length").and_then(|s| s.parse::().ok()).unwrap(); + tracing::info!(len, "Downloading {} bytes", len); + + let mut reader = resp.into_reader(); + + let f = std::fs::File::create(&model_filepath).unwrap(); + let mut writer = std::io::BufWriter::new(f); + + let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(FetchModelError::IoError)?; + if bytes_io_count == len as u64 { + model_filepath + } else { + return Err(FetchModelError::CopyError { + expected: len as u64, + io: bytes_io_count + } + .into()); + } + }; + self.with_model_from_file(downloaded_path) } diff --git a/tests/mnist.rs b/tests/mnist.rs index dd270e6d..9ab190f9 100644 --- a/tests/mnist.rs +++ b/tests/mnist.rs @@ -1,7 +1,7 @@ use std::path::Path; use image::{imageops::FilterType, ImageBuffer, Luma, Pixel}; -use ort::{download::vision::DomainBasedImageClassification, inputs, ArrayExtensions, GraphOptimizationLevel, Session, Tensor}; +use ort::{inputs, ArrayExtensions, GraphOptimizationLevel, Session, Tensor}; use test_log::test; #[test] @@ -13,7 +13,7 @@ fn mnist_5() -> ort::Result<()> { let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? - .with_model_downloaded(DomainBasedImageClassification::Mnist) + .with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx") .expect("Could not download model from file"); let metadata = session.metadata()?; diff --git a/tests/squeezenet.rs b/tests/squeezenet.rs index 3c2a9421..3c308ec1 100644 --- a/tests/squeezenet.rs +++ b/tests/squeezenet.rs @@ -7,7 +7,7 @@ use std::{ use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb}; use ndarray::s; -use ort::{download::vision::ImageClassification, inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, Session, Tensor}; +use ort::{inputs, ArrayExtensions, FetchModelError, GraphOptimizationLevel, Session, Tensor}; use test_log::test; #[test] @@ -19,7 +19,7 @@ fn squeezenet_mushroom() -> ort::Result<()> { let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level1)? .with_intra_threads(1)? - .with_model_downloaded(ImageClassification::SqueezeNet) + .with_model_downloaded("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/squeezenet.onnx") .expect("Could not download model from file"); let metadata = session.metadata()?;