Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Transform] Add LLM transform #7303

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions docs/en/transform-v2/llm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# LLM

> LLM transform plugin

## Description

Leverage the power of a large language model (LLM) to process data by sending it to the LLM and receiving the
generated results. Utilize the LLM's capabilities to label, clean, enrich data, perform data inference, and
more.

## Options

| name | type | required | default value |
|------------------|--------|----------|--------------------------------------------|
| model_provider | enum | yes | |
| output_data_type | enum | no | String |
| prompt | string | yes | |
| model | string | yes | |
| api_key | string | yes | |
| openai.api_path | string | no | https://api.openai.com/v1/chat/completions |

### model_provider

The model provider to use. The available options are:
OPENAI

### output_data_type

The data type of the output data. The available options are:
STRING,INT,BIGINT,DOUBLE,BOOLEAN.
Default value is STRING.

### prompt

The prompt to send to the LLM. This parameter defines how LLM will process and return data, eg:

The data read from source is a table like this:

| name | age |
|---------------|-----|
| Jia Fan | 20 |
| Hailin Wang | 20 |
| Eric | 20 |
| Guangdong Liu | 20 |

The prompt can be:

```
Determine whether someone is Chinese or American by their name
```

The result will be:

| name | age | llm_output |
|---------------|-----|------------|
| Jia Fan | 20 | Chinese |
| Hailin Wang | 20 | Chinese |
| Eric | 20 | American |
| Guangdong Liu | 20 | Chinese |

### model

The model to use. Different model providers have different models. For example, the OpenAI model can be `gpt-4o-mini`.
If you use OpenAI model, please refer https://platform.openai.com/docs/models/model-endpoint-compatibility of `/v1/chat/completions` endpoint.

### api_key

The API key to use for the model provider.
If you use OpenAI model, please refer https://platform.openai.com/docs/api-reference/api-keys of how to get the API key.

### openai.api_path

The API path to use for the OpenAI model provider. In most cases, you do not need to change this configuration. If you are using an API agent's service, you may need to configure it to the agent's API address.

### common options [string]

Transform plugin common parameters, please refer to [Transform Plugin](common-options.md) for details

## Example

Determine the user's country through a LLM.

```hocon
env {
parallelism = 1
job.mode = "BATCH"
}

source {
FakeSource {
row.num = 5
schema = {
fields {
id = "int"
name = "string"
}
}
rows = [
{fields = [1, "Jia Fan"], kind = INSERT}
{fields = [2, "Hailin Wang"], kind = INSERT}
{fields = [3, "Tomas"], kind = INSERT}
{fields = [4, "Eric"], kind = INSERT}
{fields = [5, "Guangdong Liu"], kind = INSERT}
]
}
}

transform {
LLM {
model_provider = OPENAI
model = gpt-4o-mini
api_key = sk-xxx
prompt = "Determine whether someone is Chinese or American by their name"
}
}

sink {
console {
}
}
```

Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.TestResource;
import org.apache.seatunnel.e2e.common.container.TestContainer;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestTemplate;
import org.testcontainers.containers.Container;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.output.Slf4jLogConsumer;
import org.testcontainers.containers.wait.strategy.HttpWaitStrategy;
import org.testcontainers.lifecycle.Startables;
import org.testcontainers.utility.DockerImageName;
import org.testcontainers.utility.DockerLoggerFactory;
import org.testcontainers.utility.MountableFile;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.Optional;
import java.util.stream.Stream;

public class TestLLMIT extends TestSuiteBase implements TestResource {

private static final String TMP_DIR = "/tmp";
private GenericContainer<?> mockserverContainer;
private static final String IMAGE = "mockserver/mockserver:5.14.0";

@BeforeAll
@Override
public void startUp() {
Optional<URL> resource =
Optional.ofNullable(TestLLMIT.class.getResource("/mockserver-config.json"));
this.mockserverContainer =
new GenericContainer<>(DockerImageName.parse(IMAGE))
.withNetwork(NETWORK)
.withNetworkAliases("mockserver")
.withExposedPorts(1080)
.withCopyFileToContainer(
MountableFile.forHostPath(
new File(
resource.orElseThrow(
() ->
new IllegalArgumentException(
"Can not get config file of mockServer"))
.getPath())
.getAbsolutePath()),
TMP_DIR + "/mockserver-config.json")
.withEnv(
"MOCKSERVER_INITIALIZATION_JSON_PATH",
TMP_DIR + "/mockserver-config.json")
.withEnv("MOCKSERVER_LOG_LEVEL", "WARN")
.withLogConsumer(new Slf4jLogConsumer(DockerLoggerFactory.getLogger(IMAGE)))
.waitingFor(new HttpWaitStrategy().forPath("/").forStatusCode(404));
Startables.deepStart(Stream.of(mockserverContainer)).join();
}

@AfterAll
@Override
public void tearDown() throws Exception {
if (mockserverContainer != null) {
mockserverContainer.stop();
}
}

@TestTemplate
public void testLLMWithOpenAI(TestContainer container)
throws IOException, InterruptedException {
Container.ExecResult execResult = container.executeJob("/llm_openai_transform.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
######
###### This config file is a demonstration of streaming processing in seatunnel config
######

env {
job.mode = "BATCH"
}

source {
FakeSource {
row.num = 5
schema = {
fields {
id = "int"
name = "string"
}
}
rows = [
{fields = [1, "Jia Fan"], kind = INSERT}
{fields = [2, "Hailin Wang"], kind = INSERT}
{fields = [3, "Tomas"], kind = INSERT}
{fields = [4, "Eric"], kind = INSERT}
{fields = [5, "Guangdong Liu"], kind = INSERT}
]
}
}

transform {
LLM {
model_provider = OPENAI
model = gpt-4o-mini
api_key = sk-xxx
prompt = "Determine whether someone is Chinese or American by their name"
openai.api_path = "http://mockserver:1080/v1/chat/completions"
}
}

sink {
Assert {
rules =
{
field_rules = [
{
field_name = llm_output
field_type = string
field_value = [
{
rule_type = NOT_NULL
}
]
}
]
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// https://www.mock-server.com/mock_server/getting_started.html#request_matchers

[
{
"httpRequest": {
"method": "POST",
"path": "/v1/chat/completions"
},
"httpResponse": {
"body": {
"id": "chatcmpl-9s4hoBNGV0d9Mudkhvgzg64DAWPnx",
"object": "chat.completion",
"created": 1722674828,
"model": "gpt-4o-mini",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "[\"Chinese\"]"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 107,
"completion_tokens": 3,
"total_tokens": 110
},
"system_fingerprint": "fp_0f03d4f0ee",
"code": 0,
"msg": "ok"
},
"headers": {
"Content-Type": "application/json"
}
}
}
]
15 changes: 15 additions & 0 deletions seatunnel-transforms-v2/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
<artifactId>seatunnel-transforms-v2</artifactId>
<name>SeaTunnel : Transforms : V2</name>

<properties>
<httpclient.version>4.5.13</httpclient.version>
<httpcore.version>4.4.4</httpcore.version>
</properties>

<dependencyManagement>
<dependencies>
<dependency>
Expand Down Expand Up @@ -77,6 +82,16 @@
<version>${project.version}</version>
<classifier>optional</classifier>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>${httpclient.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpcore</artifactId>
<version>${httpcore.version}</version>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,8 @@ public RowKind getRowKind() {
public Object getField(int pos) {
return row.getField(pos);
}

public Object[] getFields() {
return row.getFields();
}
}
Loading
Loading