Skip to content

Commit

Permalink
Merge branch 'main' into jialli/init-doxygen
Browse files Browse the repository at this point in the history
  • Loading branch information
skyline75489 committed Dec 17, 2024
2 parents 6860cbb + 10932c1 commit bd0f247
Show file tree
Hide file tree
Showing 15 changed files with 95 additions and 88 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/ios-build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: "iOS ARM64 Build"
on:
workflow_dispatch:
push:
branches:
- main
- rel-*
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
iphonesimulator-arm64-build:
runs-on: macos-latest # arm64
steps:
- name: Checkout OnnxRuntime GenAI repo
uses: actions/checkout@v4
with:
submodules: true

- uses: actions/setup-python@v5
with:
python-version: '3.12.x'

- name: Install the python wheel and dependencies
run: |
python3 -m venv genai-macos-venv
source genai-macos-venv/bin/activate
python3 -m pip install requests
- name: Run iOS Build
run: |
set -e -x
source genai-macos-venv/bin/activate
python3 build.py --ios \
--parallel \
--apple_sysroot iphonesimulator \
--osx_arch arm64 \
--apple_deploy_target 15.4 \
--cmake_generator 'Xcode' \
--build_dir build_iphonesimulator
13 changes: 0 additions & 13 deletions examples/c/src/phi3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,6 @@ void CXX_API(const char* model_path) {
is_first_token = false;
}

// Show usage of GetOutput
std::unique_ptr<OgaTensor> output_logits = generator->GetOutput("logits");

// Assuming output_logits.Type() is float as it's logits
// Assuming shape is 1 dimensional with shape[0] being the size
auto logits = reinterpret_cast<float*>(output_logits->Data());

// Print out the logits using the following snippet, if needed
//auto shape = output_logits->Shape();
//for (size_t i=0; i < shape[0]; i++)
// std::cout << logits[i] << " ";
//std::cout << std::endl;

const auto num_tokens = generator->GetSequenceCount(0);
const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
std::cout << tokenizer_stream->Decode(new_token) << std::flush;
Expand Down
8 changes: 3 additions & 5 deletions src/java/src/main/java/ai/onnxruntime/genai/Adapters.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
*/
package ai.onnxruntime.genai;

/**
* A container of adapters.
*/
/** A container of adapters. */
public final class Adapters implements AutoCloseable {
private long nativeHandle = 0;

Expand Down Expand Up @@ -40,8 +38,8 @@ public void loadAdapter(String adapterFilePath, String adapterName) throws GenAI
}

/**
* Unloads the adapter with the given identifier from the previosly loaded adapters.
* If the adapter is not found, or if it cannot be unloaded (when it is in use), an error is returned.
* Unloads the adapter with the given identifier from the previosly loaded adapters. If the
* adapter is not found, or if it cannot be unloaded (when it is in use), an error is returned.
*
* @param adapterName A unique user supplied adapter identifier.
* @throws GenAIException If the call to the GenAI native API fails.
Expand Down
29 changes: 15 additions & 14 deletions src/java/src/main/java/ai/onnxruntime/genai/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
package ai.onnxruntime.genai;

/**
* Use Config to set multiple ORT execution providers. The EP used will be chosen based on the
* Use Config to set the ORT execution providers (EPs) and their options. The EPs are applied based on
* insertion order.
*/
public final class Config implements AutoCloseable {
private long nativeHandle;

/**
* Creates an OgaConfig from the given configuration directory.
* Creates a Config from the given configuration directory.
*
* @param modelPath The path to the configuration directory.
* @throws GenAIException If the call to the GenAI native API fails.
Expand All @@ -20,7 +20,7 @@ public Config(String modelPath) throws GenAIException {
nativeHandle = createConfig(modelPath);
}

/** Clear all providers. */
/** Clear the list of providers in the config */
public void clearProviders() {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
Expand All @@ -29,29 +29,30 @@ public void clearProviders() {
}

/**
* Append a provider with the given name.
* Add the provider at the end of the list of providers in the given config if it doesn't already
* exist. If it already exists, does nothing.
*
* @param provider_name The provider name.
* @param providerName The provider name.
*/
public void appendProvider(String provider_name) {
public void appendProvider(String providerName) {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}
appendProvider(nativeHandle, provider_name);
appendProvider(nativeHandle, providerName);
}

/**
* Set options for a provider.
* Set a provider option.
*
* @param provider_name The provider name.
* @param option_name The option name.
* @param option_value The option value.
* @param providerName The provider name.
* @param optionKey The key of the option to set.
* @param optionValue The value of the option to set.
*/
public void setProviderOption(String provider_name, String option_name, String option_value) {
public void setProviderOption(String providerName, String optionKey, String optionValue) {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}
setProviderOption(nativeHandle, provider_name, option_name, option_value);
setProviderOption(nativeHandle, providerName, optionKey, optionValue);
}

@Override
Expand Down Expand Up @@ -83,5 +84,5 @@ long nativeHandle() {
private native void appendProvider(long configHandle, String provider_name);

private native void setProviderOption(
long configHandle, String provider_name, String option_name, String option_value);
long configHandle, String providerName, String optionKey, String optionValue);
}
4 changes: 2 additions & 2 deletions src/java/src/main/java/ai/onnxruntime/genai/GenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ final class GenAI {
/** The short name of the ONNX runtime shared library */
static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime";

/** The value of the system property */
/** The value of the GENAI_NATIVE_PATH system property */
private static String libraryDirPathProperty;

/** The OS & CPU architecture string */
Expand Down Expand Up @@ -268,7 +268,7 @@ private static Optional<File> extractFromResources(String library) {

/**
* Maps the library name into a platform dependent library filename. Converts macOS's "jnilib" to
* "dylib" but otherwise is the same as {\@link System#mapLibraryName(String)}.
* "dylib" but otherwise is the same as System#mapLibraryName(String).
*
* @param library The library name
* @return The library filename.
Expand Down
10 changes: 6 additions & 4 deletions src/java/src/main/java/ai/onnxruntime/genai/Generator.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ public void appendTokenSequences(Sequences sequences) throws GenAIException {
}

/**
* Rewinds the generator by the specified number of tokens.
* Rewinds the generator to the given length. This is useful when the user wants to rewind the
* generator to a specific length and continue generating from that point.
*
* @param newLength The desired length in tokens after rewinding.
* @throws GenAIException If the call to the GenAI native API fails.
Expand All @@ -108,7 +109,8 @@ public void rewindTo(int newLength) throws GenAIException {
}

/**
* Generates the next token in the sequence.
* Computes the logits from the model based on the input ids and the past state. The computed
* logits are stored in the generator.
*
* @throws GenAIException If the call to the GenAI native API fails.
*/
Expand Down Expand Up @@ -151,7 +153,7 @@ public int getLastTokenInSequence(long sequenceIndex) throws GenAIException {
}

/**
* Fetches and returns the output tensor with the given name.
* Returns a copy of the model output identified by the given name as a Tensor.
*
* @param name The name of the output needed.
* @return The tensor.
Expand All @@ -163,7 +165,7 @@ public Tensor getOutput(String name) throws GenAIException {
}

/**
* Activates one of the loaded adapters.
* Sets the adapter with the given adapter name as active.
*
* @param adapters The Adapters container.
* @param adapterName The adapter name that was previously loaded.
Expand Down
12 changes: 9 additions & 3 deletions src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ public final class GeneratorParams implements AutoCloseable {
private long nativeHandle = 0;
private ByteBuffer tokenIdsBuffer;

GeneratorParams(Model model) throws GenAIException {
/**
* Creates a GeneratorParams from the given model.
*
* @param model The model to use.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public GeneratorParams(Model model) throws GenAIException {
if (model.nativeHandle() == 0) {
throw new IllegalStateException("model has been freed and is invalid");
}
Expand All @@ -27,7 +33,7 @@ public final class GeneratorParams implements AutoCloseable {
*
* @param optionName The option name.
* @param value The option value.
* @throws GenAIException
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void setSearchOption(String optionName, double value) throws GenAIException {
if (nativeHandle == 0) {
Expand All @@ -42,7 +48,7 @@ public void setSearchOption(String optionName, double value) throws GenAIExcepti
*
* @param optionName The option name.
* @param value The option value.
* @throws GenAIException
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void setSearchOption(String optionName, boolean value) throws GenAIException {
if (nativeHandle == 0) {
Expand Down
35 changes: 1 addition & 34 deletions src/java/src/main/java/ai/onnxruntime/genai/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public Model(String modelPath) throws GenAIException {
}

/**
* Construct a Model from Config
* Construct a Model from the given Config.
*
* @param config The config to use.
* @throws GenAIException If the call to the GenAI native API fails.
Expand All @@ -27,39 +27,6 @@ public Model(Config config) throws GenAIException {
nativeHandle = createModelFromConfig(config.nativeHandle());
}

/**
* Creates a Tokenizer instance for this model. The model contains the configuration information
* that determines the tokenizer to use.
*
* @return The Tokenizer instance.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public Tokenizer createTokenizer() throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

return new Tokenizer(this);
}

// NOTE: Having model.createGeneratorParams is still under discussion.
// model.createTokenizer is consistent with the python setup at least and agreed upon.

/**
* Creates a GeneratorParams instance for executing the model. NOTE: GeneratorParams internally
* uses the Model, so the Model instance must remain valid
*
* @return The GeneratorParams instance.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public GeneratorParams createGeneratorParams() throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

return new GeneratorParams(this);
}

@Override
public void close() {
if (nativeHandle != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
package ai.onnxruntime.genai;

/**
* This class is an intermediate storage class that bridges the output of preprocessing and the
* input of the ONNX model.
* This class is a list of tensors with names that match up with model input names.
*/
public class NamedTensors implements AutoCloseable {
private long nativeHandle;
Expand Down
4 changes: 2 additions & 2 deletions src/java/src/main/java/ai/onnxruntime/genai/SimpleGenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class SimpleGenAI implements AutoCloseable {
*/
public SimpleGenAI(String modelPath) throws GenAIException {
model = new Model(modelPath);
tokenizer = model.createTokenizer();
tokenizer = new Tokenizer(model);
}

/**
Expand All @@ -48,7 +48,7 @@ public SimpleGenAI(String modelPath) throws GenAIException {
* @throws GenAIException on failure
*/
public GeneratorParams createGeneratorParams() throws GenAIException {
return model.createGeneratorParams();
return new GeneratorParams(model);
}

/**
Expand Down
4 changes: 2 additions & 2 deletions src/java/src/main/java/ai/onnxruntime/genai/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

/** Wraps ORT native Tensor. */
/** Currently wraps an ORT Tensor. */
public final class Tensor implements AutoCloseable {
private long nativeHandle = 0;
private final ElementType elementType;
Expand All @@ -18,7 +18,7 @@ public final class Tensor implements AutoCloseable {

// The values in this enum must match ONNX values
// https://github.com/onnx/onnx/blob/159fa47b7c4d40e6d9740fcf14c36fff1d11ccd8/onnx/onnx.proto#L499-L544
/** ORT native element types. */
/** Element types that correspond to OnnxRuntime supported element types. */
public enum ElementType {
undefined,
float32,
Expand Down
8 changes: 7 additions & 1 deletion src/java/src/main/java/ai/onnxruntime/genai/Tokenizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
public class Tokenizer implements AutoCloseable {
private long nativeHandle;

Tokenizer(Model model) throws GenAIException {
/**
* Creates a Tokenizer from the given model.
*
* @param model The model to use.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public Tokenizer(Model model) throws GenAIException {
assert (model.nativeHandle() != 0); // internal code should never pass an invalid model

nativeHandle = createTokenizer(model.nativeHandle());
Expand Down
6 changes: 3 additions & 3 deletions src/java/src/main/native/ai_onnxruntime_genai_Config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ Java_ai_onnxruntime_genai_Config_appendProvider(JNIEnv* env, jobject thiz, jlong
}

JNIEXPORT void JNICALL
Java_ai_onnxruntime_genai_Config_setProvider(JNIEnv* env, jobject thiz, jlong native_handle, jstring provider_name, jstring option_name, jstring option_value) {
Java_ai_onnxruntime_genai_Config_setProvider(JNIEnv* env, jobject thiz, jlong native_handle, jstring provider_name, jstring option_key, jstring option_value) {
CString c_provider_name{env, provider_name};
CString c_option_name{env, option_name};
CString c_option_key{env, option_key};
CString c_option_value{env, option_value};
OgaConfig* config = reinterpret_cast<OgaConfig*>(native_handle);

ThrowIfError(env, OgaConfigSetProviderOption(config, c_provider_name, c_option_name, c_option_value));
ThrowIfError(env, OgaConfigSetProviderOption(config, c_provider_name, c_option_key, c_option_value));
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ public void testUsageWithListener() throws GenAIException {
@EnabledIf("haveAdapters")
public void testUsageWithAdapters() throws GenAIException {
try (Model model = new Model(TestUtils.testAdapterTestModelPath());
Tokenizer tokenizer = model.createTokenizer()) {
Tokenizer tokenizer = new Tokenizer(model)) {
String[] prompts = {
TestUtils.applyPhi2ChatTemplate("def is_prime(n):"),
TestUtils.applyPhi2ChatTemplate("def compute_gcd(x, y):"),
TestUtils.applyPhi2ChatTemplate("def binary_search(arr, x):"),
};

try (Sequences sequences = tokenizer.encodeBatch(prompts);
GeneratorParams params = model.createGeneratorParams()) {
GeneratorParams params = new GeneratorParams(model)) {
params.setSearchOption("max_length", 200);
params.setSearchOption("batch_size", prompts.length);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public void testBatchEncodeDecode() throws GenAIException {
try (Model model = new Model(TestUtils.testVisionModelPath());
MultiModalProcessor multiModalProcessor = new MultiModalProcessor(model);
TokenizerStream stream = multiModalProcessor.createStream();
GeneratorParams generatorParams = model.createGeneratorParams()) {
GeneratorParams generatorParams = new GeneratorParams(model)) {
String inputs =
new String(
"<|user|>\n<|image_1|>\n Can you convert the table to markdown format?\n<|end|>\n<|assistant|>\n");
Expand Down

0 comments on commit bd0f247

Please sign in to comment.