Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ryanunderhill/MNIST sample #1330

Merged
merged 9 commits into from
Jul 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 228 additions & 0 deletions samples/c_cxx/MNIST/MNIST.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define UNICODE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and define _UNICODE ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Windows uses #ifdef UNICODE, does it also use _UNICODE?

#include <windows.h>
#include <windowsx.h>
#include <onnxruntime_cxx_api.h>

#pragma comment(lib, "user32.lib")
#pragma comment(lib, "gdi32.lib")
#pragma comment(lib, "onnxruntime.lib")

Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};

// This is the structure to interface with the MNIST model
// After instantiation, set the input_image_ data to be the 28x28 pixel image of the number to recognize
// Then call Run() to fill in the results_ data with the probabilities of each
// result_ holds the index with highest probability (aka the number the model thinks is in the image)
struct MNIST {
MNIST() {
auto allocator_info = Ort::AllocatorInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());
output_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size());
}

int Run() {
const char* input_names[] = {"Input3"};
const char* output_names[] = {"Plus214_Output_0"};

session_.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor_, 1, output_names, &output_tensor_, 1);

result_ = std::distance(results_.begin(), std::max_element(results_.begin(), results_.end()));
return result_;
}

static constexpr const int width_ = 28;
static constexpr const int height_ = 28;

std::array<float, width_ * height_> input_image_{};
std::array<float, 10> results_{};
int result_{0};

private:
Ort::Session session_{env, L"model.onnx", Ort::SessionOptions{nullptr}};

Ort::Value input_tensor_{nullptr};
std::array<int64_t, 4> input_shape_{1, 1, width_, height_};

Ort::Value output_tensor_{nullptr};
std::array<int64_t, 2> output_shape_{1, 10};
};

const constexpr int drawing_area_inset_{4}; // Number of pixels to inset the top left of the drawing area
const constexpr int drawing_area_scale_{4}; // Number of times larger to make the drawing area compared to the shape inputs
const constexpr int drawing_area_width_{MNIST::width_ * drawing_area_scale_};
const constexpr int drawing_area_height_{MNIST::height_ * drawing_area_scale_};

MNIST mnist_;
HBITMAP dib_;
HDC hdc_dib_;
bool painting_{};

HBRUSH brush_winner_{CreateSolidBrush(RGB(128, 255, 128))};
HBRUSH brush_bars_{CreateSolidBrush(RGB(128, 128, 255))};

struct DIBInfo : DIBSECTION {
DIBInfo(HBITMAP hBitmap) noexcept { ::GetObject(hBitmap, sizeof(DIBSECTION), this); }

int Width() const noexcept { return dsBm.bmWidth; }
int Height() const noexcept { return dsBm.bmHeight; }

void* Bits() const noexcept { return dsBm.bmBits; }
int Pitch() const noexcept { return dsBmih.biSizeImage / abs(dsBmih.biHeight); }
};

// We need to convert the true-color data in the DIB into the model's floating point format
// TODO: (also scales down the image and smooths the values, but this is not working properly)
void ConvertDibToMnist() {
DIBInfo info{dib_};

const DWORD* input = reinterpret_cast<const DWORD*>(info.Bits());
float* output = mnist_.input_image_.data();

std::fill(mnist_.input_image_.begin(), mnist_.input_image_.end(), 0.f);

for (unsigned y = 0; y < MNIST::height_; y++) {
for (unsigned x = 0; x < MNIST::width_; x++) {
output[x] += input[x] == 0 ? 1.0f : 0.0f;
}
input = reinterpret_cast<const DWORD*>(reinterpret_cast<const BYTE*>(input) + info.Pitch());
output += MNIST::width_;
}
}

LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);

// The Windows entry point function
int APIENTRY wWinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPTSTR lpCmdLine, int nCmdShow) {
{
WNDCLASSEX wc{};
wc.cbSize = sizeof(WNDCLASSEX);
wc.style = CS_HREDRAW | CS_VREDRAW;
wc.lpfnWndProc = WndProc;
wc.hInstance = hInstance;
wc.hCursor = LoadCursor(NULL, IDC_ARROW);
wc.hbrBackground = (HBRUSH)(COLOR_WINDOW + 1);
wc.lpszClassName = L"ONNXTest";
RegisterClassEx(&wc);
}
{
BITMAPINFO bmi{};
bmi.bmiHeader.biSize = sizeof(bmi.bmiHeader);
bmi.bmiHeader.biWidth = MNIST::width_;
bmi.bmiHeader.biHeight = -MNIST::height_;
bmi.bmiHeader.biPlanes = 1;
bmi.bmiHeader.biBitCount = 32;
bmi.bmiHeader.biPlanes = 1;
bmi.bmiHeader.biCompression = BI_RGB;

void* bits;
dib_ = CreateDIBSection(nullptr, &bmi, DIB_RGB_COLORS, &bits, nullptr, 0);
}

hdc_dib_ = CreateCompatibleDC(nullptr);
SelectObject(hdc_dib_, dib_);
SelectObject(hdc_dib_, CreatePen(PS_SOLID, 2, RGB(0, 0, 0)));
FillRect(hdc_dib_, &RECT{0, 0, MNIST::width_, MNIST::height_}, (HBRUSH)GetStockObject(WHITE_BRUSH));

HWND hWnd = CreateWindow(L"ONNXTest", L"ONNX Runtime Sample - MNIST", WS_OVERLAPPEDWINDOW, CW_USEDEFAULT, CW_USEDEFAULT, 512, 256, nullptr, nullptr, hInstance, nullptr);
if (!hWnd)
return FALSE;

ShowWindow(hWnd, nCmdShow);

MSG msg;
while (GetMessage(&msg, NULL, 0, 0)) {
TranslateMessage(&msg);
DispatchMessage(&msg);
}
return (int)msg.wParam;
}

LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam) {
switch (message) {
case WM_PAINT: {
PAINTSTRUCT ps;
HDC hdc = BeginPaint(hWnd, &ps);

// Draw the image
StretchBlt(hdc, drawing_area_inset_, drawing_area_inset_, drawing_area_width_, drawing_area_height_, hdc_dib_, 0, 0, MNIST::width_, MNIST::height_, SRCCOPY);
SelectObject(hdc, GetStockObject(BLACK_PEN));
SelectObject(hdc, GetStockObject(NULL_BRUSH));
Rectangle(hdc, drawing_area_inset_, drawing_area_inset_, drawing_area_inset_ + drawing_area_width_, drawing_area_inset_ + drawing_area_height_);

constexpr int graphs_left = drawing_area_inset_ + drawing_area_width_ + 5;
constexpr int graph_width = 64;
SelectObject(hdc, brush_bars_);

auto least = *std::min_element(mnist_.results_.begin(), mnist_.results_.end());
auto greatest = mnist_.results_[mnist_.result_];
auto range = greatest - least;

auto graphs_zero = graphs_left - least * graph_width / range;

// Hilight the winner
RECT rc{graphs_left, mnist_.result_ * 16, graphs_left + graph_width + 128, (mnist_.result_ + 1) * 16};
FillRect(hdc, &rc, brush_winner_);

// For every entry, draw the odds and the graph for it
SetBkMode(hdc, TRANSPARENT);
wchar_t value[80];
for (unsigned i = 0; i < 10; i++) {
int y = 16 * i;
float result = mnist_.results_[i];

auto length = wsprintf(value, L"%2d: %d.%02d", i, int(result), abs(int(result * 100) % 100));
TextOut(hdc, graphs_left + graph_width + 5, y, value, length);

Rectangle(hdc, graphs_zero, y + 1, graphs_zero + result * graph_width / range, y + 14);
}

// Draw the zero line
MoveToEx(hdc, graphs_zero, 0, nullptr);
LineTo(hdc, graphs_zero, 16 * 10);

EndPaint(hWnd, &ps);
return 0;
}

case WM_LBUTTONDOWN: {
SetCapture(hWnd);
painting_ = true;
int x = (GET_X_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_;
int y = (GET_Y_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_;
MoveToEx(hdc_dib_, x, y, nullptr);
return 0;
}

case WM_MOUSEMOVE:
if (painting_) {
int x = (GET_X_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_;
int y = (GET_Y_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_;
LineTo(hdc_dib_, x, y);
InvalidateRect(hWnd, nullptr, false);
}
return 0;

case WM_CAPTURECHANGED:
painting_ = false;
return 0;

case WM_LBUTTONUP:
ReleaseCapture();
ConvertDibToMnist();
mnist_.Run();
InvalidateRect(hWnd, nullptr, true);
return 0;

case WM_RBUTTONDOWN: // Erase the image
FillRect(hdc_dib_, &RECT{0, 0, MNIST::width_, MNIST::height_}, (HBRUSH)GetStockObject(WHITE_BRUSH));
InvalidateRect(hWnd, nullptr, false);
return 0;

case WM_DESTROY:
PostQuitMessage(0);
return 0;
}
return DefWindowProc(hWnd, message, wParam, lParam);
}
66 changes: 66 additions & 0 deletions samples/c_cxx/MNIST/ReadMe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# MNIST Sample - Number recognition

This sample uses the MNIST model from the Model Zoo: https://github.com/onnx/models/tree/master/mnist

![Screenshot](Screenshot.png)

## Requirements

Compiled Onnxruntime.dll / lib (link to instructions on how to build dll)
Windows Visual Studio Compiler (cl.exe)

## Build

Run 'build.bat' in this directory to call cl.exe to generate MNIST.exe
Then just run MNIST.exe

## How to use it

Just draw a number with the left mouse button (or use touch) in the box on the left side. After releasing the mouse button the model will be run and the outputs of the model will be displayed. Note that when drawing numbers requiring multiple drawing strokes, the model will be run at the end of each stroke with probably wrong predictions (but it's amusing to see and avoids needing to press a 'run model' button).

To clear the image, click the right mouse button anywhere.

## How it works

A single Ort::Env is created globally to initialize the runtime.
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L12

The MNIST structure abstracts away all of the interaction with the Onnx Runtime, creating the tensors, and running the model.

WWinMain is the Windows entry point, it creates the main window.

WndProc is the window procedure for the window, handling the mouse input and drawing the graphics

### Preprocessing the data

MNIST's input is a {1,1,28,28} shaped float tensor, which is basically a 28x28 floating point grayscale image (0.0 = background, 1.0 = foreground).

The sample stores the image in a 32-bit per pixel windows DIB section, since that's easy to draw into and draw to the screen for windows. The DIB is created here:
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L109-L121

The function to convert the DIB data and writ it into the model's input tensor:
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L77-L92

### Postprocessing the output

MNIST's output is a simple {1,10} float tensor that holds the likelihood weights per number. The number with the highest value is the model's best guess.

The MNIST structure uses std::max_element to do this and stores it in result_:
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L31

To make things more interesting, the window painting handler graphs the probabilities and shows the weights here:
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L164-L183

### The Ort::Session

1. Creation: The Ort::Session is created inside the MNIST structure here:
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L43

2. Setup inputs & outputs: The input & output tensors are created here:
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L19-L23
In this usage, we're providing the memory location for the data instead of having Ort allocate the buffers. This is simpler in this case since the buffers are small and can just be fixed members of the MNIST struct.

3. Run: Running the session is done in the Run() method:
https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L25-L33


Binary file added samples/c_cxx/MNIST/Screenshot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions samples/c_cxx/MNIST/build.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cl MNIST.cpp /Zi /EHsc /I..\..\..\include\onnxruntime\core\session /link /LIBPATH:..\..\..\build\Windows\Debug\Debug