Skip to content

Commit

Permalink
[Improve][Connector-V2] update vectorType (#7446)
Browse files Browse the repository at this point in the history
  • Loading branch information
corgy-w authored Aug 24, 2024
1 parent 8a28290 commit 1bba723
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ private int getBytesForValue(Object v, SeaTunnelDataType<?> dataType) {
case TIMESTAMP:
return 48;
case FLOAT_VECTOR:
return getArrayNotNullSize((Object[]) v) * 4;
case FLOAT16_VECTOR:
case BFLOAT16_VECTOR:
case BINARY_VECTOR:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,23 @@

package org.apache.seatunnel.api.table.type;

import org.apache.seatunnel.api.annotation.Experimental;

import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Objects;

/**
* VectorType represents a vector type in SeaTunnel.
*
* <p>Experimental feature, use with caution
*/
@Experimental
public class VectorType<T> implements SeaTunnelDataType<T> {
private static final long serialVersionUID = 2L;

public static final VectorType<Float> VECTOR_FLOAT_TYPE =
new VectorType<>(Float.class, SqlType.FLOAT_VECTOR);
public static final VectorType<ByteBuffer> VECTOR_FLOAT_TYPE =
new VectorType<>(ByteBuffer.class, SqlType.FLOAT_VECTOR);

public static final VectorType<Map> VECTOR_SPARSE_FLOAT_TYPE =
new VectorType<>(Map.class, SqlType.SPARSE_FLOAT_VECTOR);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.common.utils;

import java.nio.Buffer;
import java.nio.ByteBuffer;

public class BufferUtils {

public static ByteBuffer toByteBuffer(Short[] shortArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(shortArray.length * 2);

for (Short value : shortArray) {
byteBuffer.putShort(value);
}

// Compatible compilation and running versions are not consistent
// Flip the buffer to prepare for reading
((Buffer) byteBuffer).flip();

return byteBuffer;
}

public static Short[] toShortArray(ByteBuffer byteBuffer) {
Short[] shortArray = new Short[byteBuffer.capacity() / 2];

for (int i = 0; i < shortArray.length; i++) {
shortArray[i] = byteBuffer.getShort();
}

return shortArray;
}

public static ByteBuffer toByteBuffer(Float[] floatArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(floatArray.length * 4);

for (float value : floatArray) {
byteBuffer.putFloat(value);
}

((Buffer) byteBuffer).flip();

return byteBuffer;
}

public static Float[] toFloatArray(ByteBuffer byteBuffer) {
Float[] floatArray = new Float[byteBuffer.capacity() / 4];

for (int i = 0; i < floatArray.length; i++) {
floatArray[i] = byteBuffer.getFloat();
}

return floatArray;
}

public static ByteBuffer toByteBuffer(Double[] doubleArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(doubleArray.length * 8);

for (double value : doubleArray) {
byteBuffer.putDouble(value);
}

((Buffer) byteBuffer).flip();

return byteBuffer;
}

public static Double[] toDoubleArray(ByteBuffer byteBuffer) {
Double[] doubleArray = new Double[byteBuffer.capacity() / 8];

for (int i = 0; i < doubleArray.length; i++) {
doubleArray[i] = byteBuffer.getDouble();
}

return doubleArray;
}

public static ByteBuffer toByteBuffer(Integer[] intArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(intArray.length * 4);

for (int value : intArray) {
byteBuffer.putInt(value);
}

((Buffer) byteBuffer).flip();

return byteBuffer;
}

public static Integer[] toIntArray(ByteBuffer byteBuffer) {
Integer[] intArray = new Integer[byteBuffer.capacity() / 4];

for (int i = 0; i < intArray.length; i++) {
intArray[i] = byteBuffer.getInt();
}

return intArray;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.common.utils;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.nio.ByteBuffer;

public class BufferUtilsTest {

@Test
public void testToByteBufferAndToShortArray() {
Short[] shortArray = {1, 2, 3, 4, 5};
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(shortArray);
Short[] resultArray = BufferUtils.toShortArray(byteBuffer);

Assertions.assertArrayEquals(shortArray, resultArray, "Short array conversion failed");
}

@Test
public void testToByteBufferAndToFloatArray() {
Float[] floatArray = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f};
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(floatArray);
Float[] resultArray = BufferUtils.toFloatArray(byteBuffer);

Assertions.assertArrayEquals(floatArray, resultArray, "Float array conversion failed");
}

@Test
public void testToByteBufferAndToDoubleArray() {
Double[] doubleArray = {1.1, 2.2, 3.3, 4.4, 5.5};
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(doubleArray);
Double[] resultArray = BufferUtils.toDoubleArray(byteBuffer);

Assertions.assertArrayEquals(doubleArray, resultArray, "Double array conversion failed");
}

@Test
public void testToByteBufferAndToIntArray() {
Integer[] intArray = {1, 2, 3, 4, 5};
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(intArray);
Integer[] resultArray = BufferUtils.toIntArray(byteBuffer);

Assertions.assertArrayEquals(intArray, resultArray, "Integer array conversion failed");
}

@Test
public void testEmptyArrayConversion() {
// Test empty arrays
Short[] shortArray = {};
ByteBuffer shortBuffer = BufferUtils.toByteBuffer(shortArray);
Short[] shortResultArray = BufferUtils.toShortArray(shortBuffer);
Assertions.assertArrayEquals(
shortArray, shortResultArray, "Empty Short array conversion failed");

Float[] floatArray = {};
ByteBuffer floatBuffer = BufferUtils.toByteBuffer(floatArray);
Float[] floatResultArray = BufferUtils.toFloatArray(floatBuffer);
Assertions.assertArrayEquals(
floatArray, floatResultArray, "Empty Float array conversion failed");

Double[] doubleArray = {};
ByteBuffer doubleBuffer = BufferUtils.toByteBuffer(doubleArray);
Double[] doubleResultArray = BufferUtils.toDoubleArray(doubleBuffer);
Assertions.assertArrayEquals(
doubleArray, doubleResultArray, "Empty Double array conversion failed");

Integer[] intArray = {};
ByteBuffer intBuffer = BufferUtils.toByteBuffer(intArray);
Integer[] intResultArray = BufferUtils.toIntArray(intBuffer);
Assertions.assertArrayEquals(
intArray, intResultArray, "Empty Integer array conversion failed");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.seatunnel.connectors.seatunnel.fake.utils;

import org.apache.seatunnel.common.utils.BufferUtils;
import org.apache.seatunnel.connectors.seatunnel.fake.config.FakeConfig;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.RandomUtils;

import java.math.BigDecimal;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.time.LocalDate;
import java.time.LocalDateTime;
Expand Down Expand Up @@ -178,14 +178,14 @@ public ByteBuffer randomBinaryVector() {
return ByteBuffer.wrap(RandomUtils.nextBytes(byteCount));
}

public Float[] randomFloatVector() {
public ByteBuffer randomFloatVector() {
Float[] floatVector = new Float[fakeConfig.getVectorDimension()];
for (int i = 0; i < fakeConfig.getVectorDimension(); i++) {
floatVector[i] =
RandomUtils.nextFloat(
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
}
return floatVector;
return BufferUtils.toByteBuffer(floatVector);
}

public ByteBuffer randomFloat16Vector() {
Expand All @@ -196,7 +196,7 @@ public ByteBuffer randomFloat16Vector() {
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
float16Vector[i] = floatToFloat16(value);
}
return shortArrayToByteBuffer(float16Vector);
return BufferUtils.toByteBuffer(float16Vector);
}

public ByteBuffer randomBFloat16Vector() {
Expand All @@ -207,7 +207,7 @@ public ByteBuffer randomBFloat16Vector() {
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
bfloat16Vector[i] = floatToBFloat16(value);
}
return shortArrayToByteBuffer(bfloat16Vector);
return BufferUtils.toByteBuffer(bfloat16Vector);
}

public Map<Integer, Float> randomSparseFloatVector() {
Expand Down Expand Up @@ -242,20 +242,6 @@ private static short floatToFloat16(float value) {
return (short) (sign | (exponent << 10) | (mantissa >> 13));
}

private static ByteBuffer shortArrayToByteBuffer(Short[] shortArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(shortArray.length * 2);

for (Short value : shortArray) {
byteBuffer.putShort(value);
}

// Compatible compilation and running versions are not consistent
// Flip the buffer to prepare for reading
((Buffer) byteBuffer).flip();

return byteBuffer;
}

private static short floatToBFloat16(float value) {
int intBits = Float.floatToIntBits(value);
return (short) (intBits >> 16);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.common.exception.CommonError;
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
import org.apache.seatunnel.common.utils.BufferUtils;
import org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorErrorCode;
import org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorException;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.converter.AbstractJdbcRowConverter;
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.dialect.DatabaseIdentifier;
import org.apache.seatunnel.connectors.seatunnel.jdbc.utils.JdbcFieldTypeUtils;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
Expand All @@ -40,8 +42,6 @@
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

public class OceanBaseMysqlJdbcRowConverter extends AbstractJdbcRowConverter {
Expand Down Expand Up @@ -89,11 +89,13 @@ public SeaTunnelRow toInternal(ResultSet rs, TableSchema tableSchema) throws SQL
fields[fieldIndex] = JdbcFieldTypeUtils.getFloat(rs, resultSetIndex);
break;
case FLOAT_VECTOR:
List<Float> vector = new ArrayList<>();
for (Object o : (Object[]) rs.getObject(fieldIndex)) {
vector.add(Float.parseFloat(o.toString()));
Object[] objects = (Object[]) rs.getObject(fieldIndex);
Float[] arrays = new Float[objects.length];
for (int i = 0; i < objects.length; i++) {
arrays[i] = Float.parseFloat(objects[i].toString());
}
fields[fieldIndex] = vector;
fields[fieldIndex] = BufferUtils.toByteBuffer(arrays);
break;
case DOUBLE:
fields[fieldIndex] = JdbcFieldTypeUtils.getDouble(rs, resultSetIndex);
break;
Expand Down Expand Up @@ -172,8 +174,10 @@ public PreparedStatement toExternal(
statement.setFloat(statementIndex, (Float) row.getField(fieldIndex));
break;
case FLOAT_VECTOR:
if (row.getField(fieldIndex) instanceof Float[]) {
Float[] floatArray = (Float[]) row.getField(fieldIndex);
if (row.getField(fieldIndex) instanceof ByteBuffer) {
ByteBuffer byteBuffer = (ByteBuffer) row.getField(fieldIndex);
// Convert ByteBuffer to Float[]
Float[] floatArray = BufferUtils.toFloatArray(byteBuffer);
StringBuilder vector = new StringBuilder();
vector.append("[");
for (Float aFloat : floatArray) {
Expand Down
Loading

0 comments on commit 1bba723

Please sign in to comment.