Skip to content

Commit

Permalink
Add releaseSession interface
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Jan 8, 2024
1 parent 57477cd commit 7377d5f
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 1 deletion.
6 changes: 6 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ export interface OrtWasmModule extends EmscriptenModule {
jsepCreateDownloader:
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
/**
* [exported from js_internal_api.js] Release a session.
* @param sessionId - specify the session ID.
* @returns
*/
jsepReleaseSession: (sessionId: number) => void;
// #endregion
}

Expand Down
7 changes: 7 additions & 0 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -571,4 +571,11 @@ export class WebGpuBackend {
}
this.status = StatusType.default;
}

releaseSession(sessionId: number): void {
if (this.capturedCommandList.has(sessionId)) {
this.capturedCommandList.delete(sessionId);
}
this.gpuDataManager.releaseSession(sessionId);
}
}
41 changes: 40 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ export interface GpuDataManager {
unregisterExternalBuffer(buffer: GPUBuffer): void;

/**
* destroy all gpu buffers. Call this when the session.release is called.
* destroy all gpu buffers.
*/
dispose(): void;

/**
* release session related data.
* @param sessionId - specify the session ID.
*/
releaseSession(sessionId: number): void;
}

interface StorageCacheValue {
Expand Down Expand Up @@ -139,13 +145,18 @@ class GpuDataManagerImpl implements GpuDataManager {
// The external buffers registered users for IO Binding.
private externalBuffers: Map<GPUBuffer, GpuDataId>;

// The pendingBuffers for capture graph.
// a SessionID -> GPUBuffer[] mapping.
private capturedPendingBuffers: Map<number, GPUBuffer[]>;

constructor(private backend: WebGpuBackend) {
this.storageCache = new Map();
this.freeBuffers = new Map();
this.freeUniformBuffers = new Map();
this.buffersForUploadingPending = [];
this.buffersPending = [];
this.externalBuffers = new Map();
this.capturedPendingBuffers = new Map();
}

upload(id: GpuDataId, data: Uint8Array): void {
Expand Down Expand Up @@ -313,6 +324,10 @@ class GpuDataManagerImpl implements GpuDataManager {
}
this.buffersForUploadingPending = [];

if (this.buffersPending.length === 0) {
return;
}

// Don't release intermediate tensors in non-default mode.
if (this.backend.status === StatusType.default) {
for (const buffer of this.buffersPending) {
Expand All @@ -329,6 +344,16 @@ class GpuDataManagerImpl implements GpuDataManager {
}
}
this.buffersPending = [];
} else {
let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!);
if (!capturedBuffers) {
capturedBuffers = [];
this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers);
}
for (const buffer of this.buffersPending) {
capturedBuffers.push(buffer);
}
this.buffersPending = [];
}
}

Expand All @@ -348,9 +373,23 @@ class GpuDataManagerImpl implements GpuDataManager {
storage.gpuData.buffer.destroy();
});

this.capturedPendingBuffers.forEach((buffers) => {
buffers.forEach(buffer => {
buffer.destroy();
});
});
this.storageCache = new Map();
this.freeBuffers = new Map();
this.freeUniformBuffers = new Map();
this.capturedPendingBuffers = new Map();
}

releaseSession(sessionId: number) {
// release the captured pending buffers.
const pendingBffers = this.capturedPendingBuffers.get(sessionId);
pendingBffers!.forEach(buffer => {
buffer.destroy();
});
}
}

Expand Down
1 change: 1 addition & 0 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ export const releaseSession = (sessionId: number): void => {
}

wasm.jsepUnregisterBuffers?.(sessionId);
wasm.jsepReleaseSession?.(sessionId);

inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/wasm/js_internal_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
return backend['createDownloader'](gpuBuffer, size, type);
};
Module['jsepReleaseSession'] = sessionId => {
backend['releaseSession'](sessionId);
};
};

0 comments on commit 7377d5f

Please sign in to comment.