diff --git a/samples/c_cxx/MNIST/MNIST.cpp b/samples/c_cxx/MNIST/MNIST.cpp new file mode 100644 index 0000000000000..0f3c98c96a825 --- /dev/null +++ b/samples/c_cxx/MNIST/MNIST.cpp @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#define UNICODE +#include +#include +#include + +#pragma comment(lib, "user32.lib") +#pragma comment(lib, "gdi32.lib") +#pragma comment(lib, "onnxruntime.lib") + +Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"}; + +// This is the structure to interface with the MNIST model +// After instantiation, set the input_image_ data to be the 28x28 pixel image of the number to recognize +// Then call Run() to fill in the results_ data with the probabilities of each +// result_ holds the index with highest probability (aka the number the model thinks is in the image) +struct MNIST { + MNIST() { + auto allocator_info = Ort::AllocatorInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + input_tensor_ = Ort::Value::CreateTensor(allocator_info, input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size()); + output_tensor_ = Ort::Value::CreateTensor(allocator_info, results_.data(), results_.size(), output_shape_.data(), output_shape_.size()); + } + + int Run() { + const char* input_names[] = {"Input3"}; + const char* output_names[] = {"Plus214_Output_0"}; + + session_.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor_, 1, output_names, &output_tensor_, 1); + + result_ = std::distance(results_.begin(), std::max_element(results_.begin(), results_.end())); + return result_; + } + + static constexpr const int width_ = 28; + static constexpr const int height_ = 28; + + std::array input_image_{}; + std::array results_{}; + int result_{0}; + + private: + Ort::Session session_{env, L"model.onnx", Ort::SessionOptions{nullptr}}; + + Ort::Value input_tensor_{nullptr}; + std::array input_shape_{1, 1, width_, height_}; + + Ort::Value output_tensor_{nullptr}; + std::array output_shape_{1, 10}; +}; + +const constexpr int drawing_area_inset_{4}; // Number of pixels to inset the top left of the drawing area +const constexpr int drawing_area_scale_{4}; // Number of times larger to make the drawing area compared to the shape inputs +const constexpr int drawing_area_width_{MNIST::width_ * drawing_area_scale_}; +const constexpr int drawing_area_height_{MNIST::height_ * drawing_area_scale_}; + +MNIST mnist_; +HBITMAP dib_; +HDC hdc_dib_; +bool painting_{}; + +HBRUSH brush_winner_{CreateSolidBrush(RGB(128, 255, 128))}; +HBRUSH brush_bars_{CreateSolidBrush(RGB(128, 128, 255))}; + +struct DIBInfo : DIBSECTION { + DIBInfo(HBITMAP hBitmap) noexcept { ::GetObject(hBitmap, sizeof(DIBSECTION), this); } + + int Width() const noexcept { return dsBm.bmWidth; } + int Height() const noexcept { return dsBm.bmHeight; } + + void* Bits() const noexcept { return dsBm.bmBits; } + int Pitch() const noexcept { return dsBmih.biSizeImage / abs(dsBmih.biHeight); } +}; + +// We need to convert the true-color data in the DIB into the model's floating point format +// TODO: (also scales down the image and smooths the values, but this is not working properly) +void ConvertDibToMnist() { + DIBInfo info{dib_}; + + const DWORD* input = reinterpret_cast(info.Bits()); + float* output = mnist_.input_image_.data(); + + std::fill(mnist_.input_image_.begin(), mnist_.input_image_.end(), 0.f); + + for (unsigned y = 0; y < MNIST::height_; y++) { + for (unsigned x = 0; x < MNIST::width_; x++) { + output[x] += input[x] == 0 ? 1.0f : 0.0f; + } + input = reinterpret_cast(reinterpret_cast(input) + info.Pitch()); + output += MNIST::width_; + } +} + +LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM); + +// The Windows entry point function +int APIENTRY wWinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPTSTR lpCmdLine, int nCmdShow) { + { + WNDCLASSEX wc{}; + wc.cbSize = sizeof(WNDCLASSEX); + wc.style = CS_HREDRAW | CS_VREDRAW; + wc.lpfnWndProc = WndProc; + wc.hInstance = hInstance; + wc.hCursor = LoadCursor(NULL, IDC_ARROW); + wc.hbrBackground = (HBRUSH)(COLOR_WINDOW + 1); + wc.lpszClassName = L"ONNXTest"; + RegisterClassEx(&wc); + } + { + BITMAPINFO bmi{}; + bmi.bmiHeader.biSize = sizeof(bmi.bmiHeader); + bmi.bmiHeader.biWidth = MNIST::width_; + bmi.bmiHeader.biHeight = -MNIST::height_; + bmi.bmiHeader.biPlanes = 1; + bmi.bmiHeader.biBitCount = 32; + bmi.bmiHeader.biPlanes = 1; + bmi.bmiHeader.biCompression = BI_RGB; + + void* bits; + dib_ = CreateDIBSection(nullptr, &bmi, DIB_RGB_COLORS, &bits, nullptr, 0); + } + + hdc_dib_ = CreateCompatibleDC(nullptr); + SelectObject(hdc_dib_, dib_); + SelectObject(hdc_dib_, CreatePen(PS_SOLID, 2, RGB(0, 0, 0))); + FillRect(hdc_dib_, &RECT{0, 0, MNIST::width_, MNIST::height_}, (HBRUSH)GetStockObject(WHITE_BRUSH)); + + HWND hWnd = CreateWindow(L"ONNXTest", L"ONNX Runtime Sample - MNIST", WS_OVERLAPPEDWINDOW, CW_USEDEFAULT, CW_USEDEFAULT, 512, 256, nullptr, nullptr, hInstance, nullptr); + if (!hWnd) + return FALSE; + + ShowWindow(hWnd, nCmdShow); + + MSG msg; + while (GetMessage(&msg, NULL, 0, 0)) { + TranslateMessage(&msg); + DispatchMessage(&msg); + } + return (int)msg.wParam; +} + +LRESULT CALLBACK WndProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam) { + switch (message) { + case WM_PAINT: { + PAINTSTRUCT ps; + HDC hdc = BeginPaint(hWnd, &ps); + + // Draw the image + StretchBlt(hdc, drawing_area_inset_, drawing_area_inset_, drawing_area_width_, drawing_area_height_, hdc_dib_, 0, 0, MNIST::width_, MNIST::height_, SRCCOPY); + SelectObject(hdc, GetStockObject(BLACK_PEN)); + SelectObject(hdc, GetStockObject(NULL_BRUSH)); + Rectangle(hdc, drawing_area_inset_, drawing_area_inset_, drawing_area_inset_ + drawing_area_width_, drawing_area_inset_ + drawing_area_height_); + + constexpr int graphs_left = drawing_area_inset_ + drawing_area_width_ + 5; + constexpr int graph_width = 64; + SelectObject(hdc, brush_bars_); + + auto least = *std::min_element(mnist_.results_.begin(), mnist_.results_.end()); + auto greatest = mnist_.results_[mnist_.result_]; + auto range = greatest - least; + + auto graphs_zero = graphs_left - least * graph_width / range; + + // Hilight the winner + RECT rc{graphs_left, mnist_.result_ * 16, graphs_left + graph_width + 128, (mnist_.result_ + 1) * 16}; + FillRect(hdc, &rc, brush_winner_); + + // For every entry, draw the odds and the graph for it + SetBkMode(hdc, TRANSPARENT); + wchar_t value[80]; + for (unsigned i = 0; i < 10; i++) { + int y = 16 * i; + float result = mnist_.results_[i]; + + auto length = wsprintf(value, L"%2d: %d.%02d", i, int(result), abs(int(result * 100) % 100)); + TextOut(hdc, graphs_left + graph_width + 5, y, value, length); + + Rectangle(hdc, graphs_zero, y + 1, graphs_zero + result * graph_width / range, y + 14); + } + + // Draw the zero line + MoveToEx(hdc, graphs_zero, 0, nullptr); + LineTo(hdc, graphs_zero, 16 * 10); + + EndPaint(hWnd, &ps); + return 0; + } + + case WM_LBUTTONDOWN: { + SetCapture(hWnd); + painting_ = true; + int x = (GET_X_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_; + int y = (GET_Y_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_; + MoveToEx(hdc_dib_, x, y, nullptr); + return 0; + } + + case WM_MOUSEMOVE: + if (painting_) { + int x = (GET_X_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_; + int y = (GET_Y_LPARAM(lParam) - drawing_area_inset_) / drawing_area_scale_; + LineTo(hdc_dib_, x, y); + InvalidateRect(hWnd, nullptr, false); + } + return 0; + + case WM_CAPTURECHANGED: + painting_ = false; + return 0; + + case WM_LBUTTONUP: + ReleaseCapture(); + ConvertDibToMnist(); + mnist_.Run(); + InvalidateRect(hWnd, nullptr, true); + return 0; + + case WM_RBUTTONDOWN: // Erase the image + FillRect(hdc_dib_, &RECT{0, 0, MNIST::width_, MNIST::height_}, (HBRUSH)GetStockObject(WHITE_BRUSH)); + InvalidateRect(hWnd, nullptr, false); + return 0; + + case WM_DESTROY: + PostQuitMessage(0); + return 0; + } + return DefWindowProc(hWnd, message, wParam, lParam); +} diff --git a/samples/c_cxx/MNIST/ReadMe.md b/samples/c_cxx/MNIST/ReadMe.md new file mode 100644 index 0000000000000..043a63dbc9e17 --- /dev/null +++ b/samples/c_cxx/MNIST/ReadMe.md @@ -0,0 +1,66 @@ +# MNIST Sample - Number recognition + +This sample uses the MNIST model from the Model Zoo: https://github.com/onnx/models/tree/master/mnist + +![Screenshot](Screenshot.png) + +## Requirements + +Compiled Onnxruntime.dll / lib (link to instructions on how to build dll) +Windows Visual Studio Compiler (cl.exe) + +## Build + +Run 'build.bat' in this directory to call cl.exe to generate MNIST.exe +Then just run MNIST.exe + +## How to use it + +Just draw a number with the left mouse button (or use touch) in the box on the left side. After releasing the mouse button the model will be run and the outputs of the model will be displayed. Note that when drawing numbers requiring multiple drawing strokes, the model will be run at the end of each stroke with probably wrong predictions (but it's amusing to see and avoids needing to press a 'run model' button). + +To clear the image, click the right mouse button anywhere. + +## How it works + +A single Ort::Env is created globally to initialize the runtime. +https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L12 + +The MNIST structure abstracts away all of the interaction with the Onnx Runtime, creating the tensors, and running the model. + +WWinMain is the Windows entry point, it creates the main window. + +WndProc is the window procedure for the window, handling the mouse input and drawing the graphics + +### Preprocessing the data + +MNIST's input is a {1,1,28,28} shaped float tensor, which is basically a 28x28 floating point grayscale image (0.0 = background, 1.0 = foreground). + +The sample stores the image in a 32-bit per pixel windows DIB section, since that's easy to draw into and draw to the screen for windows. The DIB is created here: +https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L109-L121 + +The function to convert the DIB data and writ it into the model's input tensor: +https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L77-L92 + +### Postprocessing the output + +MNIST's output is a simple {1,10} float tensor that holds the likelihood weights per number. The number with the highest value is the model's best guess. + +The MNIST structure uses std::max_element to do this and stores it in result_: +https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L31 + +To make things more interesting, the window painting handler graphs the probabilities and shows the weights here: +https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L164-L183 + +### The Ort::Session + +1. Creation: The Ort::Session is created inside the MNIST structure here: +https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L43 + +2. Setup inputs & outputs: The input & output tensors are created here: +https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L19-L23 +In this usage, we're providing the memory location for the data instead of having Ort allocate the buffers. This is simpler in this case since the buffers are small and can just be fixed members of the MNIST struct. + +3. Run: Running the session is done in the Run() method: +https://github.com/microsoft/onnxruntime/blob/521dc757984fbf9770d0051997178fbb9565cd52/samples/c_cxx/MNIST/MNIST.cpp#L25-L33 + + diff --git a/samples/c_cxx/MNIST/Screenshot.png b/samples/c_cxx/MNIST/Screenshot.png new file mode 100644 index 0000000000000..4c4ea23007e54 Binary files /dev/null and b/samples/c_cxx/MNIST/Screenshot.png differ diff --git a/samples/c_cxx/MNIST/build.bat b/samples/c_cxx/MNIST/build.bat new file mode 100644 index 0000000000000..eba19ffbd1926 --- /dev/null +++ b/samples/c_cxx/MNIST/build.bat @@ -0,0 +1 @@ +cl MNIST.cpp /Zi /EHsc /I..\..\..\include\onnxruntime\core\session /link /LIBPATH:..\..\..\build\Windows\Debug\Debug \ No newline at end of file