Skip to content

Commit

Permalink
Merge pull request #123 from yuzawa-san/shared-arena-config
Browse files Browse the repository at this point in the history
Shared  environment arena config
  • Loading branch information
yuzawa-san authored Jun 2, 2023
2 parents ac5b327 + dd0ca1a commit ed1d8a1
Show file tree
Hide file tree
Showing 19 changed files with 62 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/Api.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
/**
* The top-level API of the ONNX runtime.
*
* @since 1.0.0
*/
public interface Api {

Expand Down
4 changes: 4 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ final class ApiImpl implements Api {
final CastTypeInfoToTensorInfo CastTypeInfoToTensorInfo;
final CreateAllocator CreateAllocator;
final CreateAndRegisterAllocator CreateAndRegisterAllocator;
final CreateArenaCfgV2 CreateArenaCfgV2;
final CreateCpuMemoryInfo CreateCpuMemoryInfo;
final CreateCUDAProviderOptions CreateCUDAProviderOptions;
final CreateDnnlProviderOptions CreateDnnlProviderOptions;
Expand Down Expand Up @@ -84,6 +85,7 @@ final class ApiImpl implements Api {
final ModelMetadataGetProducerName ModelMetadataGetProducerName;
final ModelMetadataGetVersion ModelMetadataGetVersion;
final ModelMetadataLookupCustomMetadataMap ModelMetadataLookupCustomMetadataMap;
final RegisterAllocator RegisterAllocator;
final RegisterCustomOpsLibrary RegisterCustomOpsLibrary;
final ReleaseAllocator ReleaseAllocator;
final ReleaseAvailableProviders ReleaseAvailableProviders;
Expand Down Expand Up @@ -155,6 +157,7 @@ final class ApiImpl implements Api {
this.CastTypeInfoToTensorInfo = OrtApi.CastTypeInfoToTensorInfo(memorySegment, memorySession);
this.CreateAllocator = OrtApi.CreateAllocator(memorySegment, memorySession);
this.CreateAndRegisterAllocator = OrtApi.CreateAndRegisterAllocator(memorySegment, memorySession);
this.CreateArenaCfgV2 = OrtApi.CreateArenaCfgV2(memorySegment, memorySession);
this.CreateCpuMemoryInfo = OrtApi.CreateCpuMemoryInfo(memorySegment, memorySession);
this.CreateCUDAProviderOptions = OrtApi.CreateCUDAProviderOptions(memorySegment, memorySession);
this.CreateDnnlProviderOptions = OrtApi.CreateDnnlProviderOptions(memorySegment, memorySession);
Expand Down Expand Up @@ -208,6 +211,7 @@ final class ApiImpl implements Api {
this.ModelMetadataGetVersion = OrtApi.ModelMetadataGetVersion(memorySegment, memorySession);
this.ModelMetadataLookupCustomMetadataMap =
OrtApi.ModelMetadataLookupCustomMetadataMap(memorySegment, memorySession);
this.RegisterAllocator = OrtApi.RegisterAllocator(memorySegment, memorySession);
this.RegisterCustomOpsLibrary = OrtApi.RegisterCustomOpsLibrary(memorySegment, memorySession);
this.ReleaseAllocator = OrtApi.ReleaseAllocator(memorySegment, memorySession);
this.ReleaseAvailableProviders = OrtApi.ReleaseAvailableProviders(memorySegment, memorySession);
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/Environment.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
*/
package com.jyuzawa.onnxruntime;

import java.util.Map;

/**
* The environment in which model evaluation sessions can be constructed. Only
* one environment should be used in your application. This class is thread safe.
*
* @since 1.0.0
*/
public interface Environment extends AutoCloseable {

Expand Down Expand Up @@ -83,6 +87,14 @@ public interface Builder {
*/
Builder setGlobalSpinControl(boolean globalSpinControl);

/**
* Set whether the environment's shared allocator for CPU should be arena-based
* @param config the key/value configuration of the arena
* @return the builder
* @since 1.2.0
*/
Builder setArenaConfig(Map<String, Long> config);

/**
* Constructs the {@link Environment}.
*
Expand Down
28 changes: 27 additions & 1 deletion src/main/java/com/jyuzawa/onnxruntime/EnvironmentImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
*/
package com.jyuzawa.onnxruntime;

import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.ORT_PROJECTION_JAVA;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtArenaAllocator;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtMemTypeDefault;

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.util.Map;

final class EnvironmentImpl extends ManagedImpl implements Environment {

Expand Down Expand Up @@ -48,8 +51,25 @@ final class EnvironmentImpl extends ManagedImpl implements Environment {
this.memoryInfo = api.create(
memorySession, out -> api.CreateCpuMemoryInfo.apply(OrtArenaAllocator(), OrtMemTypeDefault(), out));
memorySession.addCloseAction(() -> api.ReleaseMemoryInfo.apply(memoryInfo));
api.checkStatus(api.CreateAndRegisterAllocator.apply(address, memoryInfo, MemoryAddress.NULL));
this.ortAllocator = api.create(memorySession, out -> api.GetAllocatorWithDefaultOptions.apply(out));
Map<String, Long> arenaConfig = builder.arenaConfig;
if (arenaConfig == null) {
api.RegisterAllocator.apply(address, ortAllocator);
} else {
int size = arenaConfig.size();
MemorySegment keyArray = temporarySession.allocateArray(C_POINTER, size);
MemorySegment valueArray = temporarySession.allocateArray(C_LONG, size);
int i = 0;
for (Map.Entry<String, Long> entry : arenaConfig.entrySet()) {
keyArray.setAtIndex(C_POINTER, i, temporarySession.allocateUtf8String(entry.getKey()));
valueArray.setAtIndex(C_LONG, i, entry.getValue());
i++;
}
MemoryAddress arenaConfigAddress = api.create(
temporarySession,
out -> api.CreateArenaCfgV2.apply(keyArray.address(), valueArray.address(), size, out));
api.checkStatus(api.CreateAndRegisterAllocator.apply(address, memoryInfo, arenaConfigAddress));
}
}
}

Expand Down Expand Up @@ -82,6 +102,7 @@ static final class Builder implements Environment.Builder {
private Integer globalInterOpNumThreads;
private Integer globalIntraOpNumThreads;
private Boolean globalSpinControl;
private Map<String, Long> arenaConfig;

Builder(ApiImpl api) {
this.api = api;
Expand Down Expand Up @@ -129,6 +150,11 @@ public Builder setGlobalSpinControl(boolean globalSpinControl) {
return this;
}

public Builder setArenaConfig(Map<String, Long> config) {
this.arenaConfig = config;
return this;
}

private MemoryAddress newThreadingOptions(MemorySession memorySession) {
MemoryAddress threadingOptions = api.create(memorySession, out -> api.CreateThreadingOptions.apply(out));
memorySession.addCloseAction(() -> api.ReleaseThreadingOptions.apply(threadingOptions));
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/MapInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/**
* A description of the type information related to ONNX's "Map" type.
*
* @since 1.0.0
*/
public interface MapInfo {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
/**
* A representation of the metadata stored in an ONNX model.
*
* @since 1.0.0
*/
public interface ModelMetadata {
/**
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/NamedCollection.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
/**
* An immutable collection which can be accessed by list index or name.
* @param <V> the type in the collection
* @since 1.0.0
*/
public interface NamedCollection<V> {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/NodeInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/**
* A tuple of name and type information.
*
* @since 1.0.0
*/
public interface NodeInfo {

Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
/**
* A map view of an {@link OnnxValue}. Use the {@link MapInfo#getKeyType()} to
* select a type-safe view. A {@link NoSuchElementException} will be thrown if a view does not exist for this instance's type.
*
* @since 1.0.0
*/
public interface OnnxMap {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/**
* This is the entrypoint into this library. It is a singleton that is accessible using {@link #get()}.
*
* @since 1.0.0
*/
public interface OnnxRuntime {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxSequence.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* A sequence view of an {@link OnnxValue}. Extends {@link java.util.List} to provide an UNMODIFIABLE view.
* Use {@link #add()} to populate the sequence.
*
* @since 1.0.0
*/
public interface OnnxSequence extends List<OnnxValue> {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
/**
* A representation of a dense tensor. Use {@link #getInfo()} to select the proper buffer type. A {@link NoSuchElementException} will be thrown if a view does not exist for this instance's type.
*
* @since 1.0.0
*/
public interface OnnxTensor {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTypedMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* A type-safe map of an {@link OnnxValue}. Extends {@link java.util.Map} to provide an UNMODIFIABLE view. Use {@link #set(Object)} to populate the map.
*
* @param <K> key type
* @since 1.0.0
*/
public interface OnnxTypedMap<K> extends Map<K, OnnxValue> {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
/**
* A representation of a value provided as an input or output. Use {@link #getType()} to find out which more specific view is present. A {@link NoSuchElementException} will be thrown if a view does not exist for this instance's type.
*
* @since 1.0.0
*/
public interface OnnxValue {

Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

/**
* An evaluation session loaded from an ONNX file or bytes. This class is thread safe.
*
* @since 1.0.0
*/
public interface Session extends AutoCloseable {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
/**
* A description of a tensor: type and shape.
*
* @since 1.0.0
*/
public interface TensorInfo {

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/jyuzawa/onnxruntime/Transaction.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
/**
* A representation of a model evaluation. Should NOT be reused. This class is NOT thread-safe.
*
* @since 1.0.0
*/
public interface Transaction extends AutoCloseable {
/**
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/TypeInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

/**
* A description of the type of an input or output. Use {@link #getType()} to determine what type of additional information is present. A {@link NoSuchElementException} will be thrown from getters which are not valid for the instance's type.
*
* @since 1.0.0
*/
public interface TypeInfo {

Expand Down
1 change: 1 addition & 0 deletions src/test/java/com/jyuzawa/onnxruntime/SessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public static void setup() {
environment = api.newEnvironment()
.setLogSeverityLevel(OnnxRuntimeLoggingLevel.VERBOSE)
.setLogId("testing")
.setArenaConfig(Map.of("initial_chunk_size_bytes", 65535L, "initial_growth_chunk_size_bytes", 65535L))
.build();
}

Expand Down

0 comments on commit ed1d8a1

Please sign in to comment.