Skip to content

Commit

Permalink
Update high level api
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Feb 9, 2024
1 parent cb984cd commit d88f325
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
33 changes: 24 additions & 9 deletions ggml/experimental.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

import enum
import ctypes

Expand Down Expand Up @@ -45,7 +46,7 @@ def __enter__(self):
self._current.append(self)
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
self._current.pop()
return False

Expand Down Expand Up @@ -93,6 +94,7 @@ def alloc_buffer(self, size: int) -> "BackendBuffer":
buffer = ggml.ggml_backend_alloc_buffer(self.backend, size)
return BackendBuffer(buffer)


class BackendBuffer:
def __init__(self, buffer: ggml.ggml_backend_buffer_t):
self.buffer = buffer
Expand Down Expand Up @@ -196,6 +198,8 @@ def data(self):
return ggml.ggml_get_data(self.tensor)

def set_data(self, data: bytes):
if self.data is None:
raise ValueError("Data is not allocated")
return ctypes.memmove(self.data, data, self.nbytes())

def numpy(self):
Expand Down Expand Up @@ -645,16 +649,16 @@ def mul_mat(a: Tensor, b: Tensor, ctx: Optional[Context] = None):
return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b])

@staticmethod
def scale(a: Tensor, b: Tensor, ctx: Optional[Context] = None):
def scale(a: Tensor, s: float, ctx: Optional[Context] = None):
ctx = ctx or Context.current_context()
op = ggml.ggml_scale(ctx.context, a.tensor, b.tensor)
return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b])
op = ggml.ggml_scale(ctx.context, a.tensor, s)
return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a])

@staticmethod
def scale_inplace(a: Tensor, b: Tensor, ctx: Optional[Context] = None):
def scale_inplace(a: Tensor, s: float, ctx: Optional[Context] = None):
ctx = ctx or Context.current_context()
op = ggml.ggml_scale_inplace(ctx.context, a.tensor, b.tensor)
return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a, b])
op = ggml.ggml_scale_inplace(ctx.context, a.tensor, s)
return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a])

@staticmethod
def set(
Expand Down Expand Up @@ -1019,8 +1023,13 @@ def rope_back(
n_dims: int,
mode: int,
n_ctx: int,
n_orig_ctx: int,
freq_base: float,
freq_scale: float,
ext_factor: float,
attn_factor: float,
beta_fast: float,
beta_slow: float,
xpos_base: float,
xpos_down: bool,
ctx: Optional[Context] = None,
Expand All @@ -1033,10 +1042,15 @@ def rope_back(
n_dims,
mode,
n_ctx,
n_orig_ctx,
freq_base,
freq_scale,
ext_factor,
attn_factor,
beta_fast,
beta_slow,
xpos_base,
xpos_down,
xpos_down
)
return Tensor.with_buffer(tensor=op, ctx=ctx, src=[a])

Expand Down Expand Up @@ -1173,9 +1187,10 @@ def ggml_ftype_to_ggml_type(ftype: int):
class CGraph:
def __init__(self, cgraph: ggml.ggml_cgraph_p):
self.cgraph = cgraph
self._tensors: List[Tensor] = []
self._output_tensors: List[Tensor] = []

def build_forward_expand(self, tensor: Tensor):
self._output_tensors.append(tensor)
ggml.ggml_build_forward_expand(self.cgraph, tensor.tensor)

def compute(self, backend: Backend):
Expand Down
29 changes: 13 additions & 16 deletions tests/test_experimental_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ggml.experimental import ggml_context, ggml_cgraph, Tensor, GGML_TYPE, Backend


def test_experimental_api():
backend = Backend.cpu()

Expand Down Expand Up @@ -59,34 +60,30 @@ def test_experimental_api():
with ggml_context():
a = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32)
b = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32)
x = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32)

backend.alloc_ctx_tensors()

a[0] = 3.0
b[0] = 4.0

with ggml_context():
x = Tensor.with_shape((1,), ggml_type=GGML_TYPE.F32)

backend.alloc_ctx_tensors()

x[0] = 2.0
f = a * x * x + b

f = a * x * x + b
assert f.shape == (1,)

assert f.shape == (1,)
measure_allocr = backend.new_measure()

measure_allocr = backend.new_measure()
graph = ggml_cgraph(f)

graph = ggml_cgraph(f)
mem_size = measure_allocr.alloc_graph(graph)

mem_size = measure_allocr.alloc_graph(graph)
buffer = backend.alloc_buffer(mem_size)

buffer = backend.alloc_buffer(mem_size)
allocr = buffer.new_allocr()
allocr.alloc_graph(graph)

allocr = buffer.new_allocr()
allocr.alloc_graph(graph)
x[0] = 2.0

graph.compute(backend)
graph.compute(backend)

assert f[0] == 16.0
assert f[0] == 16.0

0 comments on commit d88f325

Please sign in to comment.