Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validator, logging and modelling improvements #127

Merged
merged 14 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/tanuki/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,13 @@ def wrapper(*args, **kwargs) -> Union[Embedding, Any]:
# Configure the function modeler using incoming parameters
function_modeler.environment_id = environment_id
if ignore_finetuning:
logging.info(f"The flag for ignoring finetuning has been set True for {test_func.__name__}. No model distillation will be performed.")
function_modeler.execute_finetune_blacklist.append(func_hash)
if ignore_finetune_fetching:
logging.info(f"The flag for ignoring searching for finetuned models has been set True for {test_func.__name__}. No already finetuned models will be looked for.")
function_modeler.check_finetune_blacklist.append(func_hash)
if ignore_data_storage:
logging.info(f"The flag for ignoring data storage has been set True for {test_func.__name__}. No data will be read or saved and model distillation will not be performed.")
function_modeler.store_data_blacklist.append(func_hash)
task_type = function_description.type
if len(teacher_models) > 0:
Expand Down
29 changes: 18 additions & 11 deletions src/tanuki/function_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from typing import List, Tuple, Dict, Union

import openai
import logging

from tanuki.constants import EXAMPLE_ELEMENT_LIMIT, PATCHES, SYMBOLIC_ALIGNMENTS, POSITIVE_EMBEDDABLE_ALIGNMENTS, \
NEGATIVE_EMBEDDABLE_ALIGNMENTS, OPENAI_PROVIDER
Expand Down Expand Up @@ -318,7 +318,7 @@ def load_function_config(self, func_hash, function_description):

def _check_for_finetunes(self, function_description: FunctionDescription, finetune_provider : str) -> Tuple[bool, Dict]:
# hash the function_hash into 16 characters (to embed it into the name of OpenAI finetunes, for later retrieval)

logging.info(f"Checking for finetunes for {function_description.__name__} using {finetune_provider}")
finetune_hash = function_description.__hash__(purpose="finetune") + encode_int(self.environment_id)
# List 10 fine-tuning jobs
finetunes: List[FinetuneJob] = self.api_provider[finetune_provider].list_finetuned(limit=1000)
Expand All @@ -333,10 +333,12 @@ def _check_for_finetunes(self, function_description: FunctionDescription, finetu
config = self._construct_config_from_finetune(finetune_hash, finetune)
# save the config
self.data_worker.update_function_config(function_description.__hash__(), config)
logging.info(f"Found finetuned model for {function_description.__name__} [{config.distilled_model.model_name}]")
return True, config
except:
logging.info(f"Found finetuned model for {function_description.__name__} [{finetune.fine_tuned_model.model_name}] but could not load it")
return False, {}

logging.info(f"No finetuned model found for {function_description.__name__}")
return False, {}

def _construct_config_from_finetune(self, finetune_hash: str, finetune: FinetuneJob):
Expand Down Expand Up @@ -426,16 +428,16 @@ def check_for_finetuning(self, function_description, func_hash):
# check if already finetuning
if "job_id" in self.function_configs[func_hash].current_training_run:
# check for job status
self._check_finetuning_status(func_hash)
self._check_finetuning_status(func_hash, function_description)
else:
# check for finetuning condition
if self._check_finetuning_condition(func_hash):
if self._check_finetuning_condition(func_hash, function_description):
self._execute_finetuning(function_description, func_hash)
except Exception as e:
print(e)
print("Error checking for finetuning")

def _check_finetuning_condition(self, func_hash):
def _check_finetuning_condition(self, func_hash, function_description):
"""
Check if the finetuning condition is met
Currently finetuning condition is dependent on the number of symbolic datapoints since last finetuning
Expand All @@ -453,6 +455,9 @@ def _check_finetuning_condition(self, func_hash):
# if havent read in the patch dataset size, read it in
patch_dataset_size = self._get_dataset_info(PATCHES, func_hash, type="length")
self.dataset_sizes[PATCHES][func_hash] = patch_dataset_size
logging.info(f"Function {function_description.__name__} [{align_dataset_size} aligns | {patch_dataset_size} runs] will be finetuned from"\
f" {self.function_configs[func_hash].teacher_models.model_name} using {self.function_configs[func_hash].distilled_model.provider} in"\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't teacher_models a list? If so, how can you get model_name from it?

f"{training_threshold-(patch_dataset_size + align_dataset_size)} runs")

return (patch_dataset_size + align_dataset_size) > training_threshold

Expand Down Expand Up @@ -529,8 +534,10 @@ def _execute_finetuning(self, function_description, func_hash):
# Use the stream as a file
try:
finetune_provider = self.function_configs[func_hash].distilled_model.provider
logging.info(f"Starting finetuning for {function_description.__name__} using {finetune_provider}")
finetuning_response: FinetuneJob = self.api_provider[finetune_provider].finetune(file=temp_file, suffix=finetune_hash)
except Exception as e:
logging.info(f"Could not start finetuning for {function_description.__name__} using {finetune_provider}. Error: {e}")
return

self.function_configs[func_hash].current_training_run = {"job_id": finetuning_response.id,
Expand All @@ -544,7 +551,7 @@ def _execute_finetuning(self, function_description, func_hash):
print(e)
print("Could not update config file to register a finetuning run")

def _check_finetuning_status(self, func_hash):
def _check_finetuning_status(self, func_hash, function_description):
"""
Check the status of the current finetuning job
If the job is finished, update the config file to reflect the new model
Expand All @@ -560,18 +567,18 @@ def _check_finetuning_status(self, func_hash):
self.function_configs[func_hash].current_training_run["last_checked"] = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")
if response.status == "succeeded" or response.status == "failed":
self._update_finetune_config(response, func_hash)
self._update_finetune_config(response, func_hash, function_description)
else:
self._update_config_file(func_hash)

def _update_finetune_config(self, response: FinetuneJob, func_hash):
def _update_finetune_config(self, response: FinetuneJob, func_hash, function_description):
"""
Update the config file to reflect the new model and switch the current model to the finetuned model
"""
self.function_configs[func_hash].update_with_finetuned_response(response)
logging.info(f"Finetuning for {function_description.__name__} using {self.function_configs[func_hash].distilled_model.provider} finished with status {response.status}")
try:
self._update_config_file(func_hash)
except Exception as e:
print(e)
print("Could not update config file after a successful finetuning run")
logging.info(f"Could not update the function configuration file with the finetuned model for {function_description.__name__}. Error: {e}")
pass
25 changes: 16 additions & 9 deletions src/tanuki/language_models/language_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,21 @@ def generate(self, args, kwargs, function_description, llm_parameters={}):
The main generation function, given the args, kwargs, function description and model type, generate a response and check if the datapoint can be saved to the finetune dataset
"""

func_hash = function_description.__hash__()
prompt, model, save_to_finetune, is_distilled_model = self.get_generation_case(args, kwargs,
function_description,
llm_parameters)
func_hash = function_description.__hash__()
llm_parameters,
func_hash)
# loggings
if func_hash not in self.current_generators:
logging.info(f"Generating function outputs with {model.model_name}")
self.current_generators[func_hash] = model.model_name
elif self.current_generators[func_hash] != model.model_name:
logging.info(f"Switching output generation from {self.current_generators[func_hash]} to {model.model_name}")
self.current_generators[func_hash] = model.model_name
current_generator = self.current_generators.get(func_hash, None)
if current_generator:
generator_model = current_generator["model"]
if generator_model == "":
logging.info(f"Found {len(current_generator['examples'])} align statements for {function_description.name}. Generating function outputs with {model.model_name}.")
self.current_generators[func_hash]["model"] = model.model_name
elif generator_model != model.model_name:
logging.info(f"Switching output generation from {generator_model} to {model.model_name} for funcion {function_description.name}.")
self.current_generators[func_hash]["model"] = model.model_name

choice = self._synthesise_answer(prompt, model, llm_parameters)
output = LanguageModelOutput(choice, save_to_finetune, is_distilled_model)
Expand All @@ -114,7 +118,7 @@ def _synthesise_answer(self, prompt, model, llm_parameters):
return self.api_provider[model.provider].generate(model, system_message, prompt, **llm_parameters)


def get_generation_case(self, args, kwargs, function_description, llm_parameters):
def get_generation_case(self, args, kwargs, function_description, llm_parameters, func_hash):
"""
Get the generation case with the correct prompt and model
First get the current model, then if distilled model, do zero-shot prompt and return False as suitable_for_finetune
Expand All @@ -136,6 +140,9 @@ def get_generation_case(self, args, kwargs, function_description, llm_parameters
examples = [f"Inputs:\nArgs: {align['args']}\nKwargs: {align['kwargs']}\nOutput: {align['output']}" for align in
aligns]

if func_hash not in self.current_generators:
self.current_generators[func_hash] = {"model": "", "examples": examples}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the variable name current_generators to something more intuitive?


examples_token_count = sum([approximate_token_count(example) for example in examples])
generation_tokens = llm_parameters.get("max_new_tokens", self.default_generation_length)
model = self.choose_model_from_tokens(teacher_models,
Expand Down
9 changes: 7 additions & 2 deletions src/tanuki/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ def check_type(self, value: Any, type_definition: Any) -> bool:

# Handle tuples
if origin == tuple:
if not isinstance(value, tuple) or (args and len(value) != len(args)):
if not isinstance(value, tuple):
return False
return all(self.check_type(v, t) for v, t in zip(value, args))
item_type = args[0] if args else Any
return all(self.check_type(v, item_type) for v in value)

# Handle lists
if origin == list:
Expand Down Expand Up @@ -175,6 +176,8 @@ def check_type(self, value: Any, type_definition: Any) -> bool:
if self.is_pydantic_model(origin):
try:
#temp_model = create_model('TempModel', **value)
if isinstance(value, origin):
return True
#return isinstance(temp_model, origin)
# check if value is dict
if not isinstance(value, dict):
Expand Down Expand Up @@ -480,6 +483,8 @@ def instantiate(self, data: Any, target_type: Type) -> Any:
if self._is_subclass_of_generic(target_type, list) and not self._is_generic(target_type):
return target_type(instantiated_items)

return instantiated_items

# Handle tuples
if self._is_tuple_like(target_type) or (isinstance(origin, type) and issubclass(origin, tuple)):
base, item_types = self._find_generic_base_and_args(target_type)
Expand Down
12 changes: 8 additions & 4 deletions tests/test_token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_token_counter_finetunable():
prompt, distilled_model, suitable_for_distillation, is_distilled_model = lang_model.get_generation_case(args,
kwargs,
function_description,
{})
{},
"")
assert suitable_for_distillation
assert is_distilled_model
assert distilled_model.model_name == "test_ft_1"
Expand All @@ -54,7 +55,8 @@ def test_token_counter_non_finetunable_1():
prompt, distilled_model, suitable_for_distillation, is_distilled_model = lang_model.get_generation_case(args,
kwargs,
function_description,
{})
{},
"")
assert not suitable_for_distillation
assert not is_distilled_model
assert distilled_model.model_name == "gpt-4"
Expand All @@ -72,7 +74,8 @@ def test_token_counter_non_finetunable_2():
prompt, distilled_model, suitable_for_distillation, is_distilled_model = lang_model.get_generation_case(args,
kwargs,
function_description,
{})
{},
"")
assert not suitable_for_distillation
assert not is_distilled_model
assert distilled_model.model_name == "gpt-4-32k"
Expand All @@ -92,7 +95,8 @@ def test_error_raise():
prompt, distilled_model, suitable_for_distillation, is_distilled_model = lang_model.get_generation_case(args,
kwargs,
function_description,
{})
{},
"")
except ValueError:
error = True
assert error
Expand Down
20 changes: 18 additions & 2 deletions tests/test_validator/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class Person(BaseModel):
age: int
height: float
is_cool: bool
favourite_numbers: List[int]
even_more_favourite_numbers: tuple[int, ...]
favourite_dict: Dict[str, int]


def __eq__(self, other):
return self.model_dump() == other.model_dump()
Expand All @@ -27,10 +31,20 @@ def __hash__(self):
"name": "John",
"age": 20,
"height": 1.8,
"is_cool": True
"is_cool": True,
"favourite_numbers": [1, 2, 3],
"even_more_favourite_numbers": (1, 2, 3),
"favourite_dict": {"a": 1, "b": 2},
}
person_obj = validator.instantiate(person, Person)
assert isinstance(person_obj, Person)
# test lists
list_pydantic = [person, person]
person_obj = validator.instantiate(list_pydantic, List[Person])
assert isinstance(person_obj, list)
assert isinstance(person_obj[0], Person)
assert isinstance(person_obj[1], Person)
assert len(person_obj) == 2

# Nested data classes or Pydantic models.
@dataclass
Expand Down Expand Up @@ -76,8 +90,10 @@ def test_primitives():
assert validator.instantiate("1.0", str) != 1.0
assert validator.instantiate("true", str) != True
assert validator.instantiate({}, dict) == {}
assert validator.instantiate({"asd": 2, "bb": "ad"}, dict) == {"asd": 2, "bb": "ad"}
assert validator.instantiate([], list) == []
assert validator.instantiate((), tuple) == ()
assert validator.instantiate((1,2), tuple) == (1, 2)
assert validator.instantiate(set(), frozenset) == set()
assert validator.instantiate((), frozenset) == ()
assert validator.instantiate((), set) == ()
Expand Down Expand Up @@ -231,4 +247,4 @@ class ExtendedList(List[int]):
test_instantiate_dataclass()
test_primitives()
test_generics()
test_extended_generics()
test_extended_generics(Validator())
Loading