From 10236ecabbaaaa471288f697547346c3e2ce33f6 Mon Sep 17 00:00:00 2001 From: hawk9821 Date: Tue, 10 Sep 2024 12:52:20 +0800 Subject: [PATCH] [Future][Transforms-V2] llm trans support field projection --- docs/en/transform-v2/llm.md | 26 ++++++- docs/zh/transform-v2/llm.md | 43 +++++++---- .../seatunnel/e2e/transform/TestLLMIT.java | 7 ++ .../llm_openai_transform_columns.conf | 76 +++++++++++++++++++ .../transform/nlpmodel/llm/LLMTransform.java | 2 + .../nlpmodel/llm/LLMTransformConfig.java | 8 ++ .../nlpmodel/llm/remote/AbstractModel.java | 61 ++++++++++++++- .../llm/remote/custom/CustomModel.java | 3 +- .../llm/remote/openai/OpenAIModel.java | 3 +- .../transform/llm/LLMRequestJsonTest.java | 41 ++++++++++ 10 files changed, 247 insertions(+), 23 deletions(-) create mode 100644 seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf diff --git a/docs/en/transform-v2/llm.md b/docs/en/transform-v2/llm.md index 8caaad00a0e..d1b8e6fc6ec 100644 --- a/docs/en/transform-v2/llm.md +++ b/docs/en/transform-v2/llm.md @@ -11,17 +11,18 @@ more. ## Options | name | type | required | default value | -|------------------------|--------|----------|---------------| +| ---------------------- | ------ | -------- | ------------- | | model_provider | enum | yes | | | output_data_type | enum | no | String | | prompt | string | yes | | +| inference_columns | list | no | | | model | string | yes | | | api_key | string | yes | | | api_path | string | no | | -| custom_config | map | no | | -| custom_response_parse | string | no | | +| custom_config | map | no | | +| custom_response_parse | string | no | | | custom_request_headers | map | no | | -| custom_request_body | map | no | | +| custom_request_body | map | no | | ### model_provider @@ -62,6 +63,23 @@ The result will be: | Eric | 20 | American | | Guangdong Liu | 20 | Chinese | +### inference_columns + +The `inference_columns` option allows you to specify which columns from the input data should be used as inputs for the LLM. By default, all columns will be used as inputs. + +For example: +```hocon +transform { + LLM { + model_provider = OPENAI + model = gpt-4o-mini + api_key = sk-xxx + inference_columns = ["name", "age"] + prompt = "Determine whether someone is Chinese or American by their name" + } +} +``` + ### model The model to use. Different model providers have different models. For example, the OpenAI model can be `gpt-4o-mini`. diff --git a/docs/zh/transform-v2/llm.md b/docs/zh/transform-v2/llm.md index 5efcf47125d..c2d3c0f6ca3 100644 --- a/docs/zh/transform-v2/llm.md +++ b/docs/zh/transform-v2/llm.md @@ -8,18 +8,19 @@ ## 属性 -| 名称 | 类型 | 是否必须 | 默认值 | -|------------------------|--------|------|--------| -| model_provider | enum | yes | | -| output_data_type | enum | no | String | -| prompt | string | yes | | -| model | string | yes | | -| api_key | string | yes | | -| api_path | string | no | | -| custom_config | map | no | | -| custom_response_parse | string | no | | -| custom_request_headers | map | no | | -| custom_request_body | map | no | | +| 名称 | 类型 | 是否必须 | 默认值 | +| ---------------------- | ------ | -------- | ------ | +| model_provider | enum | yes | | +| output_data_type | enum | no | String | +| prompt | string | yes | | +| inference_columns | list | no | | +| model | string | yes | | +| api_key | string | yes | | +| api_path | string | no | | +| custom_config | map | no | | +| custom_response_parse | string | no | | +| custom_request_headers | map | no | | +| custom_request_body | map | no | | ### model_provider @@ -60,6 +61,23 @@ Determine whether someone is Chinese or American by their name | Eric | 20 | American | | Guangdong Liu | 20 | Chinese | +### inference_columns + +`inference_columns`选项允许您指定应该将输入数据中的哪些列用作LLM的输入。默认情况下,所有列都将用作输入。 + +For example: +```hocon +transform { + LLM { + model_provider = OPENAI + model = gpt-4o-mini + api_key = sk-xxx + inference_columns = ["name", "age"] + prompt = "Determine whether someone is Chinese or American by their name" + } +} +``` + ### model 要使用的模型。不同的模型提供者有不同的模型。例如,OpenAI 模型可以是 `gpt-4o-mini`。 @@ -253,4 +271,3 @@ sink { } ``` - diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java index 712a6d7f908..244bca1e9c6 100644 --- a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java @@ -93,6 +93,13 @@ public void testLLMWithOpenAIBoolean(TestContainer container) throws IOException, InterruptedException { Container.ExecResult execResult = container.executeJob("/llm_openai_transform_boolean.conf"); + } + + @TestTemplate + public void testLLMWithOpenAIColumns(TestContainer container) + throws IOException, InterruptedException { + Container.ExecResult execResult = + container.executeJob("/llm_openai_transform_columns.conf"); Assertions.assertEquals(0, execResult.getExitCode()); } diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf new file mode 100644 index 00000000000..e4286ba7621 --- /dev/null +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_columns.conf @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +###### +###### This config file is a demonstration of streaming processing in seatunnel config +###### + +env { + job.mode = "BATCH" +} + +source { + FakeSource { + row.num = 5 + schema = { + fields { + id = "int" + name = "string" + } + } + rows = [ + {fields = [1, "Jia Fan"], kind = INSERT} + {fields = [2, "Hailin Wang"], kind = INSERT} + {fields = [3, "Tomas"], kind = INSERT} + {fields = [4, "Eric"], kind = INSERT} + {fields = [5, "Guangdong Liu"], kind = INSERT} + ] + result_table_name = "fake" + } +} + +transform { + LLM { + source_table_name = "fake" + model_provider = OPENAI + model = gpt-4o-mini + api_key = sk-xxx + inference_columns = ["name"] + prompt = "Determine whether someone is Chinese or American by their name" + openai.api_path = "http://mockserver:1080/v1/chat/completions" + result_table_name = "llm_output" + } +} + +sink { + Assert { + source_table_name = "llm_output" + rules = + { + field_rules = [ + { + field_name = llm_output + field_type = string + field_value = [ + { + rule_type = NOT_NULL + } + ] + } + ] + } + } +} diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java index 92db061ccca..705253a2fee 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransform.java @@ -79,6 +79,7 @@ public void open() { new CustomModel( inputCatalogTable.getSeaTunnelRowType(), outputDataType.getSqlType(), + config.get(LLMTransformConfig.INFERENCE_COLUMNS), config.get(LLMTransformConfig.PROMPT), config.get(LLMTransformConfig.MODEL), provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)), @@ -97,6 +98,7 @@ public void open() { new OpenAIModel( inputCatalogTable.getSeaTunnelRowType(), outputDataType.getSqlType(), + config.get(LLMTransformConfig.INFERENCE_COLUMNS), config.get(LLMTransformConfig.PROMPT), config.get(LLMTransformConfig.MODEL), config.get(LLMTransformConfig.API_KEY), diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java index 8800f061db7..c45bfb8f396 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/LLMTransformConfig.java @@ -21,6 +21,8 @@ import org.apache.seatunnel.api.configuration.Options; import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig; +import java.util.List; + public class LLMTransformConfig extends ModelTransformConfig { public static final Option PROMPT = @@ -29,6 +31,12 @@ public class LLMTransformConfig extends ModelTransformConfig { .noDefaultValue() .withDescription("The prompt of LLM"); + public static final Option> INFERENCE_COLUMNS = + Options.key("inference_columns") + .listType() + .noDefaultValue() + .withDescription("The row projection field of each inference"); + public static final Option INFERENCE_BATCH_SIZE = Options.key("inference_batch_size") .intType() diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java index 4ee271c4085..5d0fcee637c 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java @@ -21,25 +21,59 @@ import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode; import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.seatunnel.api.table.type.SeaTunnelDataType; import org.apache.seatunnel.api.table.type.SeaTunnelRow; import org.apache.seatunnel.api.table.type.SeaTunnelRowType; import org.apache.seatunnel.api.table.type.SqlType; import org.apache.seatunnel.format.json.RowToJsonConverters; +import com.google.common.annotations.VisibleForTesting; + import java.io.IOException; +import java.util.ArrayList; import java.util.List; public abstract class AbstractModel implements Model { protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private final RowToJsonConverters.RowToJsonConverter rowToJsonConverters; + private final RowToJsonConverters.RowToJsonConverter rowToJsonConverter; + private final SeaTunnelRowType rowType; private final String prompt; private final SqlType outputType; + private final List projectionColumns; - public AbstractModel(SeaTunnelRowType rowType, SqlType outputType, String prompt) { + public AbstractModel( + SeaTunnelRowType rowType, + SqlType outputType, + List projectionColumns, + String prompt) { + this.rowType = rowType; this.prompt = prompt; this.outputType = outputType; - this.rowToJsonConverters = new RowToJsonConverters().createConverter(rowType, null); + this.projectionColumns = projectionColumns; + this.rowToJsonConverter = getRowToJsonConverter(); + } + + public RowToJsonConverters.RowToJsonConverter getRowToJsonConverter() { + RowToJsonConverters converters = new RowToJsonConverters(); + if (projectionColumns != null && !projectionColumns.isEmpty()) { + List fieldTypes = new ArrayList<>(); + for (String fieldName : projectionColumns) { + int fieldIndex = rowType.indexOf(fieldName); + if (fieldIndex != -1) { + fieldTypes.add(rowType.getFieldType(fieldIndex)); + } else { + throw new IllegalArgumentException( + "Field name " + fieldName + " does not exist in the row type."); + } + } + SeaTunnelRowType projectionRowType = + new SeaTunnelRowType( + projectionColumns.toArray(new String[0]), + fieldTypes.toArray(new SeaTunnelDataType[0])); + return converters.createConverter(projectionRowType, null); + } + return converters.createConverter(rowType, null); } private String getPromptWithLimit() { @@ -58,12 +92,31 @@ public List inference(List rows) throws IOException { ArrayNode rowsNode = OBJECT_MAPPER.createArrayNode(); for (SeaTunnelRow row : rows) { ObjectNode rowNode = OBJECT_MAPPER.createObjectNode(); - rowToJsonConverters.convert(OBJECT_MAPPER, rowNode, row); + rowToJsonConverter.convert(OBJECT_MAPPER, rowNode, createProjectionSeaTunnelRow(row)); rowsNode.add(rowNode); } return chatWithModel(getPromptWithLimit(), OBJECT_MAPPER.writeValueAsString(rowsNode)); } + @VisibleForTesting + public SeaTunnelRow createProjectionSeaTunnelRow(SeaTunnelRow row) { + if (row == null || projectionColumns == null || projectionColumns.isEmpty()) { + return row; + } + SeaTunnelRow projectionRow = new SeaTunnelRow(projectionColumns.size()); + for (int i = 0; i < projectionColumns.size(); i++) { + String fieldName = projectionColumns.get(i); + int fieldIndex = rowType.indexOf(fieldName); + if (fieldIndex != -1) { + projectionRow.setField(i, row.getField(fieldIndex)); + } else { + throw new IllegalArgumentException( + "Field name " + fieldName + " does not exist in the row type."); + } + } + return projectionRow; + } + protected abstract List chatWithModel(String promptWithLimit, String rowsJson) throws IOException; diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java index af893e92ddc..dfc2bfc8681 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/custom/CustomModel.java @@ -56,13 +56,14 @@ public class CustomModel extends AbstractModel { public CustomModel( SeaTunnelRowType rowType, SqlType outputType, + List projectionColumns, String prompt, String model, String apiPath, Map header, Map body, String parse) { - super(rowType, outputType, prompt); + super(rowType, outputType, projectionColumns, prompt); this.apiPath = apiPath; this.model = model; this.header = header; diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java index 8dc12ec0cd3..aeea00b49bd 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java @@ -54,11 +54,12 @@ public class OpenAIModel extends AbstractModel { public OpenAIModel( SeaTunnelRowType rowType, SqlType outputType, + List projectionColumns, String prompt, String model, String apiKey, String apiPath) { - super(rowType, outputType, prompt); + super(rowType, outputType, projectionColumns, prompt); this.apiKey = apiKey; this.apiPath = apiPath; this.model = model; diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java index 2de785a1a8b..97eb1b8f964 100644 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/llm/LLMRequestJsonTest.java @@ -22,14 +22,18 @@ import org.apache.seatunnel.api.table.type.BasicType; import org.apache.seatunnel.api.table.type.SeaTunnelDataType; +import org.apache.seatunnel.api.table.type.SeaTunnelRow; import org.apache.seatunnel.api.table.type.SeaTunnelRowType; import org.apache.seatunnel.api.table.type.SqlType; +import org.apache.seatunnel.format.json.RowToJsonConverters; import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel; import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import com.google.common.collect.Lists; + import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -50,6 +54,7 @@ void testOpenAIRequestJson() throws IOException { new OpenAIModel( rowType, SqlType.STRING, + null, "Determine whether someone is Chinese or American by their name", "gpt-3.5-turbo", "sk-xxx", @@ -64,6 +69,41 @@ void testOpenAIRequestJson() throws IOException { model.close(); } + @Test + void testOpenAIProjectionRequestJson() throws IOException { + SeaTunnelRowType rowType = + new SeaTunnelRowType( + new String[] {"id", "name", "city"}, + new SeaTunnelDataType[] { + BasicType.INT_TYPE, BasicType.STRING_TYPE, BasicType.STRING_TYPE + }); + OpenAIModel model = + new OpenAIModel( + rowType, + SqlType.STRING, + Lists.newArrayList("name", "city"), + "Determine whether someone is Chinese or American by their name", + "gpt-3.5-turbo", + "sk-xxx", + "https://api.openai.com/v1/chat/completions"); + + SeaTunnelRow row = new SeaTunnelRow(rowType.getFieldTypes().length); + row.setField(0, 1); + row.setField(1, "John"); + row.setField(2, "New York"); + ObjectNode rowNode = OBJECT_MAPPER.createObjectNode(); + RowToJsonConverters.RowToJsonConverter rowToJsonConverter = model.getRowToJsonConverter(); + rowToJsonConverter.convert(OBJECT_MAPPER, rowNode, model.createProjectionSeaTunnelRow(row)); + ObjectNode node = + model.createJsonNodeFromData( + "Determine whether someone is Chinese or American by their name", + OBJECT_MAPPER.writeValueAsString(rowNode)); + Assertions.assertEquals( + "{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"system\",\"content\":\"Determine whether someone is Chinese or American by their name\"},{\"role\":\"user\",\"content\":\"{\\\"name\\\":\\\"John\\\",\\\"city\\\":\\\"New York\\\"}\"}]}", + OBJECT_MAPPER.writeValueAsString(node)); + model.close(); + } + @Test void testCustomRequestJson() throws IOException { SeaTunnelRowType rowType = @@ -95,6 +135,7 @@ void testCustomRequestJson() throws IOException { new CustomModel( rowType, SqlType.STRING, + null, "Determine whether someone is Chinese or American by their name", "custom-model", "https://api.custom.com/v1/chat/completions",