Skip to content

Commit

Permalink
More ruff formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
NathalieCharbel committed Jan 21, 2025
1 parent fbbd2c4 commit 5a81f87
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 16 deletions.
19 changes: 14 additions & 5 deletions examples/customize/build_graph/pipeline/kg_builder_from_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
logging.basicConfig(level=logging.INFO)


async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) -> PipelineResult:
async def define_and_run_pipeline(
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> PipelineResult:
from neo4j_graphrag.experimental.pipeline import Pipeline

# Instantiate Entity and Relation objects
Expand All @@ -56,7 +58,9 @@ async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface)
),
]
relations = [
SchemaRelation(label="SITUATED_AT", description="Indicates the location of a person."),
SchemaRelation(
label="SITUATED_AT", description="Indicates the location of a person."
),
SchemaRelation(
label="LED_BY",
description="Indicates the leader of an organization.",
Expand All @@ -65,7 +69,9 @@ async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface)
label="OWNS",
description="Indicates the ownership of an item such as a Horcrux.",
),
SchemaRelation(label="INTERACTS", description="The interaction between two people."),
SchemaRelation(
label="INTERACTS", description="The interaction between two people."
),
]
potential_schema = [
("PERSON", "SITUATED_AT", "LOCATION"),
Expand All @@ -78,7 +84,8 @@ async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface)
pipe = Pipeline()
pipe.add_component(PdfLoader(), "pdf_loader")
pipe.add_component(
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False), "splitter"
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False),
"splitter",
)
pipe.add_component(SchemaBuilder(), "schema")
pipe.add_component(
Expand Down Expand Up @@ -126,7 +133,9 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password"))
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
res = await define_and_run_pipeline(driver, llm)
driver.close()
await llm.async_client.close()
Expand Down
12 changes: 9 additions & 3 deletions examples/customize/build_graph/pipeline/kg_builder_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
import neo4j


async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface) -> PipelineResult:
async def define_and_run_pipeline(
neo4j_driver: neo4j.Driver, llm: LLMInterface
) -> PipelineResult:
"""This is where we define and run the KG builder pipeline, instantiating a few
components:
- Text Splitter: in this example we use the fixed size text splitter
Expand Down Expand Up @@ -74,7 +76,9 @@ async def define_and_run_pipeline(neo4j_driver: neo4j.Driver, llm: LLMInterface)
# and how the output of previous components must be used
pipe.connect("splitter", "chunk_embedder", input_config={"text_chunks": "splitter"})
pipe.connect("schema", "extractor", input_config={"schema": "schema"})
pipe.connect("chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"})
pipe.connect(
"chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}
)
pipe.connect(
"extractor",
"writer",
Expand Down Expand Up @@ -145,7 +149,9 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
driver = neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password"))
driver = neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
)
res = await define_and_run_pipeline(driver, llm)
driver.close()
await llm.async_client.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:


if __name__ == "__main__":
with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver:
with neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
) as driver:
print(asyncio.run(main(driver)))
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ async def define_and_run_pipeline(
)
# define the execution order of component
# and how the output of previous components must be used
pipe.connect("chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"})
pipe.connect(
"chunk_embedder", "extractor", input_config={"chunks": "chunk_embedder"}
)
pipe.connect("schema", "extractor", input_config={"schema": "schema"})
pipe.connect(
"extractor",
Expand Down Expand Up @@ -188,5 +190,7 @@ async def main(driver: neo4j.Driver) -> PipelineResult:


if __name__ == "__main__":
with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver:
with neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
) as driver:
print(asyncio.run(main(driver)))
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,15 @@ async def main(driver: neo4j.Driver) -> PipelineResult:
},
)
await build_lexical_graph(driver, lexical_graph_config, text=text)
res = await read_chunk_and_perform_entity_extraction(driver, llm, lexical_graph_config)
res = await read_chunk_and_perform_entity_extraction(
driver, llm, lexical_graph_config
)
await llm.async_client.close()
return res


if __name__ == "__main__":
with neo4j.GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password")) as driver:
with neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
) as driver:
print(asyncio.run(main(driver)))
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def test_invalid_chunk_size() -> None:
("Hello World", 0, 0),
],
)
def test_adjust_chunk_start(text: str, approximate_start: int, expected_start: int) -> None:
def test_adjust_chunk_start(
text: str, approximate_start: int, expected_start: int
) -> None:
"""
Test that the _adjust_chunk_start function correctly shifts
the start index to avoid breaking words, unless no whitespace is found.
Expand All @@ -127,7 +129,9 @@ def test_adjust_chunk_start(text: str, approximate_start: int, expected_start: i
("Hello World", 6, 15, 15),
],
)
def test_adjust_chunk_end(text: str, start: int, approximate_end: int, expected_end: int) -> None:
def test_adjust_chunk_end(
text: str, start: int, approximate_end: int, expected_end: int
) -> None:
"""
Test that the _adjust_chunk_end function correctly shifts
the end index to avoid breaking words, unless no whitespace is found.
Expand Down Expand Up @@ -183,7 +187,11 @@ def test_adjust_chunk_end(text: str, start: int, approximate_end: int, expected_
],
)
async def test_fixed_size_splitter_run(
text: str, chunk_size: int, chunk_overlap: int, approximate: bool, expected_chunks: list[str]
text: str,
chunk_size: int,
chunk_overlap: int,
approximate: bool,
expected_chunks: list[str],
) -> None:
"""
Test that 'FixedSizeSplitter.run' returns the expected chunks
Expand Down

0 comments on commit 5a81f87

Please sign in to comment.