Skip to content

Commit

Permalink
refactor: upgrade to SpringAI 1.0.0-M6
Browse files Browse the repository at this point in the history
  • Loading branch information
bsbodden committed Feb 18, 2025
1 parent 21c059b commit 6b2f371
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class Product {
schemaFieldType = SchemaFieldType.VECTOR, //
algorithm = VectorAlgorithm.HNSW, //
type = VectorType.FLOAT32, //
dimension = 768, //
dimension = 384, //
distanceMetric = DistanceMetric.COSINE, //
initialCapacity = 10
)
Expand Down
6 changes: 3 additions & 3 deletions redis-om-spring/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@
<elementary.version>2.0.1</elementary.version>
<gson.version>2.10.1</gson.version>
<djl.starter.version>0.26</djl.starter.version>
<djl.version>0.27.0</djl.version>
<djl.version>0.30.0</djl.version>
<junit-bom.version>5.10.2</junit-bom.version>
<spring-ai.version>1.0.0-M2</spring-ai.version>
<spring-ai.version>1.0.0-M6</spring-ai.version>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -200,7 +200,7 @@
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-vertex-ai-palm2</artifactId>
<artifactId>spring-ai-vertex-ai-embedding</artifactId>
<version>${spring-ai.version}</version>
<optional>true</optional>
</dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,29 @@
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.palm2.VertexAiPaLm2EmbeddingModel;
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingModel;
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingOptions;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.*;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;

import java.io.IOException;
import java.net.InetAddress;
import java.time.*;
import java.time.Duration;
import java.util.Map;

@ConditionalOnProperty(name = "redis.om.spring.ai.enabled")
Expand All @@ -70,8 +73,8 @@ public ImageFactory imageFactory() {
}

@Bean(name = "djlImageEmbeddingModelCriteria")
public Criteria<Image, byte[]> imageEmbeddingModelCriteria(RedisOMAiProperties properties) {
return Criteria.builder().setTypes(Image.class, byte[].class) //
public Criteria<Image, float[]> imageEmbeddingModelCriteria(RedisOMAiProperties properties) {
return Criteria.builder().setTypes(Image.class, float[].class) //
.optEngine(properties.getDjl().getImageEmbeddingModelEngine()) //
.optModelUrls(properties.getDjl().getImageEmbeddingModelModelUrls()) //
.build();
Expand Down Expand Up @@ -123,7 +126,8 @@ public Criteria<Image, float[]> faceEmbeddingModelCriteria( //
RedisOMAiProperties properties) {

return Criteria.builder() //
.setTypes(Image.class, float[].class).optModelUrls(properties.getDjl().getFaceEmbeddingModelModelUrls()) //
.setTypes(Image.class, float[].class) //
.optModelUrls(properties.getDjl().getFaceEmbeddingModelModelUrls()) //
.optModelName(properties.getDjl().getFaceEmbeddingModelName()) //
.optTranslator(translator) //
.optEngine(properties.getDjl().getFaceEmbeddingModelEngine()) //
Expand All @@ -142,8 +146,9 @@ public ZooModel<Image, float[]> faceEmbeddingModel(
}

@Bean(name = "djlImageEmbeddingModel")
public ZooModel<Image, byte[]> imageModel(
@Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria<Image, byte[]> criteria) throws MalformedModelException, ModelNotFoundException, IOException {
public ZooModel<Image, float[]> imageModel(
@Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria<Image, float[]> criteria)
throws MalformedModelException, ModelNotFoundException, IOException {
return criteria != null ? ModelZoo.loadModel(criteria) : null;
}

Expand Down Expand Up @@ -178,6 +183,28 @@ public HuggingFaceTokenizer sentenceTokenizer(RedisOMAiProperties properties) {
}
}

@Bean(name = "transformersEmbeddingModel")
public TransformersEmbeddingModel transformersEmbeddingModel(RedisOMAiProperties properties) {
TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel();
if (properties.getTransformers().getTokenizerResource() != null) {
embeddingModel.setTokenizerResource(properties.getTransformers().getTokenizerResource());
}

if (properties.getTransformers().getModelResource() != null) {
embeddingModel.setModelResource(properties.getTransformers().getModelResource());
}

if (properties.getTransformers().getResourceCacheDirectory() != null) {
embeddingModel.setResourceCacheDirectory(properties.getTransformers().getResourceCacheDirectory());
}

if (!properties.getTransformers().getTokenizerOptions().isEmpty()) {
embeddingModel.setTokenizerOptions(properties.getTransformers().getTokenizerOptions());
}

return embeddingModel;
}

@ConditionalOnMissingBean
@Bean
public OpenAiEmbeddingModel openAITextVectorizer(RedisOMAiProperties properties,
Expand All @@ -204,7 +231,7 @@ public OpenAiEmbeddingModel openAITextVectorizer(RedisOMAiProperties properties,

// Rest of the configuration
return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED,
OpenAiEmbeddingOptions.builder().withModel("text-embedding-ada-002").build(),
OpenAiEmbeddingOptions.builder().model("text-embedding-ada-002").build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
} else {
return null;
Expand Down Expand Up @@ -251,7 +278,7 @@ public OpenAIClient azureOpenAIClient(RedisOMAiProperties properties, //

@ConditionalOnMissingBean
@Bean
VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel(RedisOMAiProperties properties, //
VertexAiTextEmbeddingModel vertexAiEmbeddingModel(RedisOMAiProperties properties, //
@Value("${spring.ai.vertex.ai.api-key:}") String apiKey,
@Value("${spring.ai.vertex.ai.ai.base-url:}") String baseUrl) {
if (!StringUtils.hasText(apiKey)) {
Expand Down Expand Up @@ -281,9 +308,15 @@ VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel(RedisOMAiProperties prop
}

if (StringUtils.hasText(apiKey) && StringUtils.hasText(baseUrl)) {
VertexAiPaLm2Api vertexAiApi = new VertexAiPaLm2Api(baseUrl, apiKey, VertexAiPaLm2Api.DEFAULT_GENERATE_MODEL,
VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, RestClient.builder());
return new VertexAiPaLm2EmbeddingModel(vertexAiApi);

VertexAiEmbeddingConnectionDetails connectionDetails = VertexAiEmbeddingConnectionDetails.builder()
.projectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID")).location(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
.build();

VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.model(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME).build();

return new VertexAiTextEmbeddingModel(connectionDetails, options);
} else {
return null;
}
Expand Down Expand Up @@ -346,7 +379,7 @@ BedrockCohereEmbeddingModel bedrockCohereEmbeddingModel(RedisOMAiProperties prop
if (!StringUtils.hasText(model)) {
model = properties.getBedrockCohere().getModel();
if (!StringUtils.hasText(model)) {
model = CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1.id();
model = CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V3.id();
properties.getBedrockCohere().setModel(model);
}
}
Expand Down Expand Up @@ -439,19 +472,19 @@ BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel(RedisOMAiProperties proper
@Primary
@Bean(name = "featureExtractor")
public Embedder featureExtractor(
@Nullable @Qualifier("djlImageEmbeddingModel") ZooModel<Image, byte[]> imageEmbeddingModel,
@Nullable @Qualifier("djlImageEmbeddingModel") ZooModel<Image, float[]> imageEmbeddingModel,
@Nullable @Qualifier("djlFaceEmbeddingModel") ZooModel<Image, float[]> faceEmbeddingModel,
@Nullable @Qualifier("djlImageFactory") ImageFactory imageFactory,
@Nullable @Qualifier("djlDefaultImagePipeline") Pipeline defaultImagePipeline,
@Nullable @Qualifier("djlSentenceTokenizer") HuggingFaceTokenizer sentenceTokenizer,
@Nullable @Qualifier("transformersEmbeddingModel") TransformersEmbeddingModel transformersEmbeddingModel,
@Nullable OpenAiEmbeddingModel openAITextVectorizer, @Nullable OpenAIClient azureOpenAIClient,
@Nullable VertexAiPaLm2EmbeddingModel vertexAiPaLm2EmbeddingModel,
@Nullable VertexAiTextEmbeddingModel vertexAiTextEmbeddingModel,
@Nullable BedrockCohereEmbeddingModel bedrockCohereEmbeddingModel,
@Nullable BedrockTitanEmbeddingModel bedrockTitanEmbeddingModel,
RedisOMAiProperties properties,
ApplicationContext ac) {
return new DefaultEmbedder(ac, imageEmbeddingModel, faceEmbeddingModel, imageFactory, defaultImagePipeline,
sentenceTokenizer, openAITextVectorizer, azureOpenAIClient, vertexAiPaLm2EmbeddingModel,
transformersEmbeddingModel, openAITextVectorizer, azureOpenAIClient, vertexAiTextEmbeddingModel,
bedrockCohereEmbeddingModel, bedrockTitanEmbeddingModel, properties);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.ConfigurationProperties;

import java.util.HashMap;
import java.util.Map;

@ConditionalOnProperty(name = "redis.om.spring.ai.enabled")
@ConfigurationProperties(
prefix = "redis.om.spring.ai", ignoreInvalidFields = true
)
public class RedisOMAiProperties {
private boolean enabled = false;
private final Djl djl = new Djl();
private final Transformers transformers = new Transformers();
private final OpenAi openAi = new OpenAi();
private final AzureOpenAi azureOpenAi = new AzureOpenAi();
private final VertexAi vertexAi = new VertexAi();
Expand All @@ -31,6 +35,10 @@ public Djl getDjl() {
return djl;
}

public Transformers getTransformers() {
return transformers;
}

public OpenAi getOpenAi() {
return openAi;
}
Expand All @@ -55,6 +63,30 @@ public Ollama getOllama() {
return ollama;
}

// Transformer properties
public static class Transformers {
private String tokenizerResource;
private String modelResource;
private String resourceCacheDirectory;
private Map<String, String> tokenizerOptions = new HashMap<>();

public String getTokenizerResource() {
return tokenizerResource;
}

public String getModelResource() {
return modelResource;
}

public String getResourceCacheDirectory() {
return resourceCacheDirectory;
}

public Map<String, String> getTokenizerOptions() {
return tokenizerOptions;
}
}

// DJL properties
public static class Djl {
private static final String DEFAULT_ENGINE = "PyTorch";
Expand All @@ -73,7 +105,7 @@ public static class Djl {
@NotNull
private String sentenceTokenizerModelMaxLength = "768";
@NotNull
private String sentenceTokenizerModel = "sentence-transformers/all-mpnet-base-v2";
private String sentenceTokenizerModel = "sentence-transformers/msmarco-distilbert-dot-v5";

// face detection
@NotNull
Expand All @@ -91,6 +123,7 @@ public static class Djl {
@NotNull
private String faceEmbeddingModelModelUrls = "https://resources.djl.ai/test-models/pytorch/face_feature.zip";


public Djl() {
}

Expand Down Expand Up @@ -278,6 +311,24 @@ public static class VertexAi {
private String apiKey;
private String endPoint;
private String model;
private String projectId;
private String location;

public String getProjectId() {
return projectId;
}

public void setProjectId(String projectId) {
this.projectId = projectId;
}

public String getLocation() {
return location;
}

public void setLocation(String location) {
this.location = location;
}

public String getApiKey() {
return apiKey;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.redis.om.spring.annotations;

public enum EmbeddingProvider {
TRANSFORMERS,
DJL,
OPENAI,
OLLAMA,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingModel;
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api;

import java.lang.annotation.*;

Expand All @@ -16,17 +15,17 @@

EmbeddingType embeddingType() default EmbeddingType.SENTENCE;

EmbeddingProvider provider() default EmbeddingProvider.DJL;
EmbeddingProvider provider() default EmbeddingProvider.TRANSFORMERS;

EmbeddingModel openAiEmbeddingModel() default EmbeddingModel.TEXT_EMBEDDING_ADA_002;

OllamaModel ollamaEmbeddingModel() default OllamaModel.MISTRAL;

String azureOpenAiDeploymentName() default "text-embedding-ada-002";

String vertexAiPaLm2ApiModel() default VertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL;
String vertexAiPaLm2ApiModel() default "text-embedding-004";

CohereEmbeddingModel cohereEmbeddingModel() default CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V1;
CohereEmbeddingModel cohereEmbeddingModel() default CohereEmbeddingModel.COHERE_EMBED_MULTILINGUAL_V3;

TitanEmbeddingModel titanEmbeddingModel() default TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1;
}
Loading

0 comments on commit 6b2f371

Please sign in to comment.