-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1192 from jmartisk/tool-resolution
Some fixes for resolution of tool providers, improve tests
- Loading branch information
Showing
14 changed files
with
419 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 0 additions & 150 deletions
150
core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolProviderTest.java
This file was deleted.
Oops, something went wrong.
47 changes: 47 additions & 0 deletions
47
...c/test/java/io/quarkiverse/langchain4j/test/toolresolution/AutomaticToolProviderTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package io.quarkiverse.langchain4j.test.toolresolution; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import jakarta.enterprise.context.control.ActivateRequestContext; | ||
import jakarta.inject.Inject; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.service.MemoryId; | ||
import dev.langchain4j.service.UserMessage; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
/** | ||
* If the AI service does not explicitly specify tools nor a tool provider | ||
* and there is a bean that implements ToolProvider, that bean should be used. | ||
*/ | ||
public class AutomaticToolProviderTest { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(TestAiSupplier.class, | ||
TestAiModel.class, | ||
ServiceWithDefaultToolProviderConfig.class, | ||
MyCustomToolProvider.class)); | ||
|
||
@RegisterAiService(chatLanguageModelSupplier = TestAiSupplier.class) | ||
interface ServiceWithDefaultToolProviderConfig { | ||
String chat(@UserMessage String msg, @MemoryId Object id); | ||
} | ||
|
||
@Inject | ||
ServiceWithDefaultToolProviderConfig service; | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void testCall() { | ||
String answer = service.chat("hello", 1); | ||
assertEquals("TOOL1", answer); | ||
} | ||
|
||
} |
50 changes: 50 additions & 0 deletions
50
...java/io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolProviderSupplierTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package io.quarkiverse.langchain4j.test.toolresolution; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import jakarta.enterprise.context.control.ActivateRequestContext; | ||
import jakarta.inject.Inject; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.service.MemoryId; | ||
import dev.langchain4j.service.UserMessage; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
/** | ||
* If an AI service specifies an explicit tool provider (and no specific tools), | ||
* that tool provider should be used. | ||
*/ | ||
public class ExplicitToolProviderSupplierTest { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(TestAiSupplier.class, | ||
TestAiModel.class, | ||
ServiceWithExplicitToolProviderSupplier.class, | ||
MyCustomToolProviderSupplier.class, | ||
MyCustomToolProvider.class)); | ||
|
||
@RegisterAiService(toolProviderSupplier = MyCustomToolProviderSupplier.class, chatLanguageModelSupplier = TestAiSupplier.class) | ||
interface ServiceWithExplicitToolProviderSupplier { | ||
|
||
String chat(@UserMessage String msg, @MemoryId Object id); | ||
|
||
} | ||
|
||
@Inject | ||
ServiceWithExplicitToolProviderSupplier service; | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void testCall() { | ||
String answer = service.chat("hello", 1); | ||
assertEquals("TOOL1", answer); | ||
} | ||
|
||
} |
48 changes: 48 additions & 0 deletions
48
...o/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsAndNoBeanToolProviderTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package io.quarkiverse.langchain4j.test.toolresolution; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import jakarta.enterprise.context.control.ActivateRequestContext; | ||
import jakarta.inject.Inject; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.service.MemoryId; | ||
import dev.langchain4j.service.UserMessage; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
/** | ||
* If the AI service explicitly specifies tools, and there is a bean that implements ToolProvider, | ||
* but the service also declares a NoToolProviderSupplier, the explicit tools should be used. | ||
*/ | ||
public class ExplicitToolsAndNoBeanToolProviderTest { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(TestAiSupplier.class, | ||
TestAiModel.class, | ||
ServiceWithExplicitToolsAndNoToolProviderSupplier.class, | ||
MyCustomToolProvider.class, | ||
ToolsClass.class)); | ||
|
||
@RegisterAiService(chatLanguageModelSupplier = TestAiSupplier.class, tools = ToolsClass.class, toolProviderSupplier = RegisterAiService.NoToolProviderSupplier.class) | ||
interface ServiceWithExplicitToolsAndNoToolProviderSupplier { | ||
String chat(@UserMessage String msg, @MemoryId Object id); | ||
} | ||
|
||
@Inject | ||
ServiceWithExplicitToolsAndNoToolProviderSupplier service; | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void testCall() { | ||
String answer = service.chat("hello", 1); | ||
assertEquals("\"EXPLICIT TOOL\"", answer); | ||
} | ||
|
||
} |
53 changes: 53 additions & 0 deletions
53
.../io/quarkiverse/langchain4j/test/toolresolution/ExplicitToolsAndProviderSupplierTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package io.quarkiverse.langchain4j.test.toolresolution; | ||
|
||
import jakarta.enterprise.context.control.ActivateRequestContext; | ||
import jakarta.inject.Inject; | ||
|
||
import org.assertj.core.api.Assertions; | ||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.service.MemoryId; | ||
import dev.langchain4j.service.UserMessage; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
/** | ||
* | ||
*/ | ||
public class ExplicitToolsAndProviderSupplierTest { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) | ||
.addClasses(TestAiSupplier.class, | ||
TestAiModel.class, | ||
ServiceWithToolClash.class, | ||
MyCustomToolProviderSupplier.class, | ||
MyCustomToolProvider.class, | ||
ToolsClass.class)); | ||
|
||
@RegisterAiService(toolProviderSupplier = MyCustomToolProviderSupplier.class, tools = ToolsClass.class, chatLanguageModelSupplier = TestAiSupplier.class) | ||
interface ServiceWithToolClash { | ||
|
||
String chat(@UserMessage String msg, @MemoryId Object id); | ||
|
||
} | ||
|
||
@Inject | ||
ServiceWithToolClash service; | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void testCall() { | ||
try { | ||
String answer = service.chat("hello", 1); | ||
Assertions.fail("Exception expected"); | ||
} catch (Exception e) { | ||
Assertions.assertThat(e.getMessage()).contains(" Cannot use a tool provider when explicit tools are provided"); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.