Skip to content

Commit b82436f

Browse files
collinduttervasinov
authored andcommitted
Add rules to all Tasks (#415)
1 parent 00d87a8 commit b82436f

24 files changed

+233
-100
lines changed

griptape/engines/extraction/base_extraction_engine.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
2+
from typing import Optional
23
from abc import ABC, abstractmethod
34
from attr import define, field, Factory
45
from griptape.artifacts import ListArtifact
56
from griptape.chunkers import BaseChunker, TextChunker
67
from griptape.drivers import BasePromptDriver, OpenAiChatPromptDriver
8+
from griptape.rules import Ruleset
79
from griptape.tokenizers import OpenAiTokenizer
810

911

@@ -54,5 +56,10 @@ def min_response_tokens(self) -> int:
5456
)
5557

5658
@abstractmethod
57-
def extract(self, text: str | ListArtifact, **kwargs) -> ListArtifact:
59+
def extract(
60+
self,
61+
text: str | ListArtifact,
62+
rulesets: Optional[list[Ruleset]] = None,
63+
**kwargs,
64+
) -> ListArtifact:
5865
...

griptape/engines/extraction/csv_extraction_engine.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from typing import Optional
23
import csv
34
import io
45
from attr import field, Factory, define
@@ -11,6 +12,7 @@
1112
from griptape.utils import PromptStack
1213
from griptape.engines import BaseExtractionEngine
1314
from griptape.utils import J2
15+
from griptape.rules import Ruleset, rule
1416

1517

1618
@define
@@ -21,7 +23,10 @@ class CsvExtractionEngine(BaseExtractionEngine):
2123
)
2224

2325
def extract(
24-
self, text: str | ListArtifact, column_names: list[str], **kwargs
26+
self,
27+
text: str | ListArtifact,
28+
column_names: list[str],
29+
rulesets: Optional[Ruleset] = None,
2530
) -> ListArtifact | ErrorArtifact:
2631
try:
2732
return ListArtifact(
@@ -31,6 +36,7 @@ def extract(
3136
else [TextArtifact(text)],
3237
column_names,
3338
[],
39+
rulesets=rulesets,
3440
),
3541
item_separator="\n",
3642
)
@@ -57,10 +63,13 @@ def _extract_rec(
5763
artifacts: list[TextArtifact],
5864
column_names: list[str],
5965
rows: list[CsvRowArtifact],
66+
rulesets: Optional[Ruleset] = None,
6067
) -> list[CsvRowArtifact]:
6168
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
6269
full_text = self.template_generator.render(
63-
column_names=column_names, text=artifacts_text
70+
column_names=column_names,
71+
text=artifacts_text,
72+
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
6473
)
6574

6675
if (
@@ -86,7 +95,9 @@ def _extract_rec(
8695
else:
8796
chunks = self.chunker.chunk(artifacts_text)
8897
partial_text = self.template_generator.render(
89-
column_names=column_names, text=chunks[0].value
98+
column_names=column_names,
99+
text=chunks[0].value,
100+
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
90101
)
91102

92103
rows.extend(
@@ -104,4 +115,6 @@ def _extract_rec(
104115
)
105116
)
106117

107-
return self._extract_rec(chunks[1:], column_names, rows)
118+
return self._extract_rec(
119+
chunks[1:], column_names, rows, rulesets=rulesets
120+
)

griptape/engines/extraction/json_extraction_engine.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
2+
from typing import Optional
23
import json
34
from attr import field, Factory, define
45
from griptape.artifacts import TextArtifact, ListArtifact, ErrorArtifact
56
from griptape.engines import BaseExtractionEngine
67
from griptape.utils import J2
78
from griptape.utils import PromptStack
9+
from griptape.rules import Ruleset, rule
810

911

1012
@define
@@ -15,7 +17,10 @@ class JsonExtractionEngine(BaseExtractionEngine):
1517
)
1618

1719
def extract(
18-
self, text: str | ListArtifact, template_schema: dict, **kwargs
20+
self,
21+
text: str | ListArtifact,
22+
template_schema: dict,
23+
rulesets: Optional[Ruleset] = None,
1924
) -> ListArtifact | ErrorArtifact:
2025
try:
2126
json_schema = json.dumps(template_schema)
@@ -27,6 +32,7 @@ def extract(
2732
else [TextArtifact(text)],
2833
json_schema,
2934
[],
35+
rulesets=rulesets,
3036
),
3137
item_separator="\n",
3238
)
@@ -41,10 +47,13 @@ def _extract_rec(
4147
artifacts: list[TextArtifact],
4248
json_template_schema: str,
4349
extractions: list[TextArtifact],
50+
rulesets: Optional[Ruleset] = None,
4451
) -> list[TextArtifact]:
4552
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
4653
full_text = self.template_generator.render(
47-
json_template_schema=json_template_schema, text=artifacts_text
54+
json_template_schema=json_template_schema,
55+
text=artifacts_text,
56+
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
4857
)
4958

5059
if (
@@ -69,7 +78,9 @@ def _extract_rec(
6978
else:
7079
chunks = self.chunker.chunk(artifacts_text)
7180
partial_text = self.template_generator.render(
72-
template_schema=json_template_schema, text=chunks[0].value
81+
template_schema=json_template_schema,
82+
text=chunks[0].value,
83+
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
7384
)
7485

7586
extractions.extend(
@@ -87,5 +98,5 @@ def _extract_rec(
8798
)
8899

89100
return self._extract_rec(
90-
chunks[1:], json_template_schema, extractions
101+
chunks[1:], json_template_schema, extractions, rulesets=rulesets
91102
)

griptape/engines/query/base_query_engine.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
from typing import Optional
33
from attr import define
44
from griptape.artifacts import TextArtifact, ListArtifact
5+
from griptape.rules import Ruleset
56

67

78
@define
89
class BaseQueryEngine(ABC):
910
@abstractmethod
1011
def query(
11-
self, query: str, namespace: Optional[str] = None, **kwargs
12+
self,
13+
query: str,
14+
namespace: Optional[str] = None,
15+
rulesets: Optional[list[Ruleset]] = None,
16+
**kwargs
1217
) -> TextArtifact:
1318
...
1419

griptape/engines/query/vector_query_engine.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def query(
3535
metadata: Optional[str] = None,
3636
top_n: Optional[int] = None,
3737
namespace: Optional[str] = None,
38+
rulesets: Optional[str] = None,
3839
) -> TextArtifact:
3940
tokenizer = self.prompt_driver.tokenizer
4041
result = self.vector_store_driver.query(query, top_n, namespace)
@@ -52,7 +53,10 @@ def query(
5253
text_segments.append(artifact.value)
5354

5455
message = self.template_generator.render(
55-
metadata=metadata, query=query, text_segments=text_segments
56+
metadata=metadata,
57+
query=query,
58+
text_segments=text_segments,
59+
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
5660
)
5761
message_token_count = self.prompt_driver.token_count(
5862
PromptStack(
@@ -69,7 +73,12 @@ def query(
6973
text_segments.pop()
7074

7175
message = self.template_generator.render(
72-
metadata=metadata, query=query, text_segments=text_segments
76+
metadata=metadata,
77+
query=query,
78+
text_segments=text_segments,
79+
rulesets=J2("rulesets/rulesets.j2").render(
80+
rulesets=rulesets
81+
),
7382
)
7483

7584
break
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
from abc import ABC, abstractmethod
2+
from typing import Optional
23
from attr import define
34
from griptape.artifacts import TextArtifact, ListArtifact
5+
from griptape.rules import Ruleset
46

57

68
@define
79
class BaseSummaryEngine(ABC):
8-
def summarize_text(self, text: str) -> str:
10+
def summarize_text(
11+
self, text: str, rulesets: Optional[list[Ruleset]] = None
12+
) -> str:
913
return self.summarize_artifacts(
10-
ListArtifact([TextArtifact(text)])
14+
ListArtifact([TextArtifact(text)]), rulesets=rulesets
1115
).value
1216

1317
@abstractmethod
14-
def summarize_artifacts(self, artifacts: ListArtifact) -> TextArtifact:
18+
def summarize_artifacts(
19+
self, artifacts: ListArtifact, rulesets: Optional[list[Ruleset]] = None
20+
) -> TextArtifact:
1521
...

griptape/engines/summary/prompt_summary_engine.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from griptape.engines import BaseSummaryEngine
88
from griptape.utils import J2
99
from griptape.tokenizers import OpenAiTokenizer
10+
from griptape.rules import Ruleset
1011

1112

1213
@define
@@ -57,18 +58,27 @@ def min_response_tokens(self) -> int:
5758
* self.max_token_multiplier
5859
)
5960

60-
def summarize_artifacts(self, artifacts: ListArtifact) -> TextArtifact:
61-
return self.summarize_artifacts_rec(artifacts.value, None)
61+
def summarize_artifacts(
62+
self, artifacts: ListArtifact, rulesets: Optional[Ruleset] = None
63+
) -> TextArtifact:
64+
return self.summarize_artifacts_rec(
65+
artifacts.value, None, rulesets=rulesets
66+
)
6267

6368
def summarize_artifacts_rec(
64-
self, artifacts: list[BaseArtifact], summary: Optional[str]
69+
self,
70+
artifacts: list[BaseArtifact],
71+
summary: Optional[str],
72+
rulesets: Optional[Ruleset] = None,
6573
) -> TextArtifact:
6674
artifacts_text = self.chunk_joiner.join(
6775
[a.to_text() for a in artifacts]
6876
)
6977

7078
full_text = self.template_generator.render(
71-
summary=summary, text=artifacts_text
79+
summary=summary,
80+
text=artifacts_text,
81+
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
7282
)
7383

7484
if (
@@ -86,7 +96,9 @@ def summarize_artifacts_rec(
8696
chunks = self.chunker.chunk(artifacts_text)
8797

8898
partial_text = self.template_generator.render(
89-
summary=summary, text=chunks[0].value
99+
summary=summary,
100+
text=chunks[0].value,
101+
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
90102
)
91103

92104
return self.summarize_artifacts_rec(
@@ -100,4 +112,5 @@ def summarize_artifacts_rec(
100112
]
101113
)
102114
).value,
115+
rulesets=rulesets,
103116
)

griptape/tasks/base_text_input_task.py

+49
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,21 @@
33

44
from attr import define, field
55
from griptape.artifacts import TextArtifact
6+
from griptape.rules import Ruleset, Rule
67
from griptape.tasks import BaseTask
78
from griptape.utils import J2
89

910

1011
@define
1112
class BaseTextInputTask(BaseTask, ABC):
1213
DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}"
14+
DEFAULT_RULESET_NAME = "Default Ruleset"
15+
ADDITIONAL_RULESET_NAME = "Additional Ruleset"
1316

1417
input_template: str = field(default=DEFAULT_INPUT_TEMPLATE)
1518
context: dict[str, Any] = field(factory=dict, kw_only=True)
19+
rulesets: list[Ruleset] = field(factory=list, kw_only=True)
20+
rules: list[Rule] = field(factory=list, kw_only=True)
1621

1722
@property
1823
def input(self) -> TextArtifact:
@@ -31,6 +36,50 @@ def full_context(self) -> dict[str, Any]:
3136
else:
3237
return {}
3338

39+
@rulesets.validator
40+
def validate_rulesets(self, _, rulesets: list[Ruleset]) -> None:
41+
if not rulesets:
42+
return
43+
44+
if self.rules:
45+
raise ValueError("Can't have both rulesets and rules specified.")
46+
47+
@rules.validator
48+
def validate_rules(self, _, rules: list[Rule]) -> None:
49+
if not rules:
50+
return
51+
52+
if self.rulesets:
53+
raise ValueError("Can't have both rules and rulesets specified.")
54+
55+
@property
56+
def all_rulesets(self) -> list[Ruleset]:
57+
structure_rulesets = []
58+
59+
if self.structure:
60+
if self.structure.rulesets:
61+
structure_rulesets = self.structure.rulesets
62+
elif self.structure.rules:
63+
structure_rulesets = [
64+
Ruleset(
65+
name=self.DEFAULT_RULESET_NAME,
66+
rules=self.structure.rules,
67+
)
68+
]
69+
70+
task_rulesets = []
71+
if self.rulesets:
72+
task_rulesets = self.rulesets
73+
elif self.rules:
74+
if structure_rulesets:
75+
task_ruleset_name = self.ADDITIONAL_RULESET_NAME
76+
else:
77+
task_ruleset_name = self.DEFAULT_RULESET_NAME
78+
79+
task_rulesets = [Ruleset(name=task_ruleset_name, rules=self.rules)]
80+
81+
return structure_rulesets + task_rulesets
82+
3483
def before_run(self) -> None:
3584
super().before_run()
3685

griptape/tasks/extraction_task.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ class ExtractionTask(BaseTextInputTask):
1010
args: dict = field(kw_only=True)
1111

1212
def run(self) -> ListArtifact:
13-
return self.extraction_engine.extract(self.input.to_text(), **self.args)
13+
return self.extraction_engine.extract(
14+
self.input.to_text(), rulesets=self.all_rulesets, **self.args
15+
)

0 commit comments

Comments
 (0)