1
1
from __future__ import annotations
2
+ from typing import Optional
2
3
import json
3
4
from attr import field , Factory , define
4
5
from griptape .artifacts import TextArtifact , ListArtifact , ErrorArtifact
5
6
from griptape .engines import BaseExtractionEngine
6
7
from griptape .utils import J2
7
8
from griptape .utils import PromptStack
9
+ from griptape .rules import Ruleset , rule
8
10
9
11
10
12
@define
@@ -15,7 +17,10 @@ class JsonExtractionEngine(BaseExtractionEngine):
15
17
)
16
18
17
19
def extract (
18
- self , text : str | ListArtifact , template_schema : dict , ** kwargs
20
+ self ,
21
+ text : str | ListArtifact ,
22
+ template_schema : dict ,
23
+ rulesets : Optional [Ruleset ] = None ,
19
24
) -> ListArtifact | ErrorArtifact :
20
25
try :
21
26
json_schema = json .dumps (template_schema )
@@ -27,6 +32,7 @@ def extract(
27
32
else [TextArtifact (text )],
28
33
json_schema ,
29
34
[],
35
+ rulesets = rulesets ,
30
36
),
31
37
item_separator = "\n " ,
32
38
)
@@ -41,10 +47,13 @@ def _extract_rec(
41
47
artifacts : list [TextArtifact ],
42
48
json_template_schema : str ,
43
49
extractions : list [TextArtifact ],
50
+ rulesets : Optional [Ruleset ] = None ,
44
51
) -> list [TextArtifact ]:
45
52
artifacts_text = self .chunk_joiner .join ([a .value for a in artifacts ])
46
53
full_text = self .template_generator .render (
47
- json_template_schema = json_template_schema , text = artifacts_text
54
+ json_template_schema = json_template_schema ,
55
+ text = artifacts_text ,
56
+ rulesets = J2 ("rulesets/rulesets.j2" ).render (rulesets = rulesets ),
48
57
)
49
58
50
59
if (
@@ -69,7 +78,9 @@ def _extract_rec(
69
78
else :
70
79
chunks = self .chunker .chunk (artifacts_text )
71
80
partial_text = self .template_generator .render (
72
- template_schema = json_template_schema , text = chunks [0 ].value
81
+ template_schema = json_template_schema ,
82
+ text = chunks [0 ].value ,
83
+ rulesets = J2 ("rulesets/rulesets.j2" ).render (rulesets = rulesets ),
73
84
)
74
85
75
86
extractions .extend (
@@ -87,5 +98,5 @@ def _extract_rec(
87
98
)
88
99
89
100
return self ._extract_rec (
90
- chunks [1 :], json_template_schema , extractions
101
+ chunks [1 :], json_template_schema , extractions , rulesets = rulesets
91
102
)
0 commit comments