Skip to content

Commit

Permalink
feat: changes to address service test
Browse files Browse the repository at this point in the history
  • Loading branch information
telpirion committed May 12, 2023
1 parent c9d1cf0 commit 3433aa0
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 98 deletions.
9 changes: 3 additions & 6 deletions src/document_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def async_document_extract(
bucket: str,
name: str,
timeout: int = 420,
) -> Tuple[str, str]:
) -> str:
"""Perform OCR with PDF/TIFF as source files on GCS.
Original sample is here:
Expand All @@ -41,7 +41,7 @@ def async_document_extract(
Returns:
tuple: (text, gcs_output_path)
str: the complete text
"""

gcs_source_uri = f'gs://{bucket}/{name}'
Expand Down Expand Up @@ -102,7 +102,4 @@ def async_document_extract(

complete_text = complete_text + annotation['text']

blob = bucket.blob(gcs_output_path)
blob.upload_from_string(complete_text)

return (complete_text, gcs_output_path)
return complete_text
20 changes: 1 addition & 19 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,7 @@
from document_extract import async_document_extract
from storage import upload_to_gcs
from vertex_llm import predict_large_language_model_hack


def coerce_datetime_zulu(input_datetime: datetime.datetime):
"""Force datetime into specific format.
Args:
input_datetime (datetime.datetime): the datetime to coerce
"""
regex = re.compile(r"(.*)(Z$)")
regex_match = regex.search(input_datetime)
if regex_match:
assert input_datetime.startswith(regex_match.group(1))
assert input_datetime.endswith(regex_match.group(2))
return datetime.datetime.fromisoformat(f'{input_datetime[:-1]}+00:00')
raise RuntimeError(
'The input datetime is not in the expected format. '
'Please check format of the input datetime. Expected "Z" at the end'
)
from utils import coerce_datetime_zulu


@dataclass
Expand Down
68 changes: 68 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import re

ABSTRACT_LENGTH = 150 * 10 # Abstract recommended max word length * avg 10 letters long
CONCLUSION_LENGTH = 200 * 10 # Conclusion max word legnth * avg 10 letters long
ABSTRACT_H1 = 'abstract'
CONCLUSION_H1 = 'conclusion'

def coerce_datetime_zulu(input_datetime: datetime.datetime):
"""Force datetime into specific format.
Args:
input_datetime (datetime.datetime): the datetime to coerce
"""
regex = re.compile(r"(.*)(Z$)")
regex_match = regex.search(input_datetime)
if regex_match:
assert input_datetime.startswith(regex_match.group(1))
assert input_datetime.endswith(regex_match.group(2))
return datetime.datetime.fromisoformat(f'{input_datetime[:-1]}+00:00')
raise RuntimeError(
'The input datetime is not in the expected format. '
'Please check format of the input datetime. Expected "Z" at the end'
)


def truncate_complete_text(complete_text: str) -> str:
"""Extracts the abstract and conclusion from an academic paper.
Uses a heuristics to approximate the extent of the abstract and conclusion.
For abstract: assumes beginning after the string `abstract` and extends for 6-7 sentences
For conclusion: assumes beginning after the string `conclusion` and extends for 7-9 sentences
Args:
complete_text (str): the complete text of the academic paper
Returns
str: the truncated paper
"""
complete_text = complete_text.lower()
abstract_start = complete_text.find(ABSTRACT_H1)
conclusion_start = complete_text.find(CONCLUSION_H1)

abstract = complete_text[abstract_start:ABSTRACT_LENGTH]
conclusion = complete_text[conclusion_start:]
if len(conclusion) > CONCLUSION_LENGTH:
conclusion = conclusion[:CONCLUSION_LENGTH]

return f"""
Abstract: {abstract}
Conclusion: {conclusion}
"""
75 changes: 3 additions & 72 deletions src/vertex_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
from google.cloud import aiplatform
from vertexai.preview.language_models import TextGenerationModel

import datetime
import requests
from requests.adapters import HTTPAdapter
import urllib3


def predict_large_language_model(
project_id: str,
Expand All @@ -38,10 +33,10 @@ def predict_large_language_model(
Args:
project_id (str): the Google Cloud project ID
model_name (str): the name of the LLM model to use
temperature (float): TODO(nicain)
temperature (float): controls the randomness of predictions
max_decode_steps (int): TODO(nicain)
top_p (float): TODO(nicain)
top_k (int): TODO(nicain)
top_p (float): cumulative probability of parameter highest vocabulary tokens
top_k (int): number of highest propbability vocabulary tokens to keep for top-k-filtering
content (str): the text to summarize
location (str): the Google Cloud region to run in
tuned_model_name (str): TODO(nicain)
Expand All @@ -62,68 +57,4 @@ def predict_large_language_model(
return response.text


def predict_large_language_model_hack(
project_id: str,
model_name: str,
temperature: float,
max_decode_steps: int,
top_p: float,
top_k: int,
content: str,
location: str = "us-central1",
tuned_model_name: str = "",
) -> str:
"""Predict using a Large Language Model.
Args:
project_id (str): the Google Cloud project ID
model_name (str): the name of the LLM model to use
temperature (float): TODO(nicain)
max_decode_steps (int): TODO(nicain)
top_p (float): TODO(nicain)
top_k (int): TODO(nicain)
content (str): the text to summarize
location (str): the Google Cloud region to run in
tuned_model_name (str): TODO(nicain)
Returns:
The summarization of the content
"""
credentials, project_id = auth.default()
request = auth.transport.requests.Request()
credentials.refresh(request)

audience = f'https://us-central1-aiplatform.googleapis.com/v1/projects/cloud-large-language-models/locations/us-central1/endpoints/{model_name}:predict'
s = requests.Session()
retries = urllib3.util.Retry(
connect=10,
read=1,
backoff_factor=0.1,
status_forcelist=[429, 500],
)

headers = {}
headers["Content-type"] = "application/json"
headers["Authorization"] = f"Bearer {credentials.token}"

json_data = {
"instances": [
{"content": content},
],
"parameters": {
"temperature": temperature,
"maxDecodeSteps": max_decode_steps,
"topP": top_p,
"topK": top_k,
}
}

s.mount('https://', HTTPAdapter(max_retries=retries))
response = s.post(
audience,
headers=headers,
timeout=datetime.timedelta(minutes=15).total_seconds(),
json=json_data,
)

return response.json()['predictions'][0]['content']
2 changes: 1 addition & 1 deletion test/system-test/document_extract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@

@backoff.on_exception(backoff.expo, Exception, max_tries=3)
def test_async_document_extract(capsys):
out, _ = async_document_extract(BUCKET_NAME, FILE_NAME)
out = async_document_extract(BUCKET_NAME, FILE_NAME)
assert 'Abstract' in out
93 changes: 93 additions & 0 deletions test/system-test/services_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import backoff
import datetime
import os

from google.cloud import storage

from src.bigquery import write_summarization_to_table
from src.document_extract import async_document_extract
from src.storage import upload_to_gcs
from src.utils import truncate_complete_text
from src.vertex_llm import predict_large_language_model

PROJECT_ID = os.environ["PROJECT_ID"]
BUCKET_NAME = os.environ["BUCKET"]
DATASET_ID = "academic_papers"
TABLE_ID = "summarizations"
FILE_NAME = 'pdfs/9404001v1.pdf'
MODEL_NAME = 'text-bison@001'


def check_blob_exists(bucket, filename) -> bool:
client = storage.Client()
bucket = client.bucket(bucket)
blob = bucket.blob(filename)
return blob.exists()

@backoff.on_exception(backoff.expo, Exception, max_tries=3)
def test_up16_services():
extracted_text = async_document_extract(BUCKET_NAME, FILE_NAME)

assert "Abstract" in extracted_text

complete_text_filename = f'system-test/{FILE_NAME.replace(".pdf", "")}_fulltext.txt'
upload_to_gcs(
BUCKET_NAME,
complete_text_filename,
extracted_text,
)

assert check_blob_exists(BUCKET_NAME, complete_text_filename)

# TODO(erschmid): replace truncate with better solution
extracted_text_ = truncate_complete_text(extracted_text)
summary = predict_large_language_model(
project_id=PROJECT_ID,
model_name=MODEL_NAME,
temperature=0.2,
max_decode_steps=1024,
top_p=0.8,
top_k=40,
content=f'Summarize:\n{extracted_text_}',
location="us-central1",
)

assert summary != ""

output_filename = f'system-test/{FILE_NAME.replace(".pdf", "")}_summary.txt'
upload_to_gcs(
BUCKET_NAME,
output_filename,
summary,
)

assert check_blob_exists(BUCKET_NAME, output_filename)

errors = write_summarization_to_table(
project_id=PROJECT_ID,
dataset_id=DATASET_ID,
table_id=TABLE_ID,
bucket=BUCKET_NAME,
filename=output_filename,
complete_text=extracted_text,
complete_text_uri=complete_text_filename,
summary=summary,
summary_uri=output_filename,
timestamp=datetime.datetime.now()
)

assert len(errors) == 0

0 comments on commit 3433aa0

Please sign in to comment.