Skip to content

Commit

Permalink
Merge pull request #10 from deepopinion/feature/llama31
Browse files Browse the repository at this point in the history
Feature/llama31
  • Loading branch information
peerdavid authored Aug 14, 2024
2 parents efa5e24 + 7040b0d commit 0eda798
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 16 deletions.
2 changes: 1 addition & 1 deletion scripts/run_qa.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash

./scripts/run.sh doc_vqa $1
./scripts/run.sh mp_doc_vqa $1
# ./scripts/run.sh mp_doc_vqa $1
2 changes: 1 addition & 1 deletion src/benchmark_doc_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#
# Evaluate a single sample
#
semaphore = asyncio.Semaphore(10)
semaphore = asyncio.Semaphore(7)
async def evaluate_sample(sample):
async with semaphore:
try:
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark_kleister_charity.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def load_dataset():
#
# Evaluate a single sample
#
semaphore = asyncio.Semaphore(10)
semaphore = asyncio.Semaphore(7)
async def evaluate_sample(sample):
# This semaphore limits the memory consumption as we not load all images at once.
async with semaphore:
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark_kleister_nda.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def load_dataset():
#
# Evaluate a single sample
#
semaphore = asyncio.Semaphore(10)
semaphore = asyncio.Semaphore(7)
async def evaluate_sample(sample):
async with semaphore:
try:
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark_mp_doc_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#
# Evaluate a single sample
#
semaphore = asyncio.Semaphore(10)
semaphore = asyncio.Semaphore(7)
async def evaluate_sample(sample):
async with semaphore:
try:
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark_sroie.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def load_dataset():

return gt

semaphore = asyncio.Semaphore(10)
semaphore = asyncio.Semaphore(7)
async def evaluate_sample(file_name, label):
async with semaphore:
try:
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark_vrdu_ad_buy.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def load_dataset():
samples.append((item["filename"], kv))
return samples

semaphore = asyncio.Semaphore(10)
semaphore = asyncio.Semaphore(7)
async def evaluate_sample(ds, idx):
async with semaphore:
sample = ds[idx]
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark_vrdu_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load_dataset():
samples.append((item["filename"], kv))
return samples

semaphore = asyncio.Semaphore(10)
semaphore = asyncio.Semaphore(7)
async def evaluate_sample(ds, idx):
async with semaphore:
sample = ds[idx]
Expand Down
2 changes: 2 additions & 0 deletions src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from utils.json_parser import JsonParser
from utils.misc import (
ainvoke_die,
ainvoke_vqa,
log_result,
)

__all__ = [
"JsonParser",
"ainvoke_die",
"ainvoke_vqa",
"log_result",
Expand Down
22 changes: 22 additions & 0 deletions src/utils/json_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json
import re

from langchain_core.runnables import Runnable

_json_markdown_re = re.compile(r"```(json)?(.*)```", re.DOTALL)

class JsonParser(Runnable):

def invoke(self, *args, **kwargs) -> dict:
try:
text = args[0].content
match = _json_markdown_re.search(text)
if match:
text = match.group(2)

text = text.replace("\\n", "\n")
text = text.replace("```", "")

return text
except Exception as e:
raise
42 changes: 34 additions & 8 deletions src/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,25 @@
from io import BytesIO
from PIL import Image

from google.auth import default, transport

from langchain.pydantic_v1 import BaseModel
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage


from langchain_openai import ChatOpenAI
from langchain_google_vertexai import ChatVertexAI, HarmCategory, HarmBlockThreshold
from langchain_google_vertexai import (
ChatVertexAI,
HarmCategory,
HarmBlockThreshold,
)

from langchain_mistralai.chat_models import ChatMistralAI
from langchain_anthropic import ChatAnthropic

from utils import vision
from utils import latin
from utils import vision, latin, JsonParser

try:
from utils import sft
Expand Down Expand Up @@ -86,6 +93,7 @@ def create_llm(*, model:str):
convert_system_message_to_human=True, # This parameter is still not working -- if its used an exception is raised
**settings,
)

elif provider == "mistral":
endpoint = os.environ.get("MISTRAL_ENDPOINT")
api_key = os.environ.get("MISTRAL_API_KEY")
Expand All @@ -100,6 +108,21 @@ def create_llm(*, model:str):
model_name=model,
**settings,
)

elif provider == "model_garden":
credentials, _ = default(scopes=['https://www.googleapis.com/auth/cloud-platform'])
auth_request = transport.requests.Request()
credentials.refresh(auth_request)

PROJECT_ID = os.environ.get("VERTEXAI_PROJECT_ID")
MODEL_LOCATION = "us-central1"

llm = ChatOpenAI(
model=model,
base_url=f"https://{MODEL_LOCATION}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{MODEL_LOCATION}/endpoints/openapi/chat/completions?",
api_key=credentials.token,
)
return llm

raise Exception(f"Unknown provider: {provider}")

Expand All @@ -121,8 +144,8 @@ def get_provider(model:str):
return "mistral"
elif model.startswith("claude"):
return "anthropic"

raise Exception(f"Unknown model: {model}")
else:
return "model_garden" # We use the model garden otherwise


def sys_message(model:str):
Expand All @@ -131,7 +154,7 @@ def sys_message(model:str):
Therefore we convert it manually.
"""
provider = get_provider(model)
return "user" if provider == "vertexai" else "system"
return "user" if provider in ["vertexai", "model_garden"] else "system"

def requires_human_message(model:str):
provider = get_provider(model)
Expand Down Expand Up @@ -267,7 +290,8 @@ async def ainvoke_die(benchmark:str, model:str, method:str, pydantic_object:Base
llm = create_llm(model=model)
die_prompt = create_die_prompt(benchmark, model, method, images)
prompt = ChatPromptTemplate.from_messages(die_prompt)
chain = prompt | llm | parser
json_parser = JsonParser()
chain = prompt | llm | json_parser | parser

# Inference model a single time
async def _invoke():
Expand Down Expand Up @@ -344,8 +368,10 @@ async def _invoke():
output = await retry_invoke(_invoke)

# Return answer
write_cache(benchmark, model, method, images, output.content)
output = output.content
output = output.replace("\\n", "\n")

write_cache(benchmark, model, method, images, output)
return output


Expand Down

0 comments on commit 0eda798

Please sign in to comment.