From beb475a63fb1c4055213403b56b9b20191e3c581 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 9 Feb 2024 11:27:53 -0500 Subject: [PATCH] feat: Update experimental api --- ggml/experimental.py | 347 ++++++++++++++++++++------------- tests/test_experimental.py | 26 --- tests/test_experimental_api.py | 92 +++++++++ 3 files changed, 301 insertions(+), 164 deletions(-) delete mode 100644 tests/test_experimental.py create mode 100644 tests/test_experimental_api.py diff --git a/ggml/experimental.py b/ggml/experimental.py index 39f2e7b..72d3826 100644 --- a/ggml/experimental.py +++ b/ggml/experimental.py @@ -2,7 +2,7 @@ import enum import ctypes -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import ggml from ggml.utils import from_numpy, to_numpy @@ -29,16 +29,91 @@ def __init__( class Context: + _current: List[Context] = [] + def __init__(self, init_params: InitParams): self.init_params = init_params - self.context: ggml.ggml_context_p = ggml.ggml_init(init_params.params) + context: Optional[ggml.ggml_context_p] = ggml.ggml_init(init_params.params) + if context is None: + raise ValueError("Failed to initialize context") + self.context = context def __del__(self): ggml.ggml_free(self.context) + def __enter__(self): + self._current.append(self) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._current.pop() + return False + @classmethod - def with_tensor_overhead(cls): - return cls(InitParams(mem_size=ggml.ggml_tensor_overhead(), no_alloc=True)) + def current_context(cls): + if len(cls._current) == 0: + raise ValueError( + "No current context. Please create a context using ggml_context() first." + "Then you can set it as the current context using with statement." + "with ggml_context() as ctx:" + " ..." + ) + return cls._current[-1] + + +def ggml_context( + mem_size: int = ggml.GGML_DEFAULT_GRAPH_SIZE * ggml.ggml_tensor_overhead() + + ggml.ggml_graph_overhead(), + mem_buffer: Optional[ctypes.c_void_p] = None, + no_alloc: bool = True, +): + return Context(InitParams(mem_size, mem_buffer, no_alloc)) + + +class Backend: + def __init__(self, backend: ggml.ggml_backend_t): + self.backend = backend + + def alloc_ctx_tensors(self, ctx: Optional[Context] = None): + ctx = ctx or Context.current_context() + return ggml.ggml_backend_alloc_ctx_tensors(ctx.context, self.backend) + + @staticmethod + def cpu(): + backend = ggml.ggml_backend_cpu_init() + if backend is None: + raise ValueError("Failed to initialize CPU backend") + return Backend(backend=backend) + + def new_measure(self) -> "Allocr": + allocr = ggml.ggml_allocr_new_measure_from_backend(self.backend) + return Allocr(allocr) + + def alloc_buffer(self, size: int) -> "BackendBuffer": + buffer = ggml.ggml_backend_alloc_buffer(self.backend, size) + return BackendBuffer(buffer) + +class BackendBuffer: + def __init__(self, buffer: ggml.ggml_backend_buffer_t): + self.buffer = buffer + + def __del__(self): + ggml.ggml_backend_buffer_free(self.buffer) + + def new_allocr(self) -> "Allocr": + allocr = ggml.ggml_allocr_new_from_buffer(self.buffer) + return Allocr(allocr) + + +class Allocr: + def __init__(self, allocr: ggml.ggml_allocr_t): + self.allocr = allocr + + def __del__(self): + ggml.ggml_allocr_free(self.allocr) + + def alloc_graph(self, graph: "CGraph") -> int: + return ggml.ggml_allocr_alloc_graph(self.allocr, graph.cgraph) class GGML_TYPE(enum.IntEnum): @@ -63,6 +138,7 @@ class GGML_TYPE(enum.IntEnum): np.int32: GGML_TYPE.I32, } + GGML_TYPE_TO_NUMPY_DTYPE = {v: k for k, v in NUMPY_DTYPE_TO_GGML_TYPE.items()} @@ -113,7 +189,7 @@ def ggml_type(self): @property def shape(self) -> Tuple[int, ...]: - return tuple(self.tensor.contents.ne[: self.tensor.contents.n_dims]) + return tuple(self.tensor.contents.ne[: ggml.ggml_n_dims(self.tensor)]) @property def data(self): @@ -130,7 +206,12 @@ def __len__(self): return self.nelements() @classmethod - def with_buffer(cls, tensor: ggml.ggml_tensor_p, ctx: Optional[Context] = None, src: Optional[List[Tensor]] = None): + def with_buffer( + cls, + tensor: ggml.ggml_tensor_p, + ctx: Optional[Context] = None, + src: Optional[List[Tensor]] = None, + ): src = src or [] if tensor.contents.data is not None: return cls(tensor=tensor, ctx=ctx, src=src) @@ -139,41 +220,50 @@ def with_buffer(cls, tensor: ggml.ggml_tensor_p, ctx: Optional[Context] = None, tensor.contents.data = ctypes.cast(data, ctypes.c_void_p) return cls(tensor=tensor, ctx=ctx, data=data, src=src) + def __getitem__(self, key: int): + return self.numpy()[key] + + def __setitem__(self, key: int, value: Any): + self.numpy()[key] = value + def __add__(self, other: Tensor): - ctx = Context.with_tensor_overhead() + ctx = Context.current_context() op = ggml.ggml_add(ctx.context, self.tensor, other.tensor) return Tensor.with_buffer(op, ctx, src=[self, other]) def __sub__(self, other: Tensor): - ctx = Context.with_tensor_overhead() + ctx = Context.current_context() op = ggml.ggml_sub(ctx.context, self.tensor, other.tensor) return Tensor.with_buffer(op, ctx, src=[self, other]) def __mul__(self, other: Tensor): - ctx = Context.with_tensor_overhead() + ctx = Context.current_context() op = ggml.ggml_mul(ctx.context, self.tensor, other.tensor) return Tensor.with_buffer(op, ctx, src=[self, other]) def __truediv__(self, other: Tensor): - ctx = Context.with_tensor_overhead() + ctx = Context.current_context() op = ggml.ggml_div(ctx.context, self.tensor, other.tensor) return Tensor.with_buffer(op, ctx, src=[self, other]) def __neg__(self): - ctx = Context.with_tensor_overhead() + ctx = Context.current_context() op = ggml.ggml_neg(ctx.context, self.tensor) return Tensor.with_buffer(op, ctx, src=[self]) def __abs__(self): - ctx = Context.with_tensor_overhead() + ctx = Context.current_context() op = ggml.ggml_abs(ctx.context, self.tensor) return Tensor.with_buffer(op, ctx, src=[self]) @classmethod def with_shape( - cls, shape: Sequence[int], ggml_type: GGML_TYPE, ctx: Optional[Context] = None + cls, + shape: Sequence[int], + ggml_type: GGML_TYPE = GGML_TYPE.F32, + ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() tensor = ggml.ggml_new_tensor( ctx.context, ggml_type.value, @@ -184,7 +274,7 @@ def with_shape( @classmethod def from_numpy(cls, x: npt.NDArray[Any], ctx: Optional[Context] = None): - _ctx = ctx or Context.with_tensor_overhead() + _ctx = ctx or Context.current_context() tensor = from_numpy(x, _ctx.context) obj = cls.with_buffer(tensor=tensor, ctx=_ctx) if ctx is None: @@ -197,7 +287,7 @@ def new_tensor( shape: Sequence[int] = (), ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() tensor = ggml.ggml_new_tensor( ctx.context, ggml_type.value, @@ -212,7 +302,7 @@ def new_tensor_1d( ne0: int = 0, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() tensor = ggml.ggml_new_tensor_1d( ctx.context, ggml_type.value, @@ -227,7 +317,7 @@ def new_tensor_2d( ne1: int = 0, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() tensor = ggml.ggml_new_tensor_2d( ctx.context, ggml_type.value, @@ -244,7 +334,7 @@ def new_tensor_3d( ne2: int = 0, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() tensor = ggml.ggml_new_tensor_3d( ctx.context, ggml_type.value, @@ -263,7 +353,7 @@ def new_tensor_4d( ne3: int = 0, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() tensor = ggml.ggml_new_tensor_4d( ctx.context, ggml_type.value, @@ -279,7 +369,7 @@ def new_i32( value: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() tensor = ggml.ggml_new_i32(ctx.context, value) return Tensor.with_buffer(tensor=tensor, ctx=ctx) @@ -288,19 +378,19 @@ def new_f32( value: float, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() tensor = ggml.ggml_new_f32(ctx.context, value) return Tensor.with_buffer(tensor=tensor, ctx=ctx) @staticmethod def dup_tensor(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_dup_tensor(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def view(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_view_tensor(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @@ -350,25 +440,25 @@ def set_name(a: Tensor, name: bytes): @staticmethod def dup(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_dup(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def add(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_add(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def add_inplace(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_add_inplace(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def add1(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_add1(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @@ -382,7 +472,7 @@ def acc( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_acc( ctx.context, a.tensor, @@ -404,7 +494,7 @@ def acc_inplace( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_acc_inplace( ctx.context, a.tensor, @@ -418,151 +508,151 @@ def acc_inplace( @staticmethod def sub(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_sub(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def mul(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_mul(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def div(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_div(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def sqr(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_sqr(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def sqrt(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_sqrt(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def log(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_log(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def log_inplace(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_log_inplace(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def sum(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_sum(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def sum_rows(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_sum_rows(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def mean(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_mean(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def repeat(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_repeat(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def abs(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_abs(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def sgn(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_sgn(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def neg(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_neg(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def step(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_step(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def relu(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_relu(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def gelu(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_gelu(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def silu(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_silu(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def silu_back(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_silu_back(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def norm(a: Tensor, eps: float, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_norm(ctx.context, a.tensor, eps) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod - def rms_norm(a: Tensor, eps: float, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + def rms_norm(a: Tensor, eps: float, ctx: Optional[Context] = None): + ctx = ctx or Context.current_context() op = ggml.ggml_rms_norm(ctx.context, a.tensor, eps) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def rms_norm_back(a: Tensor, b: Tensor, eps: float, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_rms_norm_back(ctx.context, a.tensor, b.tensor, eps) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def mul_mat(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_mul_mat(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def scale(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_scale(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def scale_inplace(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_scale_inplace(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @@ -576,7 +666,7 @@ def set( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_set( ctx.context, a.tensor, @@ -598,7 +688,7 @@ def set_inplace( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_set_inplace( ctx.context, a.tensor, @@ -612,7 +702,7 @@ def set_inplace( @staticmethod def set_1d(a: Tensor, b: Tensor, offset: int, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_set_1d(ctx.context, a.tensor, b.tensor, offset) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @@ -620,7 +710,7 @@ def set_1d(a: Tensor, b: Tensor, offset: int, ctx: Optional[Context] = None): def set_1d_inplace( a: Tensor, b: Tensor, offset: int, ctx: Optional[Context] = None ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_set_1d_inplace(ctx.context, a.tensor, b.tensor, offset) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @@ -632,7 +722,7 @@ def set_2d( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_set_2d( ctx.context, a.tensor, @@ -650,7 +740,7 @@ def set_2d_inplace( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_set_2d_inplace( ctx.context, a.tensor, @@ -662,31 +752,31 @@ def set_2d_inplace( @staticmethod def cpy(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_cpy(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def cont(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_cont(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def reshape(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_reshape(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def reshape_1d(a: Tensor, ne0: int, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_reshape_1d(ctx.context, a.tensor, ne0) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def reshape_2d(a: Tensor, ne0: int, ne1: int, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_reshape_2d(ctx.context, a.tensor, ne0, ne1) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @@ -694,7 +784,7 @@ def reshape_2d(a: Tensor, ne0: int, ne1: int, ctx: Optional[Context] = None): def reshape_3d( a: Tensor, ne0: int, ne1: int, ne2: int, ctx: Optional[Context] = None ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_reshape_3d( ctx.context, a.tensor, @@ -713,7 +803,7 @@ def reshape_4d( ne3: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_reshape_4d( ctx.context, a.tensor, @@ -731,7 +821,7 @@ def view_1d( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_view_1d(ctx.context, a.tensor, ne0, offset) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @@ -744,7 +834,7 @@ def view_2d( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_view_2d( ctx.context, a.tensor, @@ -766,7 +856,7 @@ def view_3d( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_view_3d( ctx.context, a.tensor, @@ -792,7 +882,7 @@ def view_4d( offset: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_view_4d( ctx.context, a.tensor, @@ -816,7 +906,7 @@ def permute( axis3: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_permute( ctx.context, a.tensor, @@ -829,61 +919,61 @@ def permute( @staticmethod def transpose(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_transpose(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def get_rows(a: Tensor, b: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_get_rows(ctx.context, a.tensor, b.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @staticmethod def get_rows_back(a: Tensor, b: Tensor, c: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_get_rows_back(ctx.context, a.tensor, b.tensor, c.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b, c]) @staticmethod def diag(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_diag(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def diag_mask_inf(a: Tensor, n_past: int, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_diag_mask_inf(ctx.context, a.tensor, n_past) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def diag_mask_inf_inplace(a: Tensor, n_past: int, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_diag_mask_inf_inplace(ctx.context, a.tensor, n_past) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def diag_mask_zero(a: Tensor, n_past: int, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_diag_mask_zero(ctx.context, a.tensor, n_past) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def diag_mask_zero_inplace(a: Tensor, n_past: int, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_diag_mask_zero_inplace(ctx.context, a.tensor, n_past) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def soft_max(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_soft_max(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod def soft_max_inplace(a: Tensor, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_soft_max_inplace(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @@ -896,7 +986,7 @@ def rope( n_ctx: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_rope( ctx.context, a.tensor, @@ -916,8 +1006,10 @@ def rope_inplace( n_ctx: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() - op = ggml.ggml_rope_inplace(ctx.context, a.tensor, b.tensor, n_dims, mode, n_ctx) + ctx = ctx or Context.current_context() + op = ggml.ggml_rope_inplace( + ctx.context, a.tensor, b.tensor, n_dims, mode, n_ctx + ) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @staticmethod @@ -933,7 +1025,7 @@ def rope_back( xpos_down: bool, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_rope_back( ctx.context, a.tensor, @@ -956,7 +1048,7 @@ def alibi( bias_max: float, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_alibi( ctx.context, a.tensor, @@ -968,7 +1060,7 @@ def alibi( @staticmethod def clamp(a: Tensor, min: float, max: float, ctx: Optional[Context] = None): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_clamp(ctx.context, a.tensor, min, max) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a]) @@ -981,7 +1073,7 @@ def conv_1d( d0: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_conv_1d(ctx.context, a.tensor, b.tensor, s0, p0, d0) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @@ -997,7 +1089,7 @@ def conv_2d( d1: int, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_conv_2d(ctx.context, a.tensor, b.tensor, s0, s1, p0, p1, d0, d1) return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b]) @@ -1009,7 +1101,7 @@ def flash_attn( masked: bool, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_flash_attn( ctx.context, q.tensor, @@ -1028,7 +1120,7 @@ def flash_ff( c1: Tensor, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_flash_ff( ctx.context, a.tensor, b0.tensor, b1.tensor, c0.tensor, c1.tensor ) @@ -1041,7 +1133,7 @@ def map_unary_f32( fun: Callable[[float], float], ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_map_unary_f32( # type: ignore ctx.context, a.tensor, ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float)(fun) ) @@ -1055,7 +1147,7 @@ def map_binary_f32( fun: Callable[[float, float], float], ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_map_binary_f32( # type: ignore ctx.context, a.tensor, @@ -1069,7 +1161,7 @@ def set_param( a: Tensor, ctx: Optional[Context] = None, ): - ctx = ctx or Context.with_tensor_overhead() + ctx = ctx or Context.current_context() op = ggml.ggml_set_param(ctx.context, a.tensor) return Tensor.with_buffer(tensor=op, ctx=ctx) @@ -1079,52 +1171,31 @@ def ggml_ftype_to_ggml_type(ftype: int): class CGraph: - def __init__(self, cgraph: ggml.ggml_cgraph): + def __init__(self, cgraph: ggml.ggml_cgraph_p): self.cgraph = cgraph self._tensors: List[Tensor] = [] - def compute(self, n_threads: int = 1): - gp = ggml.ggml_graph_plan(ctypes.pointer(self.cgraph), n_threads=n_threads) - if gp.work_size > 0: - work_data = (ctypes.c_uint8 * gp.work_size)() - gp.work_data = ctypes.cast(work_data, ctypes.c_void_p) - ggml.ggml_graph_compute(ctypes.pointer(self.cgraph), ctypes.pointer(gp)) - else: - ggml.ggml_graph_compute(ctypes.pointer(self.cgraph), ctypes.pointer(gp)) - - def reset(self): - ggml.ggml_graph_reset(ctypes.pointer(self.cgraph)) - self._tensors = [] - - def get_tensor(self, name: bytes): - return Tensor( - tensor=ggml.ggml_graph_get_tensor(ctypes.pointer(self.cgraph), name), - ) + def build_forward_expand(self, tensor: Tensor): + ggml.ggml_build_forward_expand(self.cgraph, tensor.tensor) - def graph_export(self, fname: bytes): - ggml.ggml_graph_export(ctypes.pointer(self.cgraph), fname) + def compute(self, backend: Backend): + ggml.ggml_backend_graph_compute(backend.backend, self.cgraph) - def build_forward_expand(self, tensor: Tensor): - ggml.ggml_build_forward_expand(ctypes.pointer(self.cgraph), tensor.tensor) - self._tensors.append(tensor) - @staticmethod - def print(a: CGraph): - ggml.ggml_graph_print(ctypes.pointer(a.cgraph)) +def ggml_cgraph( + tensor_or_tensors: Union[None, Tensor, Sequence[Tensor]], + ctx: Optional[Context] = None, +): + ctx = ctx or Context.current_context() + graph = CGraph(ggml.ggml_new_graph(ctx.context)) - @staticmethod - def dump_dot( - gb: CGraph, - gf: Optional[CGraph], - filename: bytes, - ): - gf_p = ctypes.pointer(gf.cgraph) if gf else None - ggml.ggml_graph_dump_dot( - ctypes.pointer(gb.cgraph), gf_p, filename # type: ignore - ) + if tensor_or_tensors is None: + tensor_or_tensors = [] - @classmethod - def build_forward(cls, tensor: Tensor): - obj = CGraph(cgraph=ggml.ggml_build_forward(tensor.tensor)) - obj._tensors.append(tensor) - return obj + if isinstance(tensor_or_tensors, Tensor): + tensor_or_tensors = [tensor_or_tensors] + + for tensor in tensor_or_tensors: + graph.build_forward_expand(tensor) + + return graph diff --git a/tests/test_experimental.py b/tests/test_experimental.py deleted file mode 100644 index 7d65a95..0000000 --- a/tests/test_experimental.py +++ /dev/null @@ -1,26 +0,0 @@ -from ggml.experimental import CGraph, Tensor, GGML_TYPE - -import pytest - -import numpy as np - - -@pytest.mark.skip(reason="not implemented") -def test_tensor(): - x = np.ones((3,), dtype=np.float32) - assert x.shape == (3,) - t = Tensor.from_numpy(x) - assert t.shape == (3,) - assert t.ggml_type == GGML_TYPE.F32 - assert np.allclose(t.numpy(), x) - -@pytest.mark.skip(reason="not implemented") -def test_tensor_compute(): - x = Tensor.from_numpy(np.array([2.0], dtype=np.float32)) - a = Tensor.from_numpy(np.array([3.0], dtype=np.float32)) - b = Tensor.from_numpy(np.array([4.0], dtype=np.float32)) - x2 = x * x - f = a * x2 + b - gf = CGraph.build_forward(f) - gf.compute() - assert np.allclose(f.numpy(), np.array([16.0], dtype=np.float32)) diff --git a/tests/test_experimental_api.py b/tests/test_experimental_api.py new file mode 100644 index 0000000..4a55f70 --- /dev/null +++ b/tests/test_experimental_api.py @@ -0,0 +1,92 @@ +from ggml.experimental import ggml_context, ggml_cgraph, Tensor, GGML_TYPE, Backend + +def test_experimental_api(): + backend = Backend.cpu() + + with ggml_context(): + a = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + b = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + x = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + + f = a * x * x + b + + assert f.shape == (1,) + + backend.alloc_ctx_tensors() + + a[0] = 3.0 + b[0] = 4.0 + x[0] = 2.0 + + assert a[0] == 3.0 + assert b[0] == 4.0 + assert x[0] == 2.0 + + graph = ggml_cgraph(f) + graph.compute(backend) + + assert f[0] == 16.0 + + with ggml_context(): + a = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + b = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + + backend.alloc_ctx_tensors() + + a[0] = 3.0 + b[0] = 4.0 + + with ggml_context(): + x = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + + f = a * x * x + b + + assert f.shape == (1,) + + backend.alloc_ctx_tensors() + + x[0] = 2.0 + + assert a[0] == 3.0 + assert b[0] == 4.0 + assert x[0] == 2.0 + + graph = ggml_cgraph(f) + graph.compute(backend) + + assert f[0] == 16.0 + + with ggml_context(): + a = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + b = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + + backend.alloc_ctx_tensors() + + a[0] = 3.0 + b[0] = 4.0 + + with ggml_context(): + x = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32) + + backend.alloc_ctx_tensors() + + x[0] = 2.0 + + f = a * x * x + b + + assert f.shape == (1,) + + measure_allocr = backend.new_measure() + + graph = ggml_cgraph(f) + + mem_size = measure_allocr.alloc_graph(graph) + + buffer = backend.alloc_buffer(mem_size) + + allocr = buffer.new_allocr() + allocr.alloc_graph(graph) + + graph.compute(backend) + + assert f[0] == 16.0