Skip to content

Commit

Permalink
webgpu: ensure api restrictions
Browse files Browse the repository at this point in the history
The webgpu API can only be used in the context of a durable object
and when the "webgpu" compatibility flag is set.
  • Loading branch information
edevil committed Aug 18, 2023
1 parent baf7c0b commit da2bad2
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 21 deletions.
4 changes: 3 additions & 1 deletion src/workerd/api/global-scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class Navigator: public jsg::Object {
public:
kj::StringPtr getUserAgent() { return "Cloudflare-Workers"_kj; }
#ifdef WORKERD_EXPERIMENTAL_ENABLE_WEBGPU
jsg::Ref<api::gpu::GPU> getGPU() { return jsg::alloc<api::gpu::GPU>(); }
jsg::Ref<api::gpu::GPU> getGPU(CompatibilityFlags::Reader flags) {
return jsg::alloc<api::gpu::GPU>(flags);
}
#endif

JSG_RESOURCE_TYPE(Navigator) {
Expand Down
13 changes: 12 additions & 1 deletion src/workerd/api/gpu/gpu.c++
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// https://opensource.org/licenses/Apache-2.0

#include "gpu.h"
#include "workerd/jsg/exception.h"
#include <dawn/dawn_proc.h>

namespace workerd::api::gpu {
Expand All @@ -21,7 +22,17 @@ void initialize() {
dawnProcSetProcs(&dawn::native::GetProcs());
}

GPU::GPU() {
GPU::GPU(CompatibilityFlags::Reader flags) {
// is this a durable object?
KJ_IF_MAYBE (actor, IoContext::current().getActor()) {
JSG_REQUIRE(actor->getPersistent() != nullptr, TypeError,
"webgpu api is only available in Durable Objects (no storage)");
} else {
JSG_FAIL_REQUIRE(TypeError, "webgpu api is only available in Durable Objects (no actor)");
};

JSG_REQUIRE(flags.getWebgpu(), TypeError, "webgpu needs the webgpu compatibility flag set");

instance_.DiscoverDefaultAdapters();
}

Expand Down
2 changes: 1 addition & 1 deletion src/workerd/api/gpu/gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct GPURequestAdapterOptions {

class GPU : public jsg::Object {
public:
explicit GPU();
explicit GPU(CompatibilityFlags::Reader flags);
JSG_RESOURCE_TYPE(GPU) {
JSG_METHOD(requestAdapter);
}
Expand Down
9 changes: 8 additions & 1 deletion src/workerd/api/gpu/webgpu-buffer-test.gpu-wd-test
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ const unitTests :Workerd.Config = (
modules = [
(name = "worker", esModule = embed "webgpu-buffer-test.js")
],
durableObjectNamespaces = [
(className = "DurableObjectExample", uniqueKey = "210bd0cbd803ef7883a1ee9d86cce06e"),
],
durableObjectStorage = (inMemory = void),
bindings = [
(name = "ns", durableObjectNamespace = "DurableObjectExample"),
],
compatibilityDate = "2023-01-15",
compatibilityFlags = ["experimental", "nodejs_compat"],
compatibilityFlags = ["experimental", "nodejs_compat", "webgpu"],
)
),
],
Expand Down
24 changes: 20 additions & 4 deletions src/workerd/api/gpu/webgpu-buffer-test.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import { deepEqual, ok } from "node:assert";
import { deepEqual, ok, equal } from "node:assert";

// run manually for now
// bazel run --//src/workerd/io:enable_experimental_webgpu //src/workerd/server:workerd -- test `realpath ./src/workerd/api/gpu/webgpu-buffer-test.gpu-wd-test` --verbose --experimental

export const read_sync_stack = {
async test(ctrl, env, ctx) {
export class DurableObjectExample {
constructor(state) {
this.state = state;
}

async fetch() {
ok(navigator.gpu);
const adapter = await navigator.gpu.requestAdapter();
ok(adapter);
Expand Down Expand Up @@ -56,6 +60,18 @@ export const read_sync_stack = {
const copyArrayBuffer = gpuReadBuffer.getMappedRange();
ok(copyArrayBuffer);

deepEqual(new Uint8Array(copyArrayBuffer), new Uint8Array([ 0, 1, 2, 3 ]));
deepEqual(new Uint8Array(copyArrayBuffer), new Uint8Array([0, 1, 2, 3]));

return new Response("OK");
}
}

export const buffer_mapping = {
async test(ctrl, env, ctx) {
let id = env.ns.idFromName("A");
let obj = env.ns.get(id);
let res = await obj.fetch("http://foo/test");
let text = await res.text();
equal(text, "OK");
},
};
9 changes: 8 additions & 1 deletion src/workerd/api/gpu/webgpu-compute-test.gpu-wd-test
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ const unitTests :Workerd.Config = (
modules = [
(name = "worker", esModule = embed "webgpu-compute-test.js")
],
durableObjectNamespaces = [
(className = "DurableObjectExample", uniqueKey = "210bd0cbd803ef7883a1ee9d86cce06e"),
],
durableObjectStorage = (inMemory = void),
bindings = [
(name = "ns", durableObjectNamespace = "DurableObjectExample"),
],
compatibilityDate = "2023-01-15",
compatibilityFlags = ["experimental", "nodejs_compat"],
compatibilityFlags = ["experimental", "nodejs_compat", "webgpu"],
)
),
],
Expand Down
22 changes: 19 additions & 3 deletions src/workerd/api/gpu/webgpu-compute-test.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import { deepEqual, ok } from "node:assert";
import { deepEqual, ok, equal } from "node:assert";

// run manually for now
// bazel run --//src/workerd/io:enable_experimental_webgpu //src/workerd/server:workerd -- test `realpath ./src/workerd/api/gpu/webgpu-compute-test.gpu-wd-test` --verbose --experimental

export const read_sync_stack = {
async test(ctrl, env, ctx) {
export class DurableObjectExample {
constructor(state) {
this.state = state;
}

async fetch() {
ok(navigator.gpu);
if (!("gpu" in navigator)) {
console.log(
Expand Down Expand Up @@ -271,5 +275,17 @@ export const read_sync_stack = {
new Float32Array(arrayBuffer),
new Float32Array([2, 2, 50, 60, 114, 140])
);

return new Response("OK");
}
}

export const compute_shader = {
async test(ctrl, env, ctx) {
let id = env.ns.idFromName("A");
let obj = env.ns.get(id);
let res = await obj.fetch("http://foo/test");
let text = await res.text();
equal(text, "OK");
},
};
9 changes: 8 additions & 1 deletion src/workerd/api/gpu/webgpu-errors-test.gpu-wd-test
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ const unitTests :Workerd.Config = (
modules = [
(name = "worker", esModule = embed "webgpu-errors-test.js")
],
durableObjectNamespaces = [
(className = "DurableObjectExample", uniqueKey = "210bd0cbd803ef7883a1ee9d86cce06e"),
],
durableObjectStorage = (inMemory = void),
bindings = [
(name = "ns", durableObjectNamespace = "DurableObjectExample"),
],
compatibilityDate = "2023-01-15",
compatibilityFlags = ["experimental", "nodejs_compat"],
compatibilityFlags = ["experimental", "nodejs_compat", "webgpu"],
)
),
],
Expand Down
22 changes: 19 additions & 3 deletions src/workerd/api/gpu/webgpu-errors-test.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import { ok } from "node:assert";
import { ok, equal } from "node:assert";

// run manually for now
// bazel run --//src/workerd/io:enable_experimental_webgpu //src/workerd/server:workerd -- test `realpath ./src/workerd/api/gpu/webgpu-errors-test.gpu-wd-test` --verbose --experimental

export const read_sync_stack = {
async test(ctrl, env, ctx) {
export class DurableObjectExample {
constructor(state) {
this.state = state;
}

async fetch() {
ok(navigator.gpu);

const adapter = await navigator.gpu.requestAdapter();
Expand Down Expand Up @@ -79,5 +83,17 @@ export const read_sync_stack = {

// ensure callback with error was indeed called
ok(callbackCalled);

return new Response("OK");
}
}

export const error_handling = {
async test(ctrl, env, ctx) {
let id = env.ns.idFromName("A");
let obj = env.ns.get(id);
let res = await obj.fetch("http://foo/test");
let text = await res.text();
equal(text, "OK");
},
};
9 changes: 8 additions & 1 deletion src/workerd/api/gpu/webgpu-write-test.gpu-wd-test
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ const unitTests :Workerd.Config = (
modules = [
(name = "worker", esModule = embed "webgpu-write-test.js")
],
durableObjectNamespaces = [
(className = "DurableObjectExample", uniqueKey = "210bd0cbd803ef7883a1ee9d86cce06e"),
],
durableObjectStorage = (inMemory = void),
bindings = [
(name = "ns", durableObjectNamespace = "DurableObjectExample"),
],
compatibilityDate = "2023-01-15",
compatibilityFlags = ["experimental", "nodejs_compat"],
compatibilityFlags = ["experimental", "nodejs_compat", "webgpu"],
)
),
],
Expand Down
24 changes: 20 additions & 4 deletions src/workerd/api/gpu/webgpu-write-test.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import { ok, deepEqual } from "node:assert";
import { ok, deepEqual, equal } from "node:assert";

// run manually for now
// bazel run --//src/workerd/io:enable_experimental_webgpu //src/workerd/server:workerd -- test `realpath ./src/workerd/api/gpu/webgpu-write-test.gpu-wd-test` --verbose --experimental

export const read_sync_stack = {
async test(ctrl, env, ctx) {
export class DurableObjectExample {
constructor(state) {
this.state = state;
}

async fetch() {
ok(navigator.gpu);
const adapter = await navigator.gpu.requestAdapter();
ok(adapter);
Expand All @@ -23,6 +27,18 @@ export const read_sync_stack = {

// Write bytes to buffer.
new Uint8Array(arrayBuffer).set([0, 1, 2, 3]);
deepEqual(new Uint8Array(arrayBuffer), new Uint8Array([ 0, 1, 2, 3 ]));
deepEqual(new Uint8Array(arrayBuffer), new Uint8Array([0, 1, 2, 3]));

return new Response("OK");
}
}

export const buffer_write = {
async test(ctrl, env, ctx) {
let id = env.ns.idFromName("A");
let obj = env.ns.get(id);
let res = await obj.fetch("http://foo/test");
let text = await res.text();
equal(text, "OK");
},
};
4 changes: 4 additions & 0 deletions src/workerd/io/compatibility-date.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -340,4 +340,8 @@ struct CompatibilityFlags @0x8f8c1b68151b6cef {
$compatEnableFlag("rtti_api")
$experimental;
# Enables the `workerd:rtti` module for querying runtime-type-information from JavaScript.

webgpu @35 :Bool
$compatEnableFlag("webgpu")
$experimental;
}

0 comments on commit da2bad2

Please sign in to comment.