-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ryanunderhill/MNIST sample #1330
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
0a79cfa
Initial commit of sample
RyanUnderhill 6fda348
Update ReadMe.md
RyanUnderhill 80628ec
Add more comments.
RyanUnderhill 6dc3d9d
Merge branch 'ryanunderhill/nmist_sample' of https://github.com/Micro…
RyanUnderhill 521dc75
Prettier view, fix some drawing bugs
RyanUnderhill 1444de9
Update ReadMe.md
RyanUnderhill ed7ca97
Updates + screenshot
RyanUnderhill aeca60a
Merge branch 'ryanunderhill/nmist_sample' of https://github.com/Micro…
RyanUnderhill 5501d86
Update ReadMe.md
RyanUnderhill File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
#define UNICODE | ||
#include <windows.h> | ||
#include <windowsx.h> | ||
#include <onnxruntime_cxx_api.h> | ||
|
||
#pragma comment(lib, "user32.lib") | ||
#pragma comment(lib, "gdi32.lib") | ||
#pragma comment(lib, "onnxruntime.lib") | ||
|
||
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"}; | ||
|
||
// This is the structure to interface with the MNIST model | ||
// After instantiation, set the input_image_ data to be the 28x28 pixel image of the number to recognize | ||
// Then call Run() to fill in the results_ data with the probabilities of each | ||
// result_ holds the index with highest probability (aka the number the model thinks is in the image) | ||
struct MNIST { | ||
MNIST() { | ||
auto allocator_info = Ort::AllocatorInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); | ||
input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); | ||
output_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); | ||
} | ||
|
||
int Run() { | ||
const char* input_names[] = {"Input3"}; | ||
const char* output_names[] = {"Plus214_Output_0"}; | ||
|
||
session_.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor_, 1, output_names, &output_tensor_, 1); | ||
|
||
result_ = std::distance(results_.begin(), std::max_element(results_.begin(), results_.end())); | ||
return result_; | ||
} | ||
|
||
static constexpr const int width_ = 28; | ||
static constexpr const int height_ = 28; | ||
|
||
std::array<float, width_ * height_> input_image_{}; | ||
std::array<float, 10> results_{}; | ||
int result_{0}; | ||
|
||
private: | ||
Ort::Session session_{env, L"model.onnx", Ort::SessionOptions{nullptr}}; | ||
|
||
Ort::Value input_tensor_{nullptr}; | ||
std::array<int64_t, 4> input_shape_{1, 1, width_, height_}; | ||
|
||
Ort::Value output_tensor_{nullptr}; | ||
std::array<int64_t, 2> output_shape_{1, 10}; | ||
}; | ||
|
||
const constexpr int drawing_area_inset_{4}; // Number of pixels to inset the top left of the drawing area | ||
const constexpr int drawing_area_scale_{4}; // Number of times larger to make the drawing area compared to the shape inputs | ||
const constexpr int drawing_area_width_{MNIST::width_ * drawing_area_scale_}; | ||
const constexpr int drawing_area_height_{MNIST::height_ * drawing_area_scale_}; | ||
|
||
MNIST mnist_; | ||
HBITMAP dib_; | ||
HDC hdc_dib_; | ||
bool painting_{}; | ||
|
||
HBRUSH brush_winner_{CreateSolidBrush(RGB(128, 255, 128))}; | ||
HBRUSH brush_bars_{CreateSolidBrush(RGB(128, 128, 255))}; | ||
|
||
struct DIBInfo : DIBSECTION { | ||
DIBInfo(HBITMAP hBitmap) noexcept { ::GetObject(hBitmap, sizeof(DIBSECTION), this); } | ||
|
||
int Width() const noexcept { return dsBm.bmWidth; } | ||
int Height() const noexcept { return dsBm.bmHeight; } | ||
|
||
void* Bits() const noexcept { return dsBm.bmBits; } | ||
int Pitch() const noexcept { return dsBmih.biSizeImage / abs(dsBmih.biHeight); } | ||
}; | ||
|
||
// We need to convert the true-color data in the DIB into the model's floating point format | ||
// TODO: (also scales down the image and smooths the values, but this is not working properly) | ||
void ConvertDibToMnist() { | ||
DIBInfo info{dib_}; | ||
|
||
const DWORD* input = reinterpret_cast<const DWORD*>(info.Bits()); | ||
float* output = mnist_.input_image_.data(); | ||
|
||
std::fill(mnist_.input_image_.begin(), mnist_.input_image_.end(), 0.f); | ||
|
||
for (unsigned y = 0; y < MNIST::height_; y++) { | ||
for (unsigned x = 0; x < MNIST::width_; x++) { | ||
output[x] += input[x] == 0 ? 1.0f : 0.0f; | ||
} | ||
input = reinterpret_cast<const DWORD*>(reinterpret_cast<const BYTE*>(input) + info.Pitch()); | ||
output += MNIST::width_; | ||
} | ||
} | ||
|
||
LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM); | ||
|
||
// The Windows entry point function | ||
int APIENTRY wWinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPTSTR lpCmdLine, int nCmdShow) { | ||
{ | ||
WNDCLASSEX wc{}; | ||
wc.cbSize = sizeof(WNDCLASSEX); | ||
wc.style = CS_HREDRAW | CS_VREDRAW; | ||
wc.lpfnWndProc = WndProc; | ||
wc.hInstance = hInstance; | ||
wc.hCursor = LoadCursor(NULL, IDC_ARROW); | ||
wc.hbrBackground = (HBRUSH)(COLOR_WINDOW + 1); | ||
wc.lpszClassName = L"ONNXTest"; | ||
RegisterClassEx(&wc); | ||
} | ||
{ | ||
BITMAPINFO bmi{}; | ||
bmi.bmiHeader.biSize = sizeof(bmi.bmiHeader); | ||
bmi.bmiHeader.biWidth = MNIST::width_; | ||
bmi.bmiHeader.biHeight = -MNIST::height_; | ||
bmi.bmiHeader.biPlanes = 1; | ||
bmi.bmiHeader.biBitCount = 32; | ||
bmi.bmiHeader.biPlanes = 1; | ||
bmi.bmiHeader.biCompression = BI_RGB; | ||
|
||
void* bits; | ||
dib_ = CreateDIBSection(nullptr, &bmi, DIB_RGB_COLORS, &bits, nullptr, 0); | ||
} | ||
|
||
hdc_dib_ = CreateCompatibleDC(nullptr); | ||
SelectObject(hdc_dib_, dib_); | ||
SelectObject(hdc_dib_, CreatePen(PS_SOLID, 2, RGB(0, 0, 0))); | ||
FillRect(hdc_dib_, &RECT{0, 0, MNIST::width_, MNIST::height_}, (HBRUSH)GetStockObject(WHITE_BRUSH)); | ||
|
||
HWND hWnd = CreateWindow(L"ONNXTest", L"ONNX Runtime Sample - MNIST", WS_OVERLAPPEDWINDOW, CW_USEDEFAULT, CW_USEDEFAULT, 512, 256, nullptr, nullptr, hInstance, nullptr); | ||
if (!hWnd) | ||
return FALSE; | ||
|
||
ShowWindow(hWnd, nCmdShow); | ||
|
||
MSG msg; | ||
while (GetMessage(&msg, NULL, 0, 0)) { | ||
TranslateMessage(&msg); | ||
DispatchMessage(&msg); | ||
} | ||
return (int)msg.wParam; | ||
} | ||
|
||
LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam) { | ||
switch (message) { | ||
case WM_PAINT: { | ||
PAINTSTRUCT ps; | ||
HDC hdc = BeginPaint(hWnd, &ps); | ||
|
||
// Draw the image | ||
StretchBlt(hdc, drawing_area_inset_, drawing_area_inset_, drawing_area_width_, drawing_area_height_, hdc_dib_, 0, 0, MNIST::width_, MNIST::height_, SRCCOPY); | ||
SelectObject(hdc, GetStockObject(BLACK_PEN)); | ||
SelectObject(hdc, GetStockObject(NULL_BRUSH)); | ||
Rectangle(hdc, drawing_area_inset_, drawing_area_inset_, drawing_area_inset_ + drawing_area_width_, drawing_area_inset_ + drawing_area_height_); | ||
|
||
constexpr int graphs_left = drawing_area_inset_ + drawing_area_width_ + 5; | ||
constexpr int graph_width = 64; | ||
SelectObject(hdc, brush_bars_); | ||
|
||
auto least = *std::min_element(mnist_.results_.begin(), mnist_.results_.end()); | ||
auto greatest = mnist_.results_[mnist_.result_]; | ||
auto range = greatest - least; | ||
|
||
auto graphs_zero = graphs_left - least * graph_width / range; | ||
|
||
// Hilight the winner | ||
RECT rc{graphs_left, mnist_.result_ * 16, graphs_left + graph_width + 128, (mnist_.result_ + 1) * 16}; | ||
FillRect(hdc, &rc, brush_winner_); | ||
|
||
// For every entry, draw the odds and the graph for it | ||
SetBkMode(hdc, TRANSPARENT); | ||
wchar_t value[80]; | ||
for (unsigned i = 0; i < 10; i++) { | ||
int y = 16 * i; | ||
float result = mnist_.results_[i]; | ||
|
||
auto length = wsprintf(value, L"%2d: %d.%02d", i, int(result), abs(int(result * 100) % 100)); | ||
TextOut(hdc, graphs_left + graph_width + 5, y, value, length); | ||
|
||
Rectangle(hdc, graphs_zero, y + 1, graphs_zero + result * graph_width / range, y + 14); | ||
} | ||
|
||
// Draw the zero line | ||
MoveToEx(hdc, graphs_zero, 0, nullptr); | ||
LineTo(hdc, graphs_zero, 16 * 10); | ||
|
||
EndPaint(hWnd, &ps); | ||
return 0; | ||
} | ||
|
||
case WM_LBUTTONDOWN: { | ||
SetCapture(hWnd); | ||
painting_ = true; | ||
int x = (GET_X_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_; | ||
int y = (GET_Y_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_; | ||
MoveToEx(hdc_dib_, x, y, nullptr); | ||
return 0; | ||
} | ||
|
||
case WM_MOUSEMOVE: | ||
if (painting_) { | ||
int x = (GET_X_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_; | ||
int y = (GET_Y_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_; | ||
LineTo(hdc_dib_, x, y); | ||
InvalidateRect(hWnd, nullptr, false); | ||
} | ||
return 0; | ||
|
||
case WM_CAPTURECHANGED: | ||
painting_ = false; | ||
return 0; | ||
|
||
case WM_LBUTTONUP: | ||
ReleaseCapture(); | ||
ConvertDibToMnist(); | ||
mnist_.Run(); | ||
InvalidateRect(hWnd, nullptr, true); | ||
return 0; | ||
|
||
case WM_RBUTTONDOWN: // Erase the image | ||
FillRect(hdc_dib_, &RECT{0, 0, MNIST::width_, MNIST::height_}, (HBRUSH)GetStockObject(WHITE_BRUSH)); | ||
InvalidateRect(hWnd, nullptr, false); | ||
return 0; | ||
|
||
case WM_DESTROY: | ||
PostQuitMessage(0); | ||
return 0; | ||
} | ||
return DefWindowProc(hWnd, message, wParam, lParam); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# MNIST Sample - Number recognition | ||
|
||
This sample uses the MNIST model from the Model Zoo: https://github.com/onnx/models/tree/master/mnist | ||
|
||
![Screenshot](Screenshot.png) | ||
|
||
## Requirements | ||
|
||
Compiled Onnxruntime.dll / lib (link to instructions on how to build dll) | ||
Windows Visual Studio Compiler (cl.exe) | ||
|
||
## Build | ||
|
||
Run 'build.bat' in this directory to call cl.exe to generate MNIST.exe | ||
Then just run MNIST.exe | ||
|
||
## How to use it | ||
|
||
Just draw a number with the left mouse button (or use touch) in the box on the left side. After releasing the mouse button the model will be run and the outputs of the model will be displayed. Note that when drawing numbers requiring multiple drawing strokes, the model will be run at the end of each stroke with probably wrong predictions (but it's amusing to see and avoids needing to press a 'run model' button). | ||
|
||
To clear the image, click the right mouse button anywhere. | ||
|
||
## How it works | ||
|
||
A single Ort::Env is created globally to initialize the runtime. | ||
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L12 | ||
|
||
The MNIST structure abstracts away all of the interaction with the Onnx Runtime, creating the tensors, and running the model. | ||
|
||
WWinMain is the Windows entry point, it creates the main window. | ||
|
||
WndProc is the window procedure for the window, handling the mouse input and drawing the graphics | ||
|
||
### Preprocessing the data | ||
|
||
MNIST's input is a {1,1,28,28} shaped float tensor, which is basically a 28x28 floating point grayscale image (0.0 = background, 1.0 = foreground). | ||
|
||
The sample stores the image in a 32-bit per pixel windows DIB section, since that's easy to draw into and draw to the screen for windows. The DIB is created here: | ||
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L109-L121 | ||
|
||
The function to convert the DIB data and writ it into the model's input tensor: | ||
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L77-L92 | ||
|
||
### Postprocessing the output | ||
|
||
MNIST's output is a simple {1,10} float tensor that holds the likelihood weights per number. The number with the highest value is the model's best guess. | ||
|
||
The MNIST structure uses std::max_element to do this and stores it in result_: | ||
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L31 | ||
|
||
To make things more interesting, the window painting handler graphs the probabilities and shows the weights here: | ||
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L164-L183 | ||
|
||
### The Ort::Session | ||
|
||
1. Creation: The Ort::Session is created inside the MNIST structure here: | ||
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L43 | ||
|
||
2. Setup inputs & outputs: The input & output tensors are created here: | ||
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L19-L23 | ||
In this usage, we're providing the memory location for the data instead of having Ort allocate the buffers. This is simpler in this case since the buffers are small and can just be fixed members of the MNIST struct. | ||
|
||
3. Run: Running the session is done in the Run() method: | ||
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L25-L33 | ||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
cl MNIST.cpp /Zi /EHsc /I..\..\..\include\onnxruntime\core\session /link /LIBPATH:..\..\..\build\Windows\Debug\Debug |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and define _UNICODE ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Windows uses #ifdef UNICODE, does it also use _UNICODE?