Skip to content

Commit

Permalink
examples: adopt RAG examples for remote execution (#3117)
Browse files Browse the repository at this point in the history
To be able to run the RAG DAG in a deployment we need non-local file
local storage.
The POC was build to pass data between jobs using a local file. Since I
want to deploy the jobs I need a way to pass data between them without
that since they do not share a file system. Postgres based storage was
created for that.

So moved created one and adopt it. 
It's currently copied in both jobs. I will refactor it away after this
PR.

I also ended up removing NLKT from everywhere. And also few doc fixes
  • Loading branch information
antoniivanov authored Feb 16, 2024
1 parent ddd7ac1 commit 6369417
Show file tree
Hide file tree
Showing 29 changed files with 476 additions and 338 deletions.
90 changes: 90 additions & 0 deletions examples/confluence-reader/common/database_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2021-2024 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import pickle
from typing import Any
from typing import List
from typing import Optional
from typing import Union

from common.storage import IStorage
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import LargeBinary
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.exc import IntegrityError


class DatabaseStorage(IStorage):
def __init__(self, connection_string: str):
self.engine = create_engine(connection_string)
self.metadata = MetaData()
self.table = Table(
"vdk_storage",
self.metadata,
Column("name", String, primary_key=True),
Column("content", LargeBinary),
Column("content_type", String),
)
self.metadata.create_all(self.engine)

def store(self, name: str, content: Union[str, bytes, Any]) -> None:
serialized_content, content_type = self._serialize_content(content)
ins = self.table.insert().values(
name=name, content=serialized_content, content_type=content_type
)
try:
with self.engine.connect() as conn:
conn.execute(ins)
conn.commit()
except IntegrityError:
# Handle duplicate name by updating existing content
upd = (
self.table.update()
.where(self.table.c.name == name)
.values(content=serialized_content, content_type=content_type)
)
with self.engine.connect() as conn:
conn.execute(upd)
conn.commit()

def retrieve(self, name: str) -> Optional[Union[str, bytes, Any]]:
sel = self.table.select().where(self.table.c.name == name)
with self.engine.connect() as conn:
result = conn.execute(sel).fetchone()
if result:
return self._deserialize_content(result.content, result.content_type)
return None

def list_contents(self) -> List[str]:
sel = select(self.table.c.name)
with self.engine.connect() as conn:
result = conn.execute(sel).fetchall()
return [row[0] for row in result]

def remove(self, name: str) -> bool:
del_stmt = self.table.delete().where(self.table.c.name == name)
with self.engine.connect() as conn:
result = conn.execute(del_stmt)
conn.commit()
return result.rowcount > 0

@staticmethod
def _serialize_content(content: Union[str, bytes, Any]) -> tuple[bytes, str]:
if isinstance(content, bytes):
return content, "bytes"
elif isinstance(content, str):
return content.encode(), "string"
else:
# Fallback to pickle for other types
return pickle.dumps(content), "pickle"

@staticmethod
def _deserialize_content(content: bytes, content_type: Optional[str]) -> Any:
if content_type == "pickle":
return pickle.loads(content)
if content_type == "string":
return content.decode()
return content
57 changes: 57 additions & 0 deletions examples/confluence-reader/common/file_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2021-2024 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import json
import os
from typing import Any
from typing import List
from typing import Optional
from typing import Union

from storage import IStorage


class FileStorage(IStorage):
def __init__(self, base_path: str):
self.base_path = base_path
if not os.path.exists(self.base_path):
os.makedirs(self.base_path)

def _get_file_path(self, name: str) -> str:
return os.path.join(self.base_path, name)

def store(
self,
name: str,
content: Union[str, bytes, Any],
content_type: Optional[str] = None,
) -> None:
file_path = self._get_file_path(name)
with open(file_path, "w") as file:
if isinstance(content, (str, bytes)):
# Directly save strings and bytes
file.write(content if isinstance(content, str) else content.decode())
else:
# Assume JSON serializable for other types
json.dump(content, file)

def retrieve(self, name: str) -> Optional[Union[str, bytes, Any]]:
file_path = self._get_file_path(name)
if not os.path.exists(file_path):
return None
with open(file_path) as file:
try:
return json.load(file)
except json.JSONDecodeError:
# Content was not JSON, return as string
file.seek(0)
return file.read()

def list_contents(self) -> List[str]:
return os.listdir(self.base_path)

def remove(self, name: str) -> bool:
file_path = self._get_file_path(name)
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
46 changes: 46 additions & 0 deletions examples/confluence-reader/common/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2021-2024 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from typing import List
from typing import Optional
from typing import Union


class IStorage:
def store(self, name: str, content: Union[str, bytes, Any]) -> None:
"""
Stores the given content under the specified name. If the content is not a string or bytes,
the method tries to serialize it based on the content_type (if provided) or infers the type.
:param name: The unique name to store the content under.
:param content: The content to store. Can be of type str, bytes, or any serializable type.
"""
pass

def retrieve(self, name: str) -> Optional[Union[str, bytes, Any]]:
"""
Retrieves the content stored under the specified name. The method attempts to deserialize
the content to its original type if possible.
:param name: The name of the content to retrieve.
:return: The retrieved content, which can be of type str, bytes, or any deserialized Python object.
Returns None if the content does not exist.
"""
pass

def list_contents(self) -> List[str]:
"""
Lists the names of all stored contents.
:return: A list of names representing the stored contents.
"""
pass

def remove(self, name: str) -> bool:
"""
Removes the content stored under the specified name.
:param name: The name of the content to remove.
:return: True if the content was successfully removed, False otherwise.
"""
pass
50 changes: 31 additions & 19 deletions examples/confluence-reader/fetch_confluence_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import json
import logging
import os
import pathlib
from datetime import datetime

from common.database_storage import DatabaseStorage
from confluence_document import ConfluenceDocument
from langchain_community.document_loaders import ConfluenceLoader
from vdk.api.job_input import IJobInput


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -109,17 +112,12 @@ def __init__(self, confluence_url, token, space_key):
self.loader = ConfluenceLoader(url=self.confluence_url, token=self.token)

def fetch_confluence_documents(self, cql_query):
try:
# TODO: think about configurable limits ? or some streaming solution
# How do we fit all documents in memory ?
raw_documents = self.loader.load(cql=cql_query, limit=50, max_pages=200)
return [
ConfluenceDocument(doc.metadata, doc.page_content)
for doc in raw_documents
]
except Exception as e:
log.error(f"Error fetching documents from Confluence: {e}")
raise e
# TODO: think about configurable limits ? or some streaming solution
# How do we fit all documents in memory ?
raw_documents = self.loader.load(cql=cql_query, limit=50, max_pages=200)
return [
ConfluenceDocument(doc.metadata, doc.page_content) for doc in raw_documents
]

def fetch_updated_pages_in_confluence_space(
self, last_date="1900-02-06 17:54", parent_page_id=None
Expand Down Expand Up @@ -147,9 +145,10 @@ def fetch_all_pages_in_confluence_space(self, parent_page_id=None):


def get_value(job_input, key: str, default_value=None):
return job_input.get_arguments().get(
key, job_input.get_property(key, os.environ.get(key.upper(), default_value))
)
value = os.environ.get(key.upper(), default_value)
value = job_input.get_property(key, value)
value = job_input.get_secret(key, value)
return job_input.get_arguments().get(key, value)


def set_property(job_input: IJobInput, key, value):
Expand All @@ -165,12 +164,20 @@ def run(job_input: IJobInput):
token = get_value(job_input, "confluence_token")
space_key = get_value(job_input, "confluence_space_key")
parent_page_id = get_value(job_input, "confluence_parent_page_id")
last_date = get_value(job_input, "last_date", "1900-01-01 12:00")
data_file = get_value(
job_input,
"data_file",
os.path.join(job_input.get_temporary_write_directory(), "confluence_data.json"),
last_date = (
job_input.get_property(confluence_url, {})
.setdefault(space_key, {})
.setdefault(parent_page_id, {})
.get("last_date", "1900-01-01 12:00")
)
data_file = os.path.join(
job_input.get_temporary_write_directory(), "confluence_data.json"
)
storage_name = get_value(job_input, "storage_name", "confluence_data")
storage = DatabaseStorage(get_value(job_input, "storage_connection_string"))
# TODO: this is not optimal . We just care about the IDs, we should not need to retrieve everything
data = storage.retrieve(storage_name)
pathlib.Path(data_file).write_text(data if data else "[]")

confluence_reader = ConfluenceDataSource(confluence_url, token, space_key)

Expand All @@ -189,3 +196,8 @@ def run(job_input: IJobInput):
data_file,
confluence_reader.fetch_all_pages_in_confluence_space(parent_page_id),
)

# TODO: it would be better to save each page in separate row.
# But that's quick solution for now to pass the data to the next job

storage.store(storage_name, pathlib.Path(data_file).read_text())
2 changes: 2 additions & 0 deletions examples/confluence-reader/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
atlassian-python-api
langchain_community
lxml
psycopg2-binary
sqlalchemy
54 changes: 0 additions & 54 deletions examples/fetch-embed-job-example/10_fetch_confluence_space.py

This file was deleted.

Loading

0 comments on commit 6369417

Please sign in to comment.