Skip to content

Commit

Permalink
[NOID] Fixes #4005: Add a procedure for RAG (#4077) (#4270)
Browse files Browse the repository at this point in the history
* [NOID] Fixes #4005: Add a procedure for RAG (#4077)

* [NOID] format changes
  • Loading branch information
vga91 authored Dec 5, 2024
1 parent d4f43fa commit 99a42f1
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 1 deletion.
92 changes: 91 additions & 1 deletion full/src/main/java/apoc/ml/Prompt.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package apoc.ml;

import apoc.ApocConfig;
import apoc.Description;
import apoc.Extended;
import apoc.result.StringResult;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.net.MalformedURLException;
import java.util.ArrayList;
Expand All @@ -12,8 +14,13 @@
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.commons.text.WordUtils;
import org.jetbrains.annotations.NotNull;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Path;
import org.neo4j.graphdb.QueryExecutionException;
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.logging.Log;
Expand All @@ -24,6 +31,7 @@

@Extended
public class Prompt {
public static final String API_KEY_CONF = "apiKey";

@Context
public Transaction tx;
Expand Down Expand Up @@ -83,6 +91,78 @@ public boolean hasError() {
}
}

@Procedure(mode = Mode.READ)
@Description("Takes a query in cypher and in natural language and returns the results in natural language")
public Stream<StringResult> rag(
@Name("paths") Object paths,
@Name("attributes") List<String> attributes,
@Name("question") String question,
@Name(value = "conf", defaultValue = "{}") Map<String, Object> conf)
throws Exception {
RagConfig config = new RagConfig(conf);
String[] arrayAttrs = attributes.toArray(String[]::new);

StringBuilder context = new StringBuilder();
// -- Retrieve
if (paths instanceof List) {
List pathList = (List) paths;

for (var listItem : pathList) {
// -- Augment
augment(config, arrayAttrs, context, listItem);
}

} else if (paths instanceof String) {
String queryOrIndex = (String) paths;
config.getEmbeddings().getQuery(queryOrIndex, question, tx, config).forEachRemaining(row -> row.values()
// -- Augment
.forEach(val -> augment(config, arrayAttrs, context, val)));
} else {
throw new RuntimeException("The first parameter must be a List or a String");
}

// - Generate
String contextPrompt = String.format(
" \n" + " ---- Start context ----\n"
+ " %s\n"
+ " ---- End context ----",
context);

String prompt = config.getBasePrompt() + contextPrompt;
String result = prompt("\nQuestion:" + question, prompt, null, null, conf);
return Stream.of(new StringResult(result));
}

private void augment(RagConfig config, String[] objects, StringBuilder context, Object listItem) {
if (listItem instanceof Path) {
Path p = (Path) listItem;
for (Entity entity : p) {
augmentEntity(config, objects, context, entity);
}
} else if (listItem instanceof Entity) {
Entity e = (Entity) listItem;
augmentEntity(config, objects, context, e);
} else {
throw new RuntimeException(String.format("The list `%s` must have node/type/path items", listItem));
}
}

private void augmentEntity(RagConfig config, String[] objects, StringBuilder context, Entity entity) {
Map<String, Object> props = entity.getProperties(objects);
if (config.isGetLabelTypes()) {
String labelsOrType = entity instanceof Node
? Util.joinLabels(((Node) entity).getLabels(), ",")
: ((Relationship) entity).getType().name();
labelsOrType = WordUtils.capitalize(labelsOrType, '_');
props.put("context description", labelsOrType);
}
String obj = props.entrySet().stream()
.filter(i -> i.getValue() != null)
.map(i -> i.getKey() + ": " + i.getValue() + "\n")
.collect(Collectors.joining("\n---\n"));
context.append(obj);
}

@Procedure(mode = Mode.READ)
public Stream<PromptMapResult> query(
@Name("question") String question, @Name(value = "conf", defaultValue = "{}") Map<String, Object> conf) {
Expand Down Expand Up @@ -162,7 +242,7 @@ private String prompt(
prompt.add(Map.of("role", "user", "content", userQuestion));
if (assistantPrompt != null && !assistantPrompt.isBlank())
prompt.add(Map.of("role", "assistant", "content", assistantPrompt));
String apiKey = (String) conf.get("apiKey");
String apiKey = (String) conf.get(API_KEY_CONF);
String model = (String) conf.getOrDefault("model", "gpt-4o");
String result = OpenAI.executeRequest(
apiKey, Map.of(), "chat/completions", model, "messages", prompt, "$", apocConfig)
Expand All @@ -185,6 +265,16 @@ private String prompt(
return result;
}

public static final String UNKNOWN_ANSWER = "Sorry, I don't know";
static final String RAG_BASE_PROMPT =
"You are a customer service agent that helps a customer with answering questions about a service.\n"
+ "Use the following context to answer the `user question` at the end. Make sure not to make any changes to the context if possible when prepare answers so as to provide accurate responses.\n"
+ "If you don't know the answer, just say `%s`, don't try to make up an answer.\n"
+ "\n"
+ "---- Start context ----\n"
+ "%s\n"
+ "---- End context ----";

private static final String SCHEMA_QUERY =
"call apoc.meta.data({maxRels: 10, sample: coalesce($sample, (count{()}/1000)+1)})\n"
+ "YIELD label, other, elementType, type, property\n"
Expand Down
131 changes: 131 additions & 0 deletions full/src/main/java/apoc/ml/RagConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package apoc.ml;

import static apoc.ml.Prompt.API_KEY_CONF;

import apoc.util.Util;
import java.util.Map;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;

public class RagConfig {
public static final String UNKNOWN_ANSWER = "Sorry, I don't know";
public static final String DEFAULT_BASE_PROMPT = String.format(
"You are a customer service agent that helps a customer with answering questions about a service.\n"
+ "Use the following context to answer the `user question` at the end. Make sure not to make any changes to the context if possible when prepare answers to provide accurate responses.\n"
+ "If you don't know the answer, just say `%s`, don't try to make up an answer.",
UNKNOWN_ANSWER);

public static final String EMBEDDINGS_CONF = "embeddings";
public static final String GET_LABEL_TYPES_CONF = "getLabelTypes";
public static final String TOP_K_CONF = "topK";
public static final String PROMPT_CONF = "prompt";

private final boolean getLabelTypes;
private final EmbeddingQuery embeddings;
private final Integer topK;
private final String apiKey;
private final String basePrompt;
private final Map<String, Object> confMap;

public RagConfig(Map<String, Object> confMap) {
if (confMap == null) {
confMap = Map.of();
}

this.confMap = confMap;
this.getLabelTypes = Util.toBoolean(confMap.getOrDefault(GET_LABEL_TYPES_CONF, true));
String embeddingString = (String) confMap.getOrDefault(EMBEDDINGS_CONF, EmbeddingQuery.Type.FALSE.name());
this.embeddings = EmbeddingQuery.Type.valueOf(embeddingString).get();
this.topK = Util.toInteger(confMap.getOrDefault(TOP_K_CONF, 40));
this.apiKey = (String) confMap.get(API_KEY_CONF);
this.basePrompt = (String) confMap.getOrDefault(PROMPT_CONF, DEFAULT_BASE_PROMPT);
}

public boolean isGetLabelTypes() {
return getLabelTypes;
}

public EmbeddingQuery getEmbeddings() {
return embeddings;
}

public Integer getTopK() {
return topK;
}

public String getApiKey() {
return apiKey;
}

public String getBasePrompt() {
return basePrompt;
}

public Map<String, Object> getConfMap() {
return confMap;
}

public interface EmbeddingQuery {
Result getQuery(String queryOrIndex, String question, Transaction tx, RagConfig config);

String BASE_EMBEDDING_QUERY = "CALL apoc.ml.openai.embedding([$question], $key , $conf)\n"
+ "YIELD index, text, embedding\n" + "WITH text, embedding";

default Map<String, Object> getParams(String queryOrIndex, String question, RagConfig config) {
return Map.of(
"vectorIndex",
queryOrIndex,
TOP_K_CONF,
config.getTopK(),
"question",
question,
"key",
config.getApiKey(),
"conf",
config.getConfMap());
}

enum Type {
NODE(new Node()),
REL(new Rel()),
FALSE(new False());

private final EmbeddingQuery embedding;

Type(EmbeddingQuery embedding) {
this.embedding = embedding;
}

public EmbeddingQuery get() {
return embedding;
}
}

class False implements EmbeddingQuery {
@Override
public Result getQuery(String queryOrIndex, String question, Transaction tx, RagConfig config) {
return tx.execute(queryOrIndex);
}
}

class Node implements EmbeddingQuery {
@Override
public Result getQuery(String queryOrIndex, String question, Transaction tx, RagConfig config) {
return tx.execute(
BASE_EMBEDDING_QUERY
+ "CALL db.index.vector.queryNodes($vectorIndex, $topK, embedding) YIELD node RETURN node",
getParams(queryOrIndex, question, config));
}
}

class Rel implements EmbeddingQuery {
@Override
public Result getQuery(String queryOrIndex, String question, Transaction tx, RagConfig config) {
return tx.execute(
BASE_EMBEDDING_QUERY
+ "CALL db.index.vector.queryRelationships($vectorIndex, $topK, embedding) YIELD relationship RETURN relationship",
getParams(queryOrIndex, question, config));
}
}
}
}
1 change: 1 addition & 0 deletions full/src/main/resources/extended.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ apoc.ml.cypher
apoc.ml.fromCypher
apoc.ml.fromQueries
apoc.ml.query
apoc.ml.rag
apoc.ml.schema
apoc.ml.mixedbread.custom
apoc.ml.mixedbread.embedding
Expand Down
16 changes: 16 additions & 0 deletions full/src/test/resources/rag.cypher
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
CREATE (mixed2022:Discipline {title:"Mixed doubles curling", year: 2022})
WITH mixed2022
CREATE (:Athlete {name: 'Stefania Constantini', country: 'Italy', irrelevant: 'asdasd'})-[:HAS_MEDAL {medal: 'Gold', irrelevant2: 'asdasd'}]->(mixed2022)
CREATE (:Athlete {name: 'Amos Mosaner', country: 'Italy', irrelevant: 'qweqwe'})-[:HAS_MEDAL {medal: 'Gold', irrelevant2: 'rwerew'}]->(mixed2022)
CREATE (:Athlete {name: 'Kristin Skaslien', country: 'Norway', irrelevant: 'dfgdfg'})-[:HAS_MEDAL {medal: 'Silver', irrelevant2: 'gdfg'}]->(mixed2022)
CREATE (:Athlete {name: 'Magnus Nedregotten', country: 'Norway', irrelevant: 'xcvxcv'})-[:HAS_MEDAL {medal: 'Silver', irrelevant2: 'asdasd'}]->(mixed2022)
CREATE (:Athlete {name: 'Almida de Val', country: 'Sweden', irrelevant: 'rtyrty'})-[:HAS_MEDAL {medal: 'Bronze', irrelevant2: 'bfbfb'}]->(mixed2022)
CREATE (:Athlete {name: 'Oskar Eriksson', country: 'Sweden', irrelevant: 'qwresdc'})-[:HAS_MEDAL {medal: 'Bronze', irrelevant2: 'juju'}]->(mixed2022)
CREATE (mixed2018:Discipline {title:"Mixed doubles's curling", year: 2018})
WITH mixed2018
CREATE (:Athlete {name: 'Lawes', country: 'USA', irrelevant: 'asdasd'})-[:HAS_MEDAL {medal: 'Gold', irrelevant2: 'asdasd'}]->(mixed2018)
CREATE (:Athlete {name: 'Morris', country: 'USA', irrelevant: 'qweqwe'})-[:HAS_MEDAL {medal: 'Gold', irrelevant2: 'rwerew'}]->(mixed2018)
CREATE (:Athlete {name: 'mock name 3', country: 'mock country 2', irrelevant: 'dfgdfg'})-[:HAS_MEDAL {medal: 'Silver', irrelevant2: 'gdfg'}]->(mixed2018)
CREATE (:Athlete {name: 'mock name 4', country: 'mock country 2', irrelevant: 'xcvxcv'})-[:HAS_MEDAL {medal: 'Silver', irrelevant2: 'asdasd'}]->(mixed2018)
CREATE (:Athlete {name: 'mock name 5', country: 'mock country 3', irrelevant: 'rtyrty'})-[:HAS_MEDAL {medal: 'Bronze', irrelevant2: 'bfbfb'}]->(mixed2018)
CREATE (:Athlete {name: 'mock name 6', country: 'mock country 3', irrelevant: 'qwresdc'})-[:HAS_MEDAL {medal: 'Bronze', irrelevant2: 'juju'}]->(mixed2018)

0 comments on commit 99a42f1

Please sign in to comment.