Skip to content

Commit

Permalink
feat: wired up services to function
Browse files Browse the repository at this point in the history
  • Loading branch information
telpirion committed May 12, 2023
1 parent 3433aa0 commit f27e726
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 19 deletions.
69 changes: 55 additions & 14 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@

from dataclasses import dataclass
import datetime
import os
import re

from document_extract import async_document_extract
from storage import upload_to_gcs
from vertex_llm import predict_large_language_model_hack
from utils import coerce_datetime_zulu
from google.cloud import functions
from google.cloud import logging

from .bigquery import write_summarization_to_table
from .document_extract import async_document_extract
from .storage import upload_to_gcs
from .vertex_llm import predict_large_language_model
from .utils import coerce_datetime_zulu, truncate_complete_text

FUNCTIONS_GCS_EVENT_LOGGER = 'function-triggered-by-storage'
# TODO(erschmid): replace PROJECT_ID, MODEL_NAME, TABLE_ID, and DATASET_ID with env vars
PROJECT_ID = 'velociraptor-16p1-dev'
MODEL_NAME = 'text-bison@001'
DATASET_ID = 'academic_papers'
TABLE_ID = 'summarizations'

@dataclass
class CloudEventData:
Expand All @@ -44,11 +55,12 @@ def read_datetimes(cls, kwargs):

# WEBHOOK FUNCTION
@functions_framework.cloud_event
def entrypoint(cloud_event) -> dict:
def entrypoint(cloud_event: dict, context: object) -> dict:
"""Entrypoint for Cloud Function
Args:
cloud_event (CloudEvent): an event from EventArc
context (google.cloud.functions.Context): the context of this event; UNUSED
Returns:
dictionary with 'summary' and 'output_filename' keys
Expand All @@ -62,26 +74,55 @@ def entrypoint(cloud_event) -> dict:
timeCreated = coerce_datetime_zulu(cloud_event.data["timeCreated"])
updated = coerce_datetime_zulu(cloud_event.data["updated"])

logging_client = logging.Client()
logger = logging_client.logger(FUNCTIONS_GCS_EVENT_LOGGER)
logger.log(f"""
cloud_event: id: {event_id}, event_type: {event_type}, bucket: {bucket}, file: {name}, time: {timeCreated}
""",
severity="INFO")


extracted_text = async_document_extract(bucket, name)
summary = predict_large_language_model_hack(
project_id="velociraptor-16p1-src",
model_name="text-bison-001",

complete_text_filename = f'summaries/{name.replace(".pdf", "")}_fulltext.txt'
upload_to_gcs(
bucket,
complete_text_filename,
extracted_text,
)

# TODO(erschmid): replace truncate with better solution
extracted_text_ = truncate_complete_text(extracted_text)
summary = predict_large_language_model(
project_id=PROJECT_ID,
model_name=MODEL_NAME,
temperature=0.2,
max_decode_steps=1024,
top_p=0.8,
top_k=40,
content=f'Summarize:\n{extracted_text}',
content=f'Summarize:\n{extracted_text_}',
location="us-central1",
)

output_filename = f'{name.replace(".pdf", "")}_summary.txt'
output_filename = f'system-test/{name.replace(".pdf", "")}_summary.txt'
upload_to_gcs(
bucket,
output_filename,
summary,
)

return {
'summary': summary,
'output_filename': output_filename,
}
# If we have any errors, they'll be caught by the bigquery module
errors = write_summarization_to_table(
project_id=PROJECT_ID,
dataset_id=DATASET_ID,
table_id=TABLE_ID,
bucket=bucket,
filename=output_filename,
complete_text=extracted_text,
complete_text_uri=complete_text_filename,
summary=summary,
summary_uri=output_filename,
timestamp=datetime.datetime.now()
)

return errors
24 changes: 20 additions & 4 deletions test/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import mock

from dataclasses import dataclass

from src.main import entrypoint

@dataclass
class CloudEventDataMock:
bucket: str
Expand Down Expand Up @@ -42,10 +49,19 @@ def __getitem__(self, key):
id='7631145714375969',
type='google.cloud.storage.object.v1.finalized',
data=CloudEventDataMock(
bucket='velociraptor-16p1-mock-users-bucket',
name='9404001v1.pdf',
bucket='velociraptor-16p1-src',
name='system-test/inputs/9404003v2.pdf',
metageneration='1',
timeCreated='2023-05-08T19:28:55.255Z',
updated='2023-05-08T19:28:55.255Z',
timeCreated=f"{datetime.datetime.now().isoformat()}Z",
updated=f"{datetime.datetime.now().isoformat()}Z",
)
)


def test_function_entrypoint():
context = mock.MagicMock()
context.event_id = f'system-test-{datetime.datetime.now().strftime("%m-%d-%Y-%H.%M.%S")}'
context.event_type = 'gcs-event'

errors = entrypoint(MOCK_CLOUD_EVENT, context)
assert len(errors) == 0
2 changes: 1 addition & 1 deletion test/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
backoff==2.2.1
mock
pytest==7.3.1
flaky==3.7.0
google-cloud-storage

0 comments on commit f27e726

Please sign in to comment.