Skip to content

Commit

Permalink
Merge branch 'main' into rollback
Browse files Browse the repository at this point in the history
  • Loading branch information
yiannisha authored Oct 2, 2024
2 parents c98921c + 03c7cdf commit 9d5581c
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "llm-tool"
version = "1.0.3"
version = "1.0.4"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,28 @@ The definition for the function above will look like this:
}
```

### Use with methods
```python
fomr llm_tool import tool

class TestClass:

@tool(self)
def test_method(self, a: int = 0) -> None:
'''
This is a test method.
:param a: This is a test parameter.
'''
pass

# get the definition from the class
definition = TestClass.test_method.definition
# or from an object
t = TestClass()
definition = t.test_method.definition
```

### Groq API Example
```python
from groq import Groq
Expand Down Expand Up @@ -220,3 +242,9 @@ def test2(a: int) -> None:

### Support
Currently only docstrings in the [reST](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html) format are supported, but support for more doscstring formats will be added in the future.

### Roadmap
<ul><li>- [ ] </li></ul> Add support for Union types
<ul><li>- [ ] </li></ul> Add support for writing subtypes (e.g. `List[int]` instead of just `List`)
<ul><li>- [ ] </li></ul> Support for more doscstring formats

12 changes: 8 additions & 4 deletions python/llm_tool/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@ def __init__(self, message: str):

class DefinedFunction():

def __init__(self, func, definition = {}) -> None:
def __init__(self, func, definition = {}, context = None) -> None:
"""
A function that has been defined in the llm_tool tool.
"""

self._func = func
self._definition = definition
self._context = context

def __call__(self, *args, **kwargs):
if self._context:
return self._func(self._context, *args, **kwargs)

return self._func(*args, **kwargs)

@property
Expand All @@ -47,7 +51,7 @@ def get_type_name(type_: Union[type, _BaseGenericAlias, None]) -> str:

raise TypeParsingException(f"Failed to parse type: {type_}")

def tool(desc_required: Union[bool, None] = None, return_required: Union[bool, None] = None) -> Callable[[Callable], DefinedFunction]:
def tool(context = None, desc_required: Union[bool, None] = None, return_required: Union[bool, None] = None) -> Callable[[Callable], DefinedFunction]:
desc_required = desc_required if desc_required is not None else GlobalToolConfig.desc_required
return_required = return_required if return_required is not None else GlobalToolConfig.return_required

Expand All @@ -62,7 +66,7 @@ def inner(func: Callable) -> DefinedFunction:
}

if func_params := inspect.signature(func).parameters:
# ignore error

for key, value in func_params.items():

if key == 'self':
Expand Down Expand Up @@ -124,7 +128,7 @@ def inner(func: Callable) -> DefinedFunction:
}
}

func = DefinedFunction(func, out)
func = DefinedFunction(func, out, context)
return func

return inner
151 changes: 103 additions & 48 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

class TestTool(unittest.TestCase):

maxDiff = None

def test_best_case(self):
@tool()
def test(a: str, b: int, c: Dict[str, str], d: List[str], e: bool, f: float, g: List[Dict[str, str]], h: Union[bool, None], i: Union[Dict[str, str], None] = None) -> Dict:
def test(a: str, b: int, c: Dict[str, str], d: List[bool], e: bool, f: float, g: List[Dict[str, str]], h: List[str] = ["1", "2", "3"]) -> Dict:
"""
This is a test function.
:param a: this is the description for a
Expand All @@ -25,54 +27,54 @@ def test(a: str, b: int, c: Dict[str, str], d: List[str], e: bool, f: float, g:
"""
pass

self.assertEqual(test.definition, {
'type': 'function',
'function': {
'name': 'test',
'description': 'This is a test function.\n\nReturn Type: Dict\n\nReturns: this is the description for reteurn',
'parameters': {
'type': 'object',
'properties': {
'a': {
'type': 'str',
'description': 'this is the description for a',
},
'b': {
'type': 'int',
'description': 'this is the description for b',
},
'c': {
'type': 'Dict',
'description': 'this is the description for c',
},
'd': {
'type': 'List',
'description': 'this is the description for d. Defaults to ["1", "2", "3"]',
},
'e': {
'type': 'bool',
'description': 'this is the description for e',
},
'f': {
'type': 'float',
'description': 'this is the description for f',
},
'g': {
'type': 'List',
'description': 'this is the description for g',
},
'h': {
'type': 'Union',
'description': 'this is the description for h',
},
'i': {
'type': 'Union',
'description': 'this is the description for i. Defaults to None',
},

self.assertEqual(test.definition, {
'type': 'function',
'function': {
'name': 'test',
'description': 'This is a test function.\n\nReturn Type: `Dict`\n\nReturn Description: this is the description for return',
'parameters': {
'type': 'object',
'properties': {
'a': {
'type': 'str',
'description': 'this is the description for a',
},
'required': ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'],
}
'b': {
'type': 'int',
'description': 'this is the description for b',
},
'c': {
'type': 'Dict',
'description': 'this is the description for c',
},
'd': {
'type': 'List',
'description': 'this is the description for d',
},
'e': {
'type': 'bool',
'description': 'this is the description for e',
},
'f': {
'type': 'float',
'description': 'this is the description for f',
},
'g': {
'type': 'List',
'description': 'this is the description for g',
},
'h': {
'type': 'List',
'description': 'this is the description for h Default Value: `[\'1\', \'2\', \'3\']`',
},
# 'i': {
# 'type': 'Union',
# 'description': ' Default Value: `None`',
# },

},
'required': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
}
}
})

Expand Down Expand Up @@ -114,6 +116,49 @@ def test(a: str) -> int:
with self.assertRaisesRegex(DocStringException, "Return description not found in docstring of `[a-zA-Z\d_]*` function signature."):
tool(return_required=True)(test)

def test_method(self):

class Test:

@tool()
def test(self, a: str, b: int = 2) -> List[Union[str, int]]:
'''
This is a test function.
:param a: this is the description for a
:param b: this is the description for b
:returns: this is the description for return
'''
return [a, b]

t = Test()
self.assertEqual(
t.test.definition,
{
'type': 'function',
'function': {
'name': 'test',
'description': 'This is a test function.\n\nReturn Type: `List`\n\nReturn Description: this is the description for return',
'parameters': {
'type': 'object',
'properties': {
'a': {
'type': 'str',
'description': 'this is the description for a',
},
'b': {
'type': 'int',
'description': 'this is the description for b Default Value: `2`',
},
},
'required': ['a']
}
}
}
)


class TestDefinedFunction(unittest.TestCase):

def test_call(self):
Expand All @@ -123,6 +168,16 @@ def test(a: int, b: int) -> List[int]:
definedFunction = DefinedFunction(test)
self.assertEqual(definedFunction(1, b=2), [1, 2])

def test_call_method(self):

class Test:
@tool(self)
def test(self, a: int) -> int:
return a

t = Test()
self.assertEqual(t.test(5), 5)

class TestGlobalToolConfig(unittest.TestCase):
def test_default(self):
self.assertEqual(GlobalToolConfig.desc_required, False)
Expand Down

0 comments on commit 9d5581c

Please sign in to comment.