Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : cgraph export/import/eval example + GPU support #108

Merged
merged 27 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f477d4f
ggml : cgraph export brainstorming
ggerganov Apr 24, 2023
2a03421
mnist : code style
ggerganov May 21, 2023
85dcc0c
mnist : minor
ggerganov May 21, 2023
95c8507
ggml : initial cgraph export
ggerganov May 24, 2023
3120189
ggml : initial graph import (wip)
ggerganov May 24, 2023
d2d1c22
ggml : import op args correctly
ggerganov May 25, 2023
4cfd92b
ggml : add ggml_get_tensor_by_name()
ggerganov May 27, 2023
b0450c2
mnist : add compute graph evaluation on CPU example
ggerganov May 27, 2023
ddea488
ggml : add ggml_tensor_overhead()
ggerganov May 27, 2023
f698dbf
ggml : rename new functions to ggml_cgraph_...
ggerganov May 27, 2023
bf93623
mnist : add Metal inference skeleton (WIP)
ggerganov May 27, 2023
bb126f9
mnist : working on the Metal pipeline (WIP)
ggerganov May 28, 2023
24ea9dd
mnist : prepare the Metal encoder (WIP)
ggerganov May 28, 2023
2ec1dff
mnist : first Metal kernel for F32 ADD
ggerganov May 28, 2023
966f9e6
mnist : looks like MTLHeap does not work
ggerganov May 28, 2023
1bc9181
mnist : initial full pass of MNIST on the GPU (not verified)
ggerganov May 28, 2023
4134bac
mnist : minor cleanup
ggerganov May 28, 2023
a556b57
mnist : full GPU inference works
ggerganov May 28, 2023
8f8653b
mnist : use custom soft_max kernel since MPSMatrixSoftMax is bugged
ggerganov May 28, 2023
3b97377
mnist : use constant for soft_max instead of hardcoded 10
ggerganov May 28, 2023
e350f13
mnist : check multiple predictions (Metal)
ggerganov May 28, 2023
4fa01f0
mnist : minor
ggerganov May 28, 2023
79dcbfd
ggml : move cgraph import / export to ggml
ggerganov May 28, 2023
25adade
mnist : remove common dependencies
ggerganov May 28, 2023
e6dc506
mnist : fix soft_max threadgroup size
ggerganov May 29, 2023
f9b04df
mnist : init no_alloc member
ggerganov May 29, 2023
c8013c5
ggml : improve "get tensor" API
ggerganov May 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
463 changes: 463 additions & 0 deletions examples/common-ggml.cpp

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions examples/common-ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ bool ggml_common_quantize_0(
const ggml_ftype ftype,
const std::vector<std::string> & to_quant,
const std::vector<std::string> & to_skip);

// these will move to ggml when ready
void ggml_cgraph_export(const struct ggml_cgraph * cgraph, const char * fname);
ggml_cgraph ggml_cgraph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
30 changes: 29 additions & 1 deletion examples/mnist/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,33 @@

set(TEST_TARGET mnist)
add_executable(${TEST_TARGET} main.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)

#
# mnist-cpu

set(TEST_TARGET mnist-cpu)
add_executable(${TEST_TARGET} main-cpu.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)

if (APPLE)
#
# mnist-mtl

find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)

set(TEST_TARGET mnist-mtl)
add_executable(${TEST_TARGET} main-mtl.cpp main-mtl.h main-mtl.m)
target_link_libraries(${TEST_TARGET} PRIVATE
ggml
common
common-ggml
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()
118 changes: 118 additions & 0 deletions examples/mnist/main-cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Use a pre-generated MNIST compute graph for inference on the CPU
//
// You can generate a compute graph using the "mnist" tool:
//
// $ ./bin/mnist ./models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
//
// This command creates the "mnist.ggml" file, which contains the generated compute graph.
// Now, you can re-use the compute graph with the "mnist-cpu" tool:
//
// $ ./bin/mnist-cpu ./models/mnist/mnist.ggml ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
//

#include "ggml/ggml.h"

#include "common.h"
#include "common-ggml.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <string>

// evaluate the MNIST compute graph
//
// - fname_cgraph: path to the compute graph
// - n_threads: number of threads to use
// - digit: 784 pixel values
//
// returns 0 - 9 prediction
int mnist_eval(
const char * fname_cgraph,
const int n_threads,
std::vector<float> digit
) {
// load the compute graph
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;

struct ggml_cgraph gfi = ggml_cgraph_import(fname_cgraph, &ctx_data, &ctx_eval);
gfi.n_threads = n_threads;

// allocate eval context
// needed during ggml_graph_compute() to allocate a work tensor
static size_t buf_size = gfi.work_size; // TODO
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
.mem_size = buf_size,
.mem_buffer = buf,
};

struct ggml_context * ctx0 = ggml_init(params);

struct ggml_tensor * input = ggml_get_tensor_by_name(&gfi, "input");
memcpy(input->data, digit.data(), ggml_nbytes(input));

ggml_graph_compute(ctx0, &gfi);

const float * probs_data = ggml_get_data_f32(ggml_get_tensor_by_name(&gfi, "probs"));

const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;

ggml_free(ctx0);
ggml_free(ctx_data);
ggml_free(ctx_eval);

return prediction;
}

int main(int argc, char ** argv) {
srand(time(NULL));
ggml_time_init();

if (argc != 3) {
fprintf(stderr, "Usage: %s models/mnist/mnist.ggml models/mnist/t10k-images.idx3-ubyte\n", argv[0]);
exit(0);
}

uint8_t buf[784];
std::vector<float> digit;

// read a random digit from the test set
{
std::ifstream fin(argv[2], std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]);
return 1;
}

// seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
fin.seekg(16 + 784 * (rand() % 10000));
fin.read((char *) &buf, sizeof(buf));
}

// render the digit in ASCII
{
digit.resize(sizeof(buf));

for (int row = 0; row < 28; row++) {
for (int col = 0; col < 28; col++) {
fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
digit[row*28 + col] = ((float)buf[row*28 + col]);
}

fprintf(stderr, "\n");
}

fprintf(stderr, "\n");
}

const int prediction = mnist_eval(argv[1], 1, digit);

fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction);

return 0;
}
107 changes: 107 additions & 0 deletions examples/mnist/main-mtl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Use a pre-generated MNIST compute graph for inference on the M1 GPU via MPS
//

#include "ggml/ggml.h"

#include "main-mtl.h"

#include "common-ggml.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>

// evaluate the MNIST compute graph
//
// - fname_cgraph: path to the compute graph
// - n_threads: number of threads to use
// - digit: 784 pixel values
//
// returns 0 - 9 prediction
int mnist_eval(
const char * fname_cgraph,
const int n_threads,
std::vector<float> digit
) {
// load the compute graph
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;

struct ggml_cgraph gf = ggml_cgraph_import(fname_cgraph, &ctx_data, &ctx_eval);
gf.n_threads = n_threads;

// allocate eval context
// needed during ggml_graph_compute() to allocate a work tensor
static size_t buf_size = gf.work_size; // TODO
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
.mem_size = buf_size,
.mem_buffer = buf,
};

struct ggml_context * ctx_work = ggml_init(params);

struct ggml_tensor * input = ggml_get_tensor_by_name(&gf, "input");
memcpy(input->data, digit.data(), ggml_nbytes(input));

auto ctx_mtl = mnist_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);
const int prediction = mnist_mtl_eval(ctx_mtl, &gf);
mnist_mtl_free(ctx_mtl);

ggml_free(ctx_work);
ggml_free(ctx_data);
ggml_free(ctx_eval);

return prediction;
}

int main(int argc, char ** argv) {
srand(time(NULL));
ggml_time_init();

if (argc != 3) {
fprintf(stderr, "Usage: %s models/mnist/mnist.ggml models/mnist/t10k-images.idx3-ubyte\n", argv[0]);
exit(0);
}

uint8_t buf[784];
std::vector<float> digit;

// read a random digit from the test set
{
std::ifstream fin(argv[2], std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]);
return 1;
}

// seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
fin.seekg(16 + 784 * (rand() % 10000));
fin.read((char *) &buf, sizeof(buf));
}

// render the digit in ASCII
{
digit.resize(sizeof(buf));

for (int row = 0; row < 28; row++) {
for (int col = 0; col < 28; col++) {
fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
digit[row*28 + col] = ((float)buf[row*28 + col]);
}

fprintf(stderr, "\n");
}

fprintf(stderr, "\n");
}

const int prediction = mnist_eval(argv[1], 1, digit);

fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction);

return 0;
}
26 changes: 26 additions & 0 deletions examples/mnist/main-mtl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

struct ggml_context;
struct ggml_cgraph;

#ifdef __cplusplus
extern "C" {
#endif

struct ggml_mtl_context;

struct ggml_mtl_context * mnist_mtl_init(
struct ggml_context * ctx_data,
struct ggml_context * ctx_eval,
struct ggml_context * ctx_work,
struct ggml_cgraph * gf);

void mnist_mtl_free(struct ggml_mtl_context * ctx);

int mnist_mtl_eval(
struct ggml_mtl_context * ctx,
struct ggml_cgraph * gf);

#ifdef __cplusplus
}
#endif
Loading