diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5db44d16..cedb6b55 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -5,9 +5,9 @@ name: Python package on: push: - branches: [ "main", "staging" ] + branches: [ "main", "staging", "pre-staging" ] pull_request: - branches: [ "main", "staging" ] + branches: [ "main", "staging", "pre-staging" ] jobs: build: diff --git a/README.md b/README.md index c1d0e529..ad6e4904 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ The design of the network's incentive mechanism is based on two important requir It is imperative that the validation process engages with miners in the same way as real users. The reasons for this are as follows: - Miners will compete and continuously improve at performing the validation task(s), and so this fine tuning should be aligned with the goals of the subnet. -- It should not be possible to distinguish between validation and API client queries so that miners always serve requests (even when they do not recieve emissions for doing so). +- It should not be possible to distinguish between validation and API client queries so that miners always serve requests (even when they do not receive emissions for doing so). In the context of this subnet, miners are required to be intelligent AI assistants that provide helpful and correct responses to a range of queries. @@ -104,7 +104,7 @@ These validators are designed to run and update themselves automatically. To run pm2 start run.sh --name s1_validator_autoupdate -- --wallet.name <your-wallet-name> --wallet.hotkey <your-wallet-hot-key> ``` -This will run **two** PM2 process: one for the validator which is called `s1_validator_main_process` by default (you can change this in `run.sh`), and one for the run.sh script (in step 4, we named it `s1_validator_autoupdate`). The script will check for updates every 30 minutes, if there is an update then it will pull it, install it, restart `s1_validator_main_process` and then restart itself. +This will run **two** PM2 processes: one for the validator which is called `s1_validator_main_process` by default (you can change this in `run.sh`), and one for the run.sh script (in step 4, we named it `s1_validator_autoupdate`). The script will check for updates every 30 minutes, if there is an update then it will pull it, install it, restart `s1_validator_main_process` and then restart itself. diff --git a/neurons/miner.py b/neurons/miner.py index 1f9557db..c5870e98 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -176,11 +176,11 @@ def log_event(self, timing: float, prompt: str, completion: str, system_prompt: # This is the main function, which runs the miner. if __name__ == "__main__": - with Miner() as miner: + with Miner() as m: while True: - bt.logging.info("Miner running...", time.time()) + bt.logging.info(f"Miner running:: network: {m.subtensor.network} | block: {m.block} | step: {m.step} | uid: {m.uid} | last updated: {m.block-m.metagraph.last_update[m.uid]} | trust: {m.metagraph.trust[m.uid]:.3f} | emission {m.metagraph.emission[m.uid]:.3f}") time.sleep(5) - if miner.should_exit: + if m.should_exit: bt.logging.warning("Ending miner...") break \ No newline at end of file diff --git a/neurons/miners/openai/miner.py b/neurons/miners/openai/miner.py index 67350eba..f6ba3ce6 100644 --- a/neurons/miners/openai/miner.py +++ b/neurons/miners/openai/miner.py @@ -53,16 +53,16 @@ def __init__(self, config=None): if self.config.wandb.on: self.identity_tags = ("openai_miner", ) + (self.config.neuron.model_id, ) - - _ = load_dotenv(find_dotenv()) - api_key = os.environ.get("OPENAI_API_KEY") + + _ = load_dotenv(find_dotenv()) + api_key = os.environ.get("OPENAI_API_KEY") # Set openai key and other args self.model = ChatOpenAI( api_key=api_key, model_name=self.config.neuron.model_id, max_tokens = self.config.neuron.max_tokens, - temperature = self.config.neuron.temperature, + temperature = self.config.neuron.temperature, ) self.system_prompt = "You are a friendly chatbot who always responds concisely and helpfully. You are honest about things you don't know." @@ -122,7 +122,7 @@ async def forward( role = synapse.roles[-1] message = synapse.messages[-1] - + bt.logging.debug(f"💬 Querying openai: {prompt}") response = chain.invoke( {"role": role, "input": message} @@ -133,7 +133,7 @@ async def forward( if self.config.wandb.on: self.log_event( - timing=synapse_latency, + timing=synapse_latency, prompt=message, completion=response, system_prompt=self.system_prompt, @@ -141,6 +141,8 @@ async def forward( ) bt.logging.debug(f"✅ Served Response: {response}") + self.step += 1 + return synapse except Exception as e: bt.logging.error(f"Error in forward: {e}") @@ -160,4 +162,4 @@ async def forward( if miner.should_exit: bt.logging.warning("Ending miner...") - break + break diff --git a/neurons/miners/test/echo.py b/neurons/miners/test/echo.py index 6a89c061..d62bd7b8 100644 --- a/neurons/miners/test/echo.py +++ b/neurons/miners/test/echo.py @@ -41,7 +41,7 @@ async def forward( ) -> PromptingSynapse: synapse.completion = synapse.messages[-1] - + self.step += 1 return synapse async def blacklist( diff --git a/neurons/miners/test/mock.py b/neurons/miners/test/mock.py index 753316b8..17170c4e 100644 --- a/neurons/miners/test/mock.py +++ b/neurons/miners/test/mock.py @@ -41,7 +41,7 @@ async def forward( ) -> PromptingSynapse: synapse.completion = f'Hey you reached mock miner {self.config.wallet.hotkey!r}. Please leave a message after the tone.. Beep!' - + self.step += 1 return synapse async def blacklist( diff --git a/neurons/miners/test/phrase.py b/neurons/miners/test/phrase.py index 39fcde4a..07200dec 100644 --- a/neurons/miners/test/phrase.py +++ b/neurons/miners/test/phrase.py @@ -54,7 +54,7 @@ async def forward( ) -> PromptingSynapse: synapse.completion = self.config.neuron.phrase - + self.step += 1 return synapse async def blacklist( diff --git a/neurons/miners/wiki_agent/miner.py b/neurons/miners/wiki_agent/miner.py index d2349c34..ccb5acb1 100644 --- a/neurons/miners/wiki_agent/miner.py +++ b/neurons/miners/wiki_agent/miner.py @@ -29,7 +29,7 @@ class WikipediaAgentMiner(Miner): """Langchain-based miner which uses OpenAI's API as the LLM. This uses the ReAct framework. - + You should also install the dependencies for this miner, which can be found in the requirements.txt file in this directory. """ @classmethod @@ -41,14 +41,14 @@ def add_args(cls, parser: argparse.ArgumentParser): def __init__(self, config=None): super().__init__(config=config) - + bt.logging.info(f"🤖📖 Initializing wikipedia agent with model {self.config.neuron.model_id}...") if self.config.wandb.on: self.identity_tags = ("wikipedia_agent_miner", ) + (self.config.neuron.model_id, ) - - _ = load_dotenv(find_dotenv()) - + + _ = load_dotenv(find_dotenv()) + self.agent = WikiAgent(self.config.neuron.model_id, self.config.neuron.temperature) self.accumulated_total_tokens = 0 self.accumulated_prompt_tokens = 0 @@ -99,11 +99,11 @@ async def forward( with get_openai_callback() as cb: t0 = time.time() bt.logging.debug(f"📧 Message received, forwarding synapse: {synapse}") - + message = synapse.messages[-1] - + bt.logging.debug(f"💬 Querying openai and wikipedia: {message}") - + response = self.agent.run(message) synapse.completion = response @@ -111,7 +111,7 @@ async def forward( if self.config.wandb.on: self.log_event( - timing=synapse_latency, + timing=synapse_latency, prompt=message, completion=response, system_prompt='', @@ -119,6 +119,8 @@ async def forward( ) bt.logging.debug(f"✅ Served Response: {response}") + self.step += 1 + return synapse except Exception as e: bt.logging.error(f"Error in forward: {e}") diff --git a/neurons/miners/zephyr/miner.py b/neurons/miners/zephyr/miner.py index 84d497e6..72445909 100644 --- a/neurons/miners/zephyr/miner.py +++ b/neurons/miners/zephyr/miner.py @@ -122,6 +122,7 @@ async def forward(self, synapse: PromptingSynapse) -> PromptingSynapse: bt.logging.debug(f"✅ Served Response: {response}") torch.cuda.empty_cache() + self.step += 1 except Exception as e: bt.logging.error(f"Error: {e}") diff --git a/neurons/validator.py b/neurons/validator.py index 8c1a7127..28cc9fd6 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -55,7 +55,7 @@ def __init__(self, config=None): if p > 0 ] # Load the reward pipeline - self.reward_pipeline = RewardPipeline(selected_tasks=self.active_tasks, device=self.device) + self.reward_pipeline = RewardPipeline(selected_tasks=self.active_tasks, device=self.device) async def forward(self): """ @@ -100,12 +100,12 @@ def __exit__(self, exc_type, exc_value, traceback): # The main function parses the configuration and runs the validator. if __name__ == "__main__": - with Validator() as validator: + with Validator() as v: while True: - bt.logging.info("Validator running...", time.time()) + bt.logging.info(f"Validator running:: network: {v.subtensor.network} | block: {v.block} | step: {v.step} | uid: {v.uid} | last updated: {v.block-v.metagraph.last_update[v.uid]} | vtrust: {v.metagraph.validator_trust[v.uid]:.3f} | emission {v.metagraph.emission[v.uid]:.3f}") time.sleep(5) - if validator.should_exit: + if v.should_exit: bt.logging.warning("Ending validator...") break - + diff --git a/prompting/__init__.py b/prompting/__init__.py index a9650148..0505294c 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -16,7 +16,7 @@ # DEALINGS IN THE SOFTWARE. # Define the version of the template module. -__version__ = "1.0.4" +__version__ = "1.1.0" version_split = __version__.split(".") __spec_version__ = ( (10000 * int(version_split[0])) diff --git a/prompting/base/validator.py b/prompting/base/validator.py index 1f129272..e5571094 100644 --- a/prompting/base/validator.py +++ b/prompting/base/validator.py @@ -29,18 +29,18 @@ from prompting.base.neuron import BaseNeuron from prompting.mock import MockDendrite from prompting.utils.config import add_validator_args - +from prompting.utils.exceptions import MaxRetryError class BaseValidatorNeuron(BaseNeuron): """ Base class for Bittensor validators. Your validator should inherit from this class. """ - + @classmethod def add_args(cls, parser: argparse.ArgumentParser): super().add_args(parser) - add_validator_args(cls, parser) - + add_validator_args(cls, parser) + def __init__(self, config=None): super().__init__(config=config) @@ -127,10 +127,15 @@ def run(self): # Check that validator is registered on the network. self.sync() - - bt.logging.info( - f"Running validator {self.axon} on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}" - ) + + if not self.config.neuron.axon_off: + bt.logging.info( + f"Running validator {self.axon} on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}" + ) + else: + bt.logging.info( + f"Running validator on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}" + ) bt.logging.info(f"Validator starting at block: {self.block}") @@ -140,7 +145,14 @@ def run(self): bt.logging.info(f"step({self.step}) block({self.block})") # Run multiple forwards concurrently. - self.loop.run_until_complete(self.concurrent_forward()) + try: + self.loop.run_until_complete(self.concurrent_forward()) + except torch.cuda.OutOfMemoryError as e: + bt.logging.error(f"Out of memory error: {e}") + continue + except MaxRetryError as e: + bt.logging.error(f"MaxRetryError: {e}") + continue # Check if we should exit. if self.should_exit: @@ -161,8 +173,8 @@ def run(self): except Exception as err: bt.logging.error("Error during validation", str(err)) bt.logging.debug(print_exception(type(err), err, err.__traceback__)) - self.should_exit = True - + self.should_exit = True + def run_in_background_thread(self): """ Starts the validator's operations in a background thread upon entering the context. @@ -323,6 +335,7 @@ def update_scores(self, rewards: torch.FloatTensor, uids: List[int]): # shape: [ metagraph.n ] alpha = self.config.neuron.moving_average_alpha self.scores = alpha * step_rewards + (1 - alpha) * self.scores + self.scores = (self.scores - self.config.neuron.decay_alpha).clamp(min=0) bt.logging.debug(f"Updated moving avg scores: {self.scores}") def save_state(self): diff --git a/prompting/conversation.py b/prompting/conversation.py index 1601398c..263d65db 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -8,9 +8,9 @@ ) from prompting.tools import ( WikiDataset, - CodingDataset, + HFCodingDataset, MathDataset, - DateQADataset, + WikiDateDataset, ) from transformers import Pipeline @@ -26,13 +26,13 @@ def create_task(llm_pipeline: Pipeline, task_name: str) -> Task: dataset = WikiDataset() elif task_name in coding_based_tasks: - dataset = CodingDataset() + dataset = HFCodingDataset() elif task_name == "math": dataset = MathDataset() elif task_name == "date_qa": - dataset = DateQADataset() + dataset = WikiDateDataset() if task_name == "summarization": task = SummarizationTask(llm_pipeline=llm_pipeline, context=dataset.next()) diff --git a/prompting/mock.py b/prompting/mock.py index e5862c27..8e7d9ae3 100644 --- a/prompting/mock.py +++ b/prompting/mock.py @@ -74,8 +74,12 @@ def preprocess(self, **kwargs): class MockSubtensor(bt.MockSubtensor): - def __init__(self, netuid, n=16, wallet=None, network="mock"): - super().__init__(network=network) + def __init__(self, netuid, n=16, wallet=None): + + super().__init__() + # reset the underlying subtensor state + self.chain_state = None + self.setup() if not self.subnet_exists(netuid): self.create_subnet(netuid) @@ -102,6 +106,10 @@ def __init__(self, netuid, n=16, wallet=None, network="mock"): class MockMetagraph(bt.metagraph): + + default_ip = "127.0.0.0" + default_port = 8091 + def __init__(self, netuid=1, network="mock", subtensor=None): super().__init__( netuid=netuid, network=network, sync=False @@ -112,17 +120,17 @@ def __init__(self, netuid=1, network="mock", subtensor=None): self.sync(subtensor=subtensor) for axon in self.axons: - axon.ip = "127.0.0.0" - axon.port = 8091 - - bt.logging.info(f"Metagraph: {self}") - bt.logging.info(f"Axons: {self.axons}") + axon.ip = self.default_ip + axon.port = self.default_port class MockDendrite(bt.dendrite): """ Replaces a real bittensor network request with a mock request that just returns some static completion for all axons that are passed and adds some random delay. """ + min_time: float = 0 + max_time: float = 1 + def __init__(self, wallet): super().__init__(wallet) @@ -145,24 +153,24 @@ async def query_all_axons(streaming: bool): async def single_axon_response(i, axon): """Queries a single axon for a response.""" - start_time = time.time() + t0 = time.time() s = synapse.copy() # Attach some more required data so it looks real s = self.preprocess_synapse_for_request(axon, s, timeout) # We just want to mock the response, so we'll just fill in some data - process_time = random.random() + process_time = random.random()*(self.max_time-self.min_time) + self.min_time + await asyncio.sleep(process_time) if process_time < timeout: - s.dendrite.process_time = str(time.time() - start_time) # Update the status code and status message of the dendrite to match the axon s.completion = f'Mock miner completion {i}' s.dendrite.status_code = 200 s.dendrite.status_message = "OK" - synapse.dendrite.process_time = str(process_time) else: s.completion = "" s.dendrite.status_code = 408 s.dendrite.status_message = "Timeout" - synapse.dendrite.process_time = str(timeout) + + s.dendrite.process_time = str(time.time() - t0) # Return the updated synapse object after deserializing if requested if deserialize: diff --git a/prompting/rewards/__init__.py b/prompting/rewards/__init__.py index 509ceb79..faeb1d98 100644 --- a/prompting/rewards/__init__.py +++ b/prompting/rewards/__init__.py @@ -10,4 +10,4 @@ from .rouge import RougeRewardModel from .float_diff import FloatDiffModel from .date import DateRewardModel -from .pipeline import RewardPipeline +from .pipeline import RewardPipeline, REWARD_MODELS \ No newline at end of file diff --git a/prompting/rewards/date.py b/prompting/rewards/date.py index 119f7040..bf7d9446 100644 --- a/prompting/rewards/date.py +++ b/prompting/rewards/date.py @@ -1,7 +1,11 @@ import time import torch +import re +import pandas as pd +import numpy as np from typing import List from prompting.rewards import BaseRewardModel, BatchRewardOutput, RewardModelTypeEnum +import bittensor as bt class DateRewardModel(BaseRewardModel): @@ -11,42 +15,81 @@ def name(self) -> str: def __init__(self, **kwargs): super().__init__() + + def date_diff(self, ref_date: tuple, comp_date: tuple) -> int: + """ + Calculates the absolute difference in days between two dates. + """ + try: + return abs(ref_date[0] - comp_date[0]).days + 365*abs(int(ref_date[1]) - int(comp_date[1])) + except Exception as e: + return 500 - def date_score(self, reference, completion): - # TODO: cleanup code - score = 1 - #Take the last 4 characters of the reference as the year - year = reference[-4:] - month = reference.split()[0].strip() - month_num = str(time.strptime(month, "%B").tm_mon) - day = reference.split()[1].strip(',') - number_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - not_in_month_day_year = set(str(month_num) + str(day) + str(year)) - numbers = [str(x) for x in number_list if str(x) not in not_in_month_day_year] - # Create a list of the months - month_list = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"] - months = [x for x in month_list if x not in month] + def parse_dates_from_text(self, text: str) -> tuple: + """ + Parses dates from a body of text, handling various formats, and returns pandas datetime objects. + + Args: + text (str): The text to parse. + + Returns: + tuple: A tuple containing a datemtime object with they year set at 2000 and the actual year. + """ + + date_patterns = [ + r"\b(\d{1,2})[/-](\d{1,2})[/-](\d{3,4})\b", # MM/DD/YYYY or DD/MM/YYYY + r"\b(\d{1,2})[-/](\d{1,2})[-/](\d{2})\b", # MM/DD/YY or DD/MM/YY + r"\b(\d{1,2}) (January|February|March|April|May|June|July|August|September|October|November|December) (\d{3,4})\b", # DD Month, YYYY + r"\b(January|February|March|April|May|June|July|August|September|October|November|December) (\d{1,2})(,\s*)?(\d{3,4})\b", # Month DD, YYYY + ] + + for pattern in date_patterns: + matches = re.findall(pattern, text) + for match in matches: + try: + # Attempt to create a datetime object with year 2000 (datetime objects cannot take dates more than 200 years in the past) + parsed_date = pd.to_datetime(match[0] + "/" + match[1] + "/" + "2000") + year = match[-1] + # Check if the year is a number + if year.isdigit(): + # If the year is a digit, return the parsed date and the year in a tuple + return (parsed_date, year) + else: + raise ValueError + except ValueError: + pass + + return + + def date_score(self, reference: str, completion: str) -> float: + """Assign a score based on the difference between two dates using a negative exponential function. - if not year in completion: - score -= 0.5 - if not (month_num in completion or month in completion): - score -= 0.25 - if not day in completion: - score -= 0.25 + Args: + reference (str): The reference date. + completion (str): The completion date. - if not score == 0: - # Check if numbers are in completion - for number in numbers: - if str(number) in completion: - return 0.0 - # Check if months are in completion - for month in months: - if month in completion: - return 0.0 + Returns: + float: The score.""" + score = 0 + if not completion: + return score + ref_date = self.parse_dates_from_text(reference) + comp_date = self.parse_dates_from_text(completion) + score =np.exp(-(self.date_diff(ref_date, comp_date)**2/1000)) + # Clip any very small scores + if score < 0.001: + score = 0 return score def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: - """Compute difference scores given a completion and reference pair.""" + """Compute difference scores given a completion and reference pair. + + Args: + reference (str): The reference date. + completions (List[str]): A list of completions. + + Returns: + BatchRewardOutput: A BatchRewardOutput object containing the rewards and timings.""" rewards = [] timings = [] diff --git a/prompting/rewards/float_diff.py b/prompting/rewards/float_diff.py index a57d076b..b0e5e699 100644 --- a/prompting/rewards/float_diff.py +++ b/prompting/rewards/float_diff.py @@ -14,36 +14,40 @@ def __init__(self, **kwargs): super().__init__() @staticmethod - def extract_number(text): + def extract_number(text: str) -> float: + """Extract a number from a string.""" # loop over all words reversed and try to cast as a float, break when you find the first one - for word in text.split()[::-1]: + words = text.split() + for word in reversed(words): + cleaned = word.strip('.').replace(',', '') try: - # Convert the string to a float - return parse_expr(word.replace('$', '')) + return float(parse_expr(cleaned).evalf()) except Exception: - continue + # fall back to simpler parsing if required + try: + return float(cleaned) + except Exception: + continue @staticmethod - def math_score(reference, completion): - # Extract all the digits and numerical expressions from the completion and take only the last one (assuming it's the answer) - - # Convert the string to a float + def math_score(reference: str, completion: str) -> float: + """Compute a score based on the difference between a reference and a completion.""" + # Convert the strings to a float + reference = float(reference) pred = FloatDiffModel.extract_number(completion) if pred is None: return 0.0 try: - - # Convert reference to float (this is okay because we already checked that the reference is a float) - # TODO: More flexible parsing of the reference (just as with the completion) - ref = float(reference) - if pred == ref: + if pred == reference: return 1.0 # Compute the difference - diff = abs(ref - pred)/(ref + 1e-6) + diff = (reference - pred)/(reference + 1e-10) # Make sure the difference is between 0 and 1 diff = min(abs(diff), 1) - + # Clip any very small scores + if diff > 0.999: + diff = 1.0 return 1.0 - diff except Exception: return 0.0 diff --git a/prompting/rewards/pipeline.py b/prompting/rewards/pipeline.py index eee03513..2396e1f2 100644 --- a/prompting/rewards/pipeline.py +++ b/prompting/rewards/pipeline.py @@ -1,12 +1,6 @@ from typing import List -from prompting.tasks import ( - DebuggingTask, - SummarizationTask, - QuestionAnsweringTask, - MathTask, - DateQuestionAnsweringTask, -) +from prompting.tasks import TASKS from prompting.rewards import ( BaseRewardModel, RougeRewardModel, @@ -16,16 +10,6 @@ DateRewardModel, ) - -SUPPORTED_TASKS = { - "debugging": DebuggingTask, - "summarization": SummarizationTask, - "qa": QuestionAnsweringTask, - "math": MathTask, - "date_qa": DateQuestionAnsweringTask, -} - - REWARD_MODELS = { "rouge": RougeRewardModel, "relevance": RelevanceRewardModel, @@ -34,8 +18,6 @@ 'date': DateRewardModel, } - - class RewardPipeline: def __init__(self, selected_tasks: List[str], device): self.selected_tasks = selected_tasks @@ -53,11 +35,11 @@ def __repr__(self): return f'RewardPipeline({self.reward_models})' def validate_tasks(self): - + for task in self.selected_tasks: - if task not in SUPPORTED_TASKS: + if task not in TASKS: raise ValueError( - f"Task {task} not supported. Please choose from {SUPPORTED_TASKS.keys()}" + f"Task {task} not supported. Please choose from {TASKS.keys()}" ) # Check that the reward_definition and penalty_definition are lists of dictionaries whose weights sum to one self._check_weights(task, "reward_definition") @@ -67,10 +49,10 @@ def _check_weights(self, task, definition): total_weight = 0 - model_infos = getattr(SUPPORTED_TASKS[task], definition) - + model_infos = getattr(TASKS[task], definition) + for model_info in model_infos: - + if not isinstance(model_info, dict): raise ValueError(f"{definition} model {model_info} is not a dictionary.") if "weight" not in model_info: @@ -93,8 +75,8 @@ def load_pipeline(self): for task in self.selected_tasks: - active_reward_models += SUPPORTED_TASKS[task].reward_definition - active_reward_models += SUPPORTED_TASKS[task].penalty_definition + active_reward_models += TASKS[task].reward_definition + active_reward_models += TASKS[task].penalty_definition # Instantiate only the required reward models reward_models = {} diff --git a/prompting/tasks/__init__.py b/prompting/tasks/__init__.py index b429e227..64175e40 100644 --- a/prompting/tasks/__init__.py +++ b/prompting/tasks/__init__.py @@ -5,3 +5,12 @@ from .date_qa import DateQuestionAnsweringTask from .generic_instruction import GenericInstructionTask from .math import MathTask + + +TASKS = { + "qa": QuestionAnsweringTask, + "summarization": SummarizationTask, + "date_qa": DateQuestionAnsweringTask, + "debugging": DebuggingTask, + "math": MathTask, +} \ No newline at end of file diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index 3f48c6ae..8423ab4d 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -3,39 +3,34 @@ from prompting.cleaners.cleaner import CleanerPipeline +SECTION_MESSAGES = {'Births':' was born ', 'Deaths':' died ', 'Events':' '} + @dataclass class DateQuestionAnsweringTask(Task): + + name = "date-based question answering" + desc = "get help answering a specific date-based question" + goal = "to get the answer to the following date-based question" reward_definition = [ dict(name="date", weight=1.0), ] penalty_definition = [] + cleaning_pipeline = [ + dict(name="remove_quotes"), + dict(name="remove_roles"), + ] + static_reference = True + static_query = True def __init__(self, llm_pipeline, context, create_reference=True): - - self.name = "date-based question answering" - self.desc = "get help answering a specific date-based question" - self.goal = "to get the answer to the following date-based question" - - self.cleaning_pipeline = [ - dict(name="remove_quotes"), - dict(name="remove_roles"), - ] self.context = context - # The section is in {"Births", "Deaths", "Events"} - section = self.context["section"] - year, _, *event = self.context["event"].split() - event = " ".join(event) - - options = {'Births':' was born ', 'Deaths':' died ', 'Events':' '} + self.query = context.content + SECTION_MESSAGES[context.topic] + 'on what exact date?' + self.reference = self.context.title.replace('_',' ') + ", " + context.subtopic - self.query = event.strip(".") + options[section] + 'on what exact date?' - self.reference = self.context["date"] + ", " + year.strip() + self.topic = context.title + self.subtopic = context.topic + self.tags = context.tags - self.topic = section - self.subtopic = year - self.tags = [] - self.static_reference = True - self.static_query = True diff --git a/prompting/tasks/debugging.py b/prompting/tasks/debugging.py index 1ec1a3fb..a857abac 100644 --- a/prompting/tasks/debugging.py +++ b/prompting/tasks/debugging.py @@ -4,21 +4,6 @@ from prompting.tasks import Task import difflib -# The two options are: -# 1. Create a reference code and then introduce a bug to create the challenge code -# 2. Create a challenge code and then fix the bug to create the reference code - -QUERY_SYSTEM_PROMPT = """\ -You act as a coding teacher specializing in creating coding exercises by intentionally corrupting code snippets provided by the user. The purpose is to challenge the user to identify and fix the errors. When given a piece of code, analyze it, introduce errors or bugs, and then present the modified code as an exercise. The exercise will include the corrupted code and a question related to identifying or fixing the error. You should ensure that the errors introduced are logical or syntactical, suitable for educational purposes. It should not alter the core logic or purpose of the original code beyond recognition. Do not include comments or otherwise indicate in the code that it has been modified. The code should be within triple backticks (```). -""" - -# Used to obtain the query (which is a question about the context) -QUERY_PROMPT_TEMPLATE = """\ -Introduce a bug to the following {language} code snippet in triple backticks (```): - -# Code: -{context} -""" def corrupt( @@ -127,56 +112,37 @@ def diff(query, reference): @dataclass class DebuggingTask(Task): + + name = "debugging" + desc = "get help with debugging" + goal = "ask for help fixing broken code." + reward_definition = [ - dict(name="diff", lines=False, threshold=0.5, weight=1.0) + dict(name="diff", weight=1.0) ] + penalty_definition = [] - def __init__(self, llm_pipeline, context, create_reference=True): + static_reference = True + static_query = True - self.name = "debugging" - self.desc = "get help with debugging" - self.goal = "ask for help fixing the broken piece of code. When asking for help do not adjust the code in any way." + def __init__(self, llm_pipeline, context, create_reference=True): self.context = context # No LLM involved in generating the query, we just apply some language-independent corruption to the code - self.query = self.generate_query() - - if create_reference: - self.reference = self.generate_reference() - - self.delimiter="```" - self.topic=self.context["repo_name"] - self.subtopic=self.context["path"] - self.tags=[self.context["language"]] - self.static_reference = True - self.static_query = True - - def generate_query( - self, - llm=None, - n_remove=1, - n_swap=1, - seed=0, - sep="", - min_length=1, - max_length=10, - ): self.query = corrupt( - self.context["code"], - n_remove=n_remove, - n_swap=n_swap, - seed=seed, - sep=sep, - min_length=min_length, - max_length=max_length, - ) - return self.query + context.content, + n_remove=random.randint(1, 3), + n_swap=random.randint(0, 2), + sep=random.choices([""," ","\n"],weights=[0.3,0.6,0.1],k=1)[0] + ) + self.reference = context.content + self.delimiter = "```" + self.topic = context.title + self.subtopic = context.subtopic + self.tags = context.tags - def generate_reference(self, llm=None): - """Overrides the default reference generation to just return the reference code""" - return self.context["code"] def format_challenge(self, challenge): return f'{challenge}\n{self.delimiter}\n{self.query}\n{self.delimiter}' \ No newline at end of file diff --git a/prompting/tasks/math.py b/prompting/tasks/math.py index eac4eb1a..cf9a9983 100644 --- a/prompting/tasks/math.py +++ b/prompting/tasks/math.py @@ -6,33 +6,25 @@ @dataclass class MathTask(Task): + + name="math" + desc="get help solving a math problem" + goal="to get the answer to the following math question" + reward_definition = [ dict(name='float_diff', weight = 1.0), ] penalty_definition = [] + static_reference=True + static_query=True + def __init__(self, llm_pipeline, context, create_reference=True): - - reference = context["solution"] - - try: - float(reference) - except: - raise ValueError(f"Solution {reference} is not a float.") - - self.name="math" - self.desc="get help solving a math problem" - self.goal="to get the answer to the following math question" - + self.context = context - query = "How can I solve, " + context["problem"] + "?" - - self.query=query - self.reference=str(reference) - self.topic=context["topic"] - self.subtopic=context["subtopic"] - self.tags=[] - self.static_reference=True - self.static_query=True - + self.query = "How can I solve the following problem, " + context.content + "? Make sure to include the whole problem when you ask your question." + self.reference = context.extra['solution'] + self.topic = context.title + self.subtopic = context.topic + self.tags = context.tags diff --git a/prompting/tasks/qa.py b/prompting/tasks/qa.py index 14680142..9d4818d4 100644 --- a/prompting/tasks/qa.py +++ b/prompting/tasks/qa.py @@ -40,43 +40,43 @@ @dataclass class QuestionAnsweringTask(Task): - + + name = "question-answering" + desc = "get help on answering a question" + goal = "to get the answer to the following question" + reward_definition = [ dict(name="rouge", ngram="rouge-1", metric="f", weight=0.5), - dict(name="relevance", threshold=None, weight=0.5), + dict(name="relevance", weight=0.5), ] penalty_definition = [ dict(name="rouge", ngram="rouge-1", metric="f", weight=1.0), ] + cleaning_pipeline = [ + dict(name="remove_quotes"), + dict(name="prune_ending"), + dict(name="remove_roles"), + ] + def __init__(self, llm_pipeline, context, create_reference=True): - self.name = "question-answering" - self.desc = "get help on answering a question" - self.goal = "to get the answer to the following question" - self.cleaning_pipeline = [ - dict(name="remove_quotes"), - dict(name="prune_ending"), - dict(name="remove_roles"), - ] self.context = context self.query_system_prompt = QUERY_SYSTEM_PROMPT self.query_prompt = QUERY_PROMPT_TEMPLATE.format( - context = self.context["text"] + context = context.content ) self.query = self.generate_query(llm_pipeline) - self.reference_system_prompt = REFERENCE_SYSTEM_PROMPT self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format( - context = self.context["text"], question = self.query + context = context.content, question = self.query ) if create_reference: self.reference = self.generate_reference(llm_pipeline) - self.topic = self.context["title"] - self.subtopic = self.context["categories"][0] - self.tags = self.context["categories"] - + self.topic = context.title + self.subtopic = context.topic + self.tags = context.tags diff --git a/prompting/tasks/summarization.py b/prompting/tasks/summarization.py index 70f59c08..64c1db06 100644 --- a/prompting/tasks/summarization.py +++ b/prompting/tasks/summarization.py @@ -26,51 +26,44 @@ @dataclass class SummarizationTask(Task): - + + name = "summarization" + desc = "get help with summarization" + goal = "summarize the following topic" + reward_definition = [ dict(name="rouge", ngram="rouge-l", metric="f", weight=0.5), - dict(name="relevance", threshold=None, weight=0.5), + dict(name="relevance", weight=0.5), ] penalty_definition = [ dict(name="rouge", ngram="rouge-1", metric="f", weight=1.0) ] - def __init__(self, llm_pipeline: Pipeline, context: str, create_reference=True): - - self.name = "summarization" - self.desc = "get help with summarization" - self.goal = "summarize the following topic" - - self.context = context + # This is where you define cleaning procedures for the generation. + # Can be used when wanting to clean the challenge. + cleaning_pipeline = [ + dict(name="remove_quotes"), + dict(name="prune_ending"), + dict(name="remove_roles"), + ] - # Query is just the article title - self.query = self.context["title"] + static_query = True - # This is where you define cleaning procedures for the generation. - # Can be used when wanting to clean the challenge. - self.cleaning_pipeline = [ - dict(name="remove_quotes"), - dict(name="prune_ending"), - dict(name="remove_roles"), - ] + def __init__(self, llm_pipeline: Pipeline, context: str, create_reference=True): self.context = context - self.query_prompt = None - # NOTE: We do not perform an inference here and just use the article title as the query. - # This is because the article title is usually a good summary of the article itself. - # Query is just the article title. - query = self.context["title"] + # Query is just the article title and section name + self.query = context.title + ', ' + context.topic self.reference_system_prompt = SUMMARIZATION_SYSTEM_PROMPT self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format( - context = self.context["text"] + context = context.content ) if create_reference: self.reference = self.generate_reference(llm_pipeline) - self.topic = self.context["title"] - self.subtopic = self.context["categories"][0] - self.tags = self.context["categories"] - self.static_query = True + self.topic = context.title + self.subtopic = context.topic + self.tags = context.tags diff --git a/prompting/tasks/task.py b/prompting/tasks/task.py index 263e63e5..b273ed01 100644 --- a/prompting/tasks/task.py +++ b/prompting/tasks/task.py @@ -1,7 +1,7 @@ import time import bittensor as bt from abc import ABC -from dataclasses import dataclass +from dataclasses import dataclass, asdict from enum import Enum from typing import List, Union, Dict from prompting.llm import HuggingFaceLLM @@ -60,11 +60,10 @@ def __state_dict__(self, full=False): "reference_time": getattr(self, "reference_time", 0), "topic": self.topic, "subtopic": self.subtopic, - "context_time": self.context.get("fetch_time", 0.0), - # "tags": self.tags, + "context_time": self.context.stats.get("fetch_time", 0.0), } if full: - state.update(**self.context) + state.update(asdict(self.context)) return state diff --git a/prompting/tools/__init__.py b/prompting/tools/__init__.py index fb57c1ac..a8e395c2 100644 --- a/prompting/tools/__init__.py +++ b/prompting/tools/__init__.py @@ -1,8 +1,11 @@ -from .dataset import ( +from .datasets import ( + Context, + Dataset, MockDataset, - CodingDataset, + HFCodingDataset, WikiDataset, StackOverflowDataset, - DateQADataset, + WikiDateDataset, MathDataset, ) +from .selector import Selector diff --git a/prompting/tools/dataset.py b/prompting/tools/dataset.py deleted file mode 100644 index d46ebad1..00000000 --- a/prompting/tools/dataset.py +++ /dev/null @@ -1,512 +0,0 @@ -# The MIT License (MIT) -# Copyright © 2024 Yuma Rao -# Copyright © 2023 Opentensor Foundation - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -import time -import random -import string -from typing import Dict -import requests -import datetime -import mathgenerator -import bittensor as bt -from datasets import load_dataset -from bs4 import BeautifulSoup -from sympy.parsing.latex import parse_latex - -# TODO: Use beautiful soup to parse things like wikipedia articles and stack overflow questions and answers -# TODO: Use decorators or a parent class to time the next method so that context always has a fetch_time field - - -class MockDataset: - def next(self): - return {"text": "What is the capital of Texas?"} - - -def chunk(text, sep, n_chunks=None): - # choose a random chunk from the article - chunks = [chunk for chunk in text.split(sep) if chunk.strip()] - # select a subsequence of paragraphs - if n_chunks is None: - n_chunks = random.randint(1, len(chunks)) - - start_chunk = random.randint(0, len(chunks) - n_chunks) - bt.logging.info(f"Choosing {n_chunks} chunks starting at index {start_chunk}.") - - return sep.join(chunks[start_chunk : start_chunk + n_chunks]) - - -class CodingDataset: - all_languages = { - "C++": [".cpp", ".hpp", ".c++", ".h++", ".cc", ".hh", ".C", ".H"], - "CSS": [".css"], - "Dockerfile": [".dockerfile", "Dockerfile"], - "HTML": [".html"], - "Java": [".java"], - "JavaScript": [".js"], - "Python": [".py"], - "SQL": [".sql"], - "Shell": [".sh", ".bash", ".command", ".zsh"], - } - - def __init__( - self, - dataset_id="codeparrot/github-code", - seed=None, - languages=None, - buffer_size=10000, - ): - if seed is None: - seed = random.randint(0, 1000) - self.seed = seed - - if languages is None: - languages = list(self.all_languages.keys()) - self.languages = languages - - self.dataset_id = dataset_id - self.dataset = iter( - load_dataset( - dataset_id, - split="train", - streaming=True, - languages=self.languages, - ).shuffle(seed=seed, buffer_size=buffer_size) - ) - - def next(self, min_lines=5, max_lines=100): - bt.logging.debug("Retrieving code from prompting.dataset...") - t0 = time.time() - while True: - code = next(self.dataset) - if min_lines <= len(code["code"].splitlines()) <= max_lines: - code["fetch_time"] = time.time() - t0 - return code - - -class WikiDataset: - def __init__( - self, - min_length_words: int = 250, - min_length_bytes: int = 1000, - max_tries: int = 10, - min_backlinks: int = 1, - ): - # Wikipedia API endpoint for a random article - self.url = "https://en.wikipedia.org/w/api.php" - self.min_length_words = min_length_words - self.min_length_bytes = min_length_bytes - self.max_tries = max_tries - self.min_backlinks = min_backlinks - - def get_random_wikipedia_article(self) -> Dict: - """sample random wikipedia article - - Args: - min_length (int, optional): min number of words in article. Defaults to 1000. - min_backlinks (int, optional): backlink is a hyperlink from one webpage to another webpage. Defaults to 1. - """ - - # Parameters for the API request - params = { - "action": "query", - "format": "json", - "prop": "info|linkshere|categories|categoryinfo|extracts", - "generator": "random", - "grnnamespace": 0, # Namespace 0 indicates articles - "grnlimit": 10, # Number of random articles to fetch - "inprop": "url|displaytitle|length", # Requesting URL, title, and length of the page - "lhprop": "pageid", # Properties for links here (backlinks) - "lhlimit": "max", # Maximum number of backlinks to retrieve - "exlimit": "max", # Get extracts for each page - "cllimit": "max", # Get all categories for each page - } - - tries = 0 - while tries < self.max_tries: - # TODO: to avoid blocking from Wikipedia, we should provide a headers argument, where headers = {'User-Agent': 'Bittensor/0.0 (https://Bittensor.org; someone@opentensor.dev)'} - response = requests.get(self.url, params=params) - tries += 1 - - data = response.json() - if not data.get("query"): - continue - - for page_id, page_info in data["query"]["pages"].items(): - length = page_info.get("length", 0) - backlinks = len(page_info.get("linkshere", [])) - categories = [ - cat.get("title", "").strip("Category:") - for cat in page_info.get("categories", [{}]) - ] - # filter out any mention of articles - categories = [cat for cat in categories if "article" not in cat.lower()] - extract = page_info.get("extract") - - if ( - length >= self.min_length_bytes - and backlinks >= self.min_backlinks - and extract - ): # and views >= min_views: - return { - "title": page_info["title"], - "url": page_info["fullurl"], - "length": length, - "extract": extract, - "backlinks": backlinks, - "categories": categories, - } - - raise Exception( - f"Could not find an article with length >= {self.min_length_bytes} and backlinks >= {self.min_backlinks} after {self.max_tries} tries." - ) - - def get_wikipedia_article_content(self, title: str) -> str: - """Return wikipedia article content - - Args: - title (str): title of the article - remove_headers (bool, optional): remove the headers in the content body. Defaults to False. - - Returns: - str: article content - """ - # Parameters for the API request to get article content - params = { - "action": "query", - "format": "json", - "titles": title, - "prop": "extracts", - "explaintext": True, # Get plain text content - } - - # Making the API request - # TODO: to avoid blocking from Wikipedia, we should provide a headers argument, where headers = {'User-Agent': 'Bittensor/0.0 (https://Bittensor.org; someone@opentensor.dev)'} - response = requests.get(self.url, params=params) - data = response.json() - - # Extracting the page content - page = next(iter(data["query"]["pages"].values())) - content = page.get("extract", "Content not found.") - - text = {None: ""} - section = None - for line in content.split("\n"): - if line.startswith("==") and line.endswith("=="): - section = line.strip("=").strip() - text[section] = "" - continue - text[section] += line + "\n" - - return text - - def next( - self, subset=False, chunk_sep="\n", n_chunks: int = None, info: Dict = None - ) -> Dict: - """Iterate through random wikipedia articles - - Args: - subset (bool, optional): Randomly sample a chunk the article . Defaults to False. - chunk_sep (str, optional): If subsetting, define the delimiter to separate on. Defaults to "\n". - n_chunks (int, optional): If subsetting, define the number of chunks you want. Defaults to None. - info (Dict, optional): Select a known wikipedia article. Defaults to None. - - Raises: - Exception: If minimum number of words is less than min_length_words after max_tries tries. - - Returns: - Dict: information about the article - """ - bt.logging.debug("Retrieving data from prompting.dataset...") - tries = 0 - t0 = time.time() - while tries < self.max_tries: - if info is None: - info = self.get_random_wikipedia_article() - - info["sections"] = self.get_wikipedia_article_content(info["title"]) - text = "\n".join(info["sections"].values()) - tries += 1 - - if len(text.split()) >= self.min_length_words: - break - else: - info = None - - if tries == self.max_tries: - raise Exception( - f"Could not find an article with length >= {self.min_length_words} words after {self.max_tries} tries." - ) - - if subset in info["sections"].keys(): - text = info["sections"][subset] - elif subset: - text = chunk(text, sep=chunk_sep, n_chunks=n_chunks) - - info["text"] = text - info["fetch_time"] = time.time() - t0 - return info - - -class StackOverflowDataset: - def __init__(self): - # Stack Overflow API endpoint for a random article - self.url = "https://api.stackexchange.com/2.3/questions" - self.questions = [] - - def get_stack_questions(self): - url = "https://api.stackexchange.com/2.3/questions" - params = { - "order": "desc", - "sort": "votes", # Sorting by votes means that it's likely that the same questions will be fetched again - "site": "stackoverflow", - "pagesize": 100, # Fetch 100 questions per API call - "page": random.randint(1, 5), - } - - # Fetch questions - response = requests.get(url, params=params) - response.raise_for_status() - - # Parse response - questions = response.json()["items"] - - # Filter questions by minimum upvotes - min_upvotes = 10 - filtered_questions = [q for q in questions if q["score"] >= min_upvotes] - # Shuffle the questions - random.shuffle(filtered_questions) - - # Add the questions to the list of questions - self.questions.extend(filtered_questions) - return - - def get_stack_question(self) -> dict: - # If the list of questions is empty, fetch more questions - if not self.questions: - self.get_stack_questions() - question = self.questions.pop() - # Fetch the highest voted answer for the selected question - answer = self.get_stack_answer(question) - return {"question": question["title"], "answer": answer} - - def get_stack_answer(self, question): - question_id = question["question_id"] - url_answers = ( - f"https://api.stackexchange.com/2.3/questions/{question_id}/answers" - ) - params_answers = { - "order": "desc", - "sort": "votes", - "site": "stackoverflow", - "filter": "withbody", #'!9_bDDxJY5' - } - response_answers = requests.get(url_answers, params=params_answers) - response_answers.raise_for_status() - answers = response_answers.json()["items"] - if not answers: - bt.logging.warning("No answers found for the question!") - - highest_voted_answer = answers[0] # The first answer is the highest voted - soup = BeautifulSoup(highest_voted_answer["body"], "html.parser") - full_content = soup.get_text(separator="\n") - return full_content - - def next(self): - bt.logging.debug("Retrieving data from prompting.dataset...") - t0 = time.time() - info = self.get_stack_question() - info["fetch_time"] = time.time() - t0 - return info - - -class DateQADataset: - def __init__(self, max_tries: int = 10, seed=None): - self.max_tries = max_tries - self.seed = seed - self.rng = random.Random(seed) - - def get_random_event(self) -> Dict: - tries = 0 - while tries < self.max_tries: - # TODO: to avoid blocking from Wikipedia, we should provide a headers argument, where headers = {'User-Agent': 'Bittensor/0.0 (https://Bittensor.org; someone@opentensor.dev)'} - tries += 1 - - # Step 1: Generate a random date - year = 2000 - month = self.rng.randint(1, 12) - - max_days = 31 if month in (1, 3, 5, 7, 8, 10, 12) else 30 - max_days = max_days if month != 2 else 28 + int(year % 4 == 0) - day = self.rng.randint(1, max_days) - random_date = datetime.date(year, month, day) - - # Step 2: Format the date for Wikipedia URL - formatted_date = random_date.strftime("%B_%d") # E.g., "January_01" - - # Step 3: Scrape Wikipedia - url = f"https://en.wikipedia.org/wiki/{formatted_date}" - response = requests.get(url) - events = [] - - if response.status_code != 200: - bt.logging.debug( - f'Received status code {response.status_code} for URL "{url}". Retrying ({tries}/{self.max_tries})...' - ) - continue - - soup = BeautifulSoup(response.content, "html.parser") - available_sections = [] - for name in ["Events", "Births", "Deaths"]: - section = soup.find("span", id=name) - if section: - available_sections.append(name) - section = self.rng.choice(available_sections) - # Find the events section - events_list = soup.find("span", id=section).parent.find_next_sibling("ul") - - for li in events_list.find_all("li"): - events.append(li) - - # Step 4: Extract Event Information and Step 5: Select an Event - if not events: - continue - - selected_event = random.choice(events) - links = selected_event.find_all("a") - if links: - link = self.rng.choice(links) - - return { - "date": random_date.strftime("%B %d"), - "event": selected_event.get_text(), - "next_page": link.get("title"), - "section": section, - } - - def next(self): - bt.logging.debug("Retrieving data from prompting.dataset...") - t0 = time.time() - info = self.get_random_event() - info["fetch_time"] = time.time() - t0 - return info - - -class MathDataset: - topics_list = mathgenerator.getGenList() - - def __init__(self, seed=None): - # NOTE: Unfortunately, mathgenerator does not provide a way to seed the random number generator and get the same problem every time - - self.seed = seed - self.rng = random.Random(seed) - - def random_problem(self, parse): - if parse: - parseable_list = [ - 2, - 7, - 11, - 15, - 19, - 21, - 24, - 27, - 29, - 30, - 32, - 33, - 35, - 36, - 42, - 45, - 48, - 49, - 52, - 59, - 60, - 64, - 66, - 67, - 68, - 69, - 70, - 73, - 76, - 78, - 81, - 82, - 83, - 84, - 85, - 86, - 87, - 92, - 94, - 95, - 96, - 97, - 105, - 108, - 109, - 111, - 115, - 122, - 123, - ] - options = parseable_list - choice = self.rng.choice((options)) - # TODO: When the solution contains the symbol x we should specify the x value and substitute it in the solution - problem, solution = mathgenerator.genById(choice) - _, subtopic, _, _, topic, _ = self.topics_list[choice] - - subs = {} - # check if solution contains letters - if "x" in solution: - subs["x"] = 10 - bt.logging.warning( - "Coercing a symbolic expression to a numeric expression by substituting x=10" - ) - - # BUG: parse latex assumes that all letters are variables and so solutions like $No$ are interpreted as 'N * o' - solution_numeric = parse_latex( - str(solution).replace("$", "").strip() - ).evalf(subs=subs) - return { - "problem": problem, - "solution": solution_numeric, - "solution_raw": solution, - "topic": topic, - "subtopic": subtopic, - } - else: - options = mathgenerator.getGenList() - choice = self.rng.choice(range(len(options))) - problem, solution = mathgenerator.genById(choice) - _, subtopic, _, _, topic, _ = self.topics_list[choice] - return { - "problem": problem, - "solution": solution, - "topic": topic, - "subtopic": subtopic, - } - - def next(self, parse=True): - t0 = time.time() - info = self.random_problem(parse) - info["fetch_time"] = time.time() - t0 - return info diff --git a/prompting/tools/datasets/__init__.py b/prompting/tools/datasets/__init__.py new file mode 100644 index 00000000..ec4456ab --- /dev/null +++ b/prompting/tools/datasets/__init__.py @@ -0,0 +1,6 @@ +from .context import Context +from .base import Dataset +from .code import HFCodingDataset, StackOverflowDataset +from .math import MathDataset +from .mock import MockDataset +from .wiki import WikiDataset, WikiDateDataset \ No newline at end of file diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py new file mode 100644 index 00000000..e3b4944e --- /dev/null +++ b/prompting/tools/datasets/base.py @@ -0,0 +1,80 @@ +# The MIT License (MIT) +# Copyright © 2024 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import time +from abc import ABC, abstractmethod +from typing import Dict +import bittensor as bt + +from ..selector import Selector +from .context import Context + + +class Dataset(ABC): + """Base class for datasets.""" + + max_tries: int = 10 + + @abstractmethod + def search(self, name): + ... + + @abstractmethod + def random(self, name): + ... + + @abstractmethod + def get(self, name): + ... + + def next(self, method: str = 'random', selector: Selector = Selector(), **kwargs) -> Dict: + tries = 1 + t0 = time.time() + + while True: + + # TODO: Multithread the get method so that we don't have to suffer nonexistent pages + info = {} + if method == 'random': + info = self.random(selector=selector, **kwargs) + elif method == 'search': + info = self.search(selector=selector, **kwargs) + elif method == 'get': + info = self.get(selector=selector, **kwargs) + else: + raise ValueError(f"Unknown dataset get method {method!r}") + + if info: + break + + bt.logging.warning(f"Could not find an sample which meets {self.__class__.__name__} requirements after {tries} tries. Retrying... ({self.max_tries - tries} tries remaining.)") + + tries += 1 + if tries == self.max_tries: + raise Exception( + f"Could not find an sample which meets {self.__class__.__name__} requirements after {tries} tries." + ) + + info['stats'] = { + 'creator': self.__class__.__name__, + 'fetch_time': time.time() - t0, + 'num_tries': tries, + 'fetch_method': method, + 'next_kwargs': kwargs + } + return Context(**info) diff --git a/prompting/tools/datasets/code.py b/prompting/tools/datasets/code.py new file mode 100644 index 00000000..1c59eb96 --- /dev/null +++ b/prompting/tools/datasets/code.py @@ -0,0 +1,242 @@ +# The MIT License (MIT) +# Copyright © 2024 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import re +import time +import random +import requests +import itertools + +import bittensor as bt +from bs4 import BeautifulSoup + +from .base import Dataset +from ..selector import Selector +from datasets import load_dataset + +LANGUAGES = { + "C++": { + 'keywords': ['auto', 'break', 'case', 'char', 'const', 'continue', 'default', 'do', 'double', 'else', 'enum', 'extern', 'float', 'for', 'goto', 'if', 'int', 'long', 'register', 'return', 'short', 'signed', 'sizeof', 'static', 'struct', 'switch', 'typedef', 'union', 'unsigned', 'void', 'volatile', 'while'], + 'libraries': ['iostream', 'fstream', 'string', 'vector', 'map', 'set', 'algorithm', 'cmath', 'cstdio', 'cstdlib', 'ctime', 'cstring', 'cassert', 'cctype', 'cerrno', 'cfloat', 'ciso646', 'climits', 'clocale', 'cmath', 'csetjmp', 'csignal', 'cstdarg', 'cstddef', 'cstdio', 'cstdlib', 'cstring', 'ctime', 'cwchar', 'cwctype', 'complex', 'deque', 'exception', 'fstream', 'functional', 'iomanip', 'ios', 'iosfwd', 'iostream', 'istream', 'iterator', 'limits', 'list', 'locale', 'map', 'memory', 'new', 'numeric', 'ostream', 'queue', 'set', 'sstream', 'stack', 'stdexcept', 'streambuf', 'string', 'typeinfo', 'utility', 'valarray', 'vector'], + 'comments': ['//', '/*', '*/'], + }, + "Dockerfile": { + 'keywords': ['from', 'maintainer', 'run', 'cmd', 'expose', 'env', 'add', 'copy', 'entrypoint', 'volume', 'user', 'workdir', 'onbuild'], + 'libraries': [], + 'comments': ['#'] + }, + "HTML": { + 'keywords': ['div', 'span', 'input', 'ul', 'body', 'tag', 'html', 'head', 'title', 'meta', 'link', 'script', 'style', 'a', 'img', 'table', 'label'], + 'libraries': [], + 'comments': ['<!--', '-->'] + }, + "Java": { + 'keywords': ['abstract', 'assert', 'boolean', 'break', 'byte', 'case', 'catch', 'char', 'class', 'continue', 'default', 'do', 'double', 'else', 'enum', 'extends', 'final', 'finally', 'float', 'for', 'if', 'implements', 'import', 'instanceof', 'int', 'interface', 'long', 'native', 'new', 'package', 'private', 'protected', 'public', 'return', 'short', 'static', 'strictfp', 'super', 'switch', 'synchronized', 'this', 'throw', 'throws', 'transient', 'try', 'void', 'volatile', 'while'], + 'libraries': ['java.awt', 'java.awt.event', 'java.io', 'java.lang', 'java.math', 'java.net', 'java.text', 'java.util', 'javax.swing'], + 'comments': ['//', '/*', '*/', '*'], + }, + "JavaScript": { + 'keywords': ['abstract', 'arguments', 'boolean', 'break', 'byte', 'case', 'catch', 'char', 'class', 'const', 'continue', 'debugger', 'default', 'delete', 'do', 'double', 'else', 'enum', 'eval', 'export', 'extends', 'false', 'final', 'finally', 'float', 'for', 'function', 'goto', 'if', 'implements', 'import', 'in', 'instanceof', 'int', 'interface', 'let', 'long', 'native', 'module.exports' 'new', 'null', 'package', 'private', 'protected', 'public', 'return', 'short', 'static', 'super', 'switch', 'synchronized', 'this', 'throw', 'throws', 'transient', 'true', 'try', 'typeof', 'var', 'void', 'volatile', 'while', 'with', 'yield'], + 'libraries': ['react', 'express','mongoose', 'axios', 'redux', 'react-redux', 'react-router-dom', 'react-dom', 'react-scripts', 'material-ui'], + 'comments': ['//', '/*', '*/'] + }, + "Python": {'keywords': ['False', 'None', 'True', 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', 'with', 'yield'], + 'libraries': ['numpy', 'pandas', 'matplotlib', 'seaborn', 'scipy', 'sklearn', 'tensorflow', 'keras', 'pytorch', 'django', 'flask', 'requests', 'bs4', 'selenium', 'pyautogui', 'pyperclip', 'pyinputplus', 'pillow'], + 'comments': ['#'] + }, + "SQL": {'keywords': ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'backup', 'between', 'case', 'check', 'column', 'constraint', 'create', 'database', 'default', 'delete', 'desc', 'distinct', 'drop', 'exec', 'exists', 'foreign', 'from', 'full', 'group', 'having', 'in', 'index', 'inner', 'insert', 'into', 'is', 'join', 'key', 'left', 'like', 'limit', 'not', 'null', 'on', 'or', 'order', 'outer', 'primary', 'procedure', 'right', 'rownum', 'select', 'set', 'table', 'top', 'truncate', 'union', 'unique', 'update', 'values', 'view', 'where'], + 'comments': ['--', '/*', '*/'] + }, + "Shell": {'keywords': ['alias', 'bg', 'bind', 'break', 'builtin', 'caller', 'cd', 'command', 'compgen', 'complete', 'continue', 'declare', 'dirs', 'disown', 'echo', 'enable', 'eval', 'exec', 'exit', 'export', 'false', 'fc', 'fg', 'getopts', 'hash', 'help', 'history', 'jobs', 'kill', 'let', 'local', 'logout', 'popd', 'printf', 'pushd', 'pwd', 'read', 'readonly', 'return', 'set', 'shift', 'shopt', 'source', 'suspend', 'test', 'times', 'trap', 'true', 'type', 'typeset', 'ulimit', 'umask', 'unalias', 'unset', 'wait'], + 'comments': ['#'] + }, +} + +def filter_comments(code, language): + # TODO: multiline comments + # filter out comments + + # for start_tag, end_tag in languages[language]['multiline-comments']: + # code = re.sub(rf'{start_tag}.*?{end_tag}', '', code, flags=re.DOTALL) + + lines = [] + for line in code.splitlines(): + # TODO: use regex + if any(line.strip().startswith(symbol) for symbol in LANGUAGES[language]['comments']): + continue + + lines.append(line.lower()) + + return '\n'.join(lines) + + +#TODO: why not define the chain_in, chain_out logic in the class itself? +class HFCodingDataset(Dataset): + + def __init__( + self, + dataset_id="codeparrot/github-code", + seed=None, + languages=None, + buffer_size=10000, + ): + if seed is None: + seed = random.randint(0, 1000) + self.seed = seed + + if languages is None: + languages = list(LANGUAGES.keys()) + self.languages = languages + + self.dataset_id = dataset_id + self.dataset = iter( + load_dataset( + dataset_id, + split="train", + streaming=True, + languages=self.languages, + ).shuffle(seed=seed, buffer_size=buffer_size) + ) + + def get(self, min_lines=5, max_lines=100, selector: Selector = None): + + info = next(self.dataset) + + if not (min_lines <= len(info["code"].splitlines()) <= max_lines): + return None + + present_keywords, present_libraries = self.get_special_contents(info["code"], info["language"]) + keywords = list(present_keywords) + list(present_libraries) + code_words = ['code','programming','coding','code reference','programming technique'] + external_links = [] + for bigram in itertools.combinations(keywords, 2): + words = list(bigram) + [selector(code_words) + info['language']] + # shuffle the words e.g. ['react', 'promise', 'code reference'] -> 'code reference promise react' + external_links.append(' '.join(random.sample(words, len(words)))) + + return { + "title": info['repo_name'], # name of the repo + "topic": info["language"], # language of the code + 'subtopic': info['path'], + 'content': info["code"], + 'internal_links': [info['repo_name'], info['path'], info['language']], + 'external_links': external_links, + 'source': 'GitHub', + 'tags': [info['language'], info['repo_name'], info['path']], + 'extra': {'size': info['size'], 'license': info['license']} + } + + def search(self, query, min_lines=5, max_lines=100, selector: Selector = None, **kwargs): + # TODO: Would be great to be able to get other files from the same repo + raise NotImplementedError(f"Search is not implemented for {self.__class__.__name__}") + + def random(self, min_lines=5, max_lines=100, selector: Selector = None, **kwargs): + return self.get(min_lines, max_lines, selector) + + + def extract_keywords(self, code, language, field): + matches = set() + + # check which keywords and libraries are present in the code + for keyword in LANGUAGES[language].get(field,[]): + if re.search(r'\b' + keyword + r'\b', code): + matches.add(keyword) + + return matches + + def get_special_contents(self, code, language, remove_comments=True): + + if remove_comments: + code = filter_comments(code, language) + + present_libraries = self.extract_keywords(code, language, 'libraries') + present_keywords = self.extract_keywords(code, language, 'keywords') + + return present_keywords, present_libraries + + + +class StackOverflowDataset: + def __init__(self): + # Stack Overflow API endpoint for a random article + self.url = "https://api.stackexchange.com/2.3/questions" + self.questions = [] + + def get_stack_questions(self, min_upvotes=10): + params = { + "order": "desc", + "sort": "votes", # Sorting by votes means that it's likely that the same questions will be fetched again + "site": "stackoverflow", + "pagesize": 100, # Fetch 100 questions per API call + "page": random.randint(1, 5), + } + + # Fetch questions + response = requests.get(self.url, params=params) + response.raise_for_status() + + # Parse response + questions = response.json()["items"] + + # Filter questions by minimum upvotes + filtered_questions = [q for q in questions if q["score"] >= min_upvotes] + # Shuffle the questions + random.shuffle(filtered_questions) + + # Add the questions to the list of questions + self.questions.extend(filtered_questions) + return + + def get_stack_question(self) -> dict: + # If the list of questions is empty, fetch more questions + if not self.questions: + self.get_stack_questions() + question = self.questions.pop() + # Fetch the highest voted answer for the selected question + answer = self.get_stack_answer(question) + return {"question": question["title"], "answer": answer} + + def get_stack_answer(self, question): + question_id = question["question_id"] + url_answers = ( + f"https://api.stackexchange.com/2.3/questions/{question_id}/answers" + ) + params_answers = { + "order": "desc", + "sort": "votes", + "site": "stackoverflow", + "filter": "withbody", #'!9_bDDxJY5' + } + response_answers = requests.get(url_answers, params=params_answers) + response_answers.raise_for_status() + answers = response_answers.json()["items"] + if not answers: + bt.logging.warning("No answers found for the question!") + + highest_voted_answer = answers[0] # The first answer is the highest voted + soup = BeautifulSoup(highest_voted_answer["body"], "html.parser") + full_content = soup.get_text(separator="\n") + return full_content + + def next(self): + bt.logging.debug("Retrieving data from prompting.dataset...") + t0 = time.time() + info = self.get_stack_question() + info["fetch_time"] = time.time() - t0 + return info + diff --git a/prompting/tools/datasets/context.py b/prompting/tools/datasets/context.py new file mode 100644 index 00000000..10bb3d84 --- /dev/null +++ b/prompting/tools/datasets/context.py @@ -0,0 +1,18 @@ + +from typing import List +from dataclasses import dataclass + +@dataclass +class Context: + + # TODO: Pydantic model + title: str + topic: str + subtopic: str + content: str + internal_links: List[str] + external_links: List[str] + source: str + tags: List[str] = None + extra: dict = None # additional non-essential information + stats: dict = None # retrieval stats such as fetch time, number of tries, etc. diff --git a/prompting/tools/datasets/math.py b/prompting/tools/datasets/math.py new file mode 100644 index 00000000..7327563e --- /dev/null +++ b/prompting/tools/datasets/math.py @@ -0,0 +1,84 @@ +# The MIT License (MIT) +# Copyright © 2024 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import time +import random +import itertools +import mathgenerator +import bittensor as bt +from sympy.parsing.latex import parse_latex +from typing import Dict, Union, List, Tuple + + +from .base import Dataset +from ..selector import Selector + +class MathDataset(Dataset): + topics_list = mathgenerator.getGenList() + + def __init__(self, seed=None): + + self.seed = seed + self.rng = random.Random(seed) + + def get(self, name: str, selector: Selector = None, include: List = None, exclude: List = None, **kwargs) -> Dict: + """Get a math problem. + + Args: + name (str): Name of math problem to generate. + selector (Selector, optional): Selector instance to choose a specific problem. Defaults to None. + include (List, optional): _description_. Defaults to None. + exclude (List, optional): _description_. Defaults to None. + + Returns: + Dict: _description_ + """ + bt.logging.info(f"Getting math problem {name!r}") + info = mathgenerator.generate_context(name, **kwargs) + if info['reward_type'] != 'float': + return None + + math_words = ['math','mathematics','mathematical','math problem','math technique'] + external_links = [] + # construct external links from randomly shuffled trigrams containing 2 words from the problem and 1 random math word + # binary_to_decimal -> ['binary to', 'to decimal'] + for bigram in itertools.combinations(info['forward_words'], 2): + words = list(bigram) + [random.choice(math_words)] + # shuffle the words e.g. ['binary', 'decimal', 'math problem'] -> 'decimal binary math problem' + external_links.append(' '.join(random.sample(words, len(words)))) + + return { + "title": info['topic'], # title of math problem + "topic": info['topic'], # title of problem topic + 'subtopic': info['subtopic'], # title of problem subtopic + 'content': info['problem'], # problem statement + 'internal_links': [info['topic'], info['subtopic']], # internal links + 'external_links': external_links, + "tags": info['forward_words'], + 'source': 'Mathgenerator', + 'extra': {'reward_type': info['reward_type'], 'solution': info['solution']} + } + + def search(self, name, selector: Selector, include: List = None, exclude: List = None) -> Dict: + raise NotImplementedError(f"Search is not implemented for {self.__class__.__name__}") + + + def random(self, selector: Selector, **kwargs): + """Create a random math problem.""" + return self.get(name=None, selector=selector, **kwargs) + diff --git a/prompting/tools/datasets/mock.py b/prompting/tools/datasets/mock.py new file mode 100644 index 00000000..4e5a507f --- /dev/null +++ b/prompting/tools/datasets/mock.py @@ -0,0 +1,26 @@ + + + +from .base import Dataset +# from ..selector import Selector + +class MockDataset(Dataset): + + def get(self, name, exclude=None, selector=None): + return { + 'title': name, + 'topic': 'Physics', + 'subtopic': 'Quantum_mechanics', + 'content': f'{name} is a fraud. All of physics is a lie, the universe is a hologram, buy gold, bye!', + 'internal_links': ['Quantum_mechanics', 'General_relativity', 'Special_relativity', 'String_theory'], + 'external_links': ['Einstein', 'Bohr', 'Feynman', 'Hawking'], + "tags": ["fraud", "hologram", "gold"], + 'source': 'Mockpedia', + 'extra': {'solution': 'religion'}, + } + + def search(self, name, exclude=None, selector=None): + return self.get(name) + + def random(self, name='Physics', exclude=None, selector=None): + return self.get(name) \ No newline at end of file diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py new file mode 100644 index 00000000..f0c711c8 --- /dev/null +++ b/prompting/tools/datasets/wiki.py @@ -0,0 +1,275 @@ +# The MIT License (MIT) +# Copyright © 2024 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import re +import sys +import random +import datetime +import bittensor as bt +import wikipedia as wiki +from typing import Dict, Union, List, Tuple + +from functools import lru_cache +from .base import Dataset +from ..selector import Selector + +# speed up page loading +@lru_cache(maxsize=1000) +def _get_page(title, pageid=None, auto_suggest=False, redirect=True, seed=None) -> wiki.WikipediaPage: + """Cached Wikipedia page loading. + """ + try: + page = wiki.page(title=title, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect) + # create sections manually if not found + if not page.sections: + page._sections = [line.strip('= ') for line in page.content.splitlines() if re.search(r'=+\s+.*\s+=+',line)] + return page + + except wiki.DisambiguationError as e: + bt.logging.debug(f"{e.__class__.__name__} loading page {title!r}: {e}") + # exc info contains a tuple of (requested_title: str, possible_matches: List[str]) + pages = sys.exc_info()[1].args[1] + if not type(pages) == list: + return None + title = random.Random(seed).choice(pages) + return _get_page(title, auto_suggest=auto_suggest, redirect=redirect) + + except wiki.PageError as e: + bt.logging.warning(f"{e.__class__.__name__} loading page {title!r}: {e}") + if not auto_suggest: + return _get_page(title, auto_suggest=True, redirect=redirect) + return None + +@lru_cache(maxsize=1000) +def _get_random_titles(pages=10, seed=42) -> List: + """Cached wikipedia random page. Approximately deterministic random titles. This is useful for testing. + NOTE: the actually cached result will change each session, but the result will be the same within a session. + """ + return wiki.random(pages=pages) + +@lru_cache(maxsize=1000) +def _wiki_search(name, results) -> List: + """Cached Wikipedia search. + """ + return wiki.search(name, results=results) + +def process_page(page, valid_header: callable = None, valid_content: callable = None) -> Dict: + """Process a Wikipedia page and return a dictionary of sections with their content. + + Args: + page: wikipedia.WikipediaPage + valid_header: callable to determine if a section header is valid + valid_content: callable to determine if a section content is valid + Returns: + dict: dictionary of sections and their content. Note that keys are tuples (header, section_title) + """ + header = '' + sections = {} + + for section_title in page.sections: + content = page.section(section_title) + if not content: + header = section_title + continue + + # Filter out sections that don't match the headers and/or are not valid + if (valid_header and not valid_header(header)) or \ + (valid_content and not valid_content(content)): + continue + + key = (header, section_title) + sections[key] = content.splitlines() + + if not sections: + bt.logging.debug(f"No valid sections found in page {page.title!r} ({page.url})") + + return sections + + +def most_relevant_links(page, num_links=10, num_summary_words=50, return_scores=False): + """Return the most relevant links to a Wikipedia page based on the intersection over union (IOU) of the link and the page summary.""" + link_scores = {} + summary_words = set(page.summary.split()[:num_summary_words]) + for link in page.links: + link_words = set(link.split()) + iou = len(summary_words.intersection(link_words)) / len(summary_words.union(link_words)) + link_scores[link] = iou / len(link.split()) + + sorted_links = sorted(link_scores.items(), key=lambda x: x[1], reverse=True) + if return_scores: + return sorted_links[:num_links] + + return [link for link, _ in sorted_links[:num_links]] + +def filter_categories(categories, exclude=None, include=None): + """Filter categories based on a list of categories to exclude and/or include.""" + if exclude: + categories = [cat for cat in categories if not re.search('|'.join(exclude), cat,re.IGNORECASE)] + if include: + categories = [cat for cat in categories if re.search('|'.join(include), cat,re.IGNORECASE)] + return categories + +class WikiDataset(Dataset): + """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" + + EXCLUDE_HEADERS = ('See also', 'References', 'Further reading', 'External links') + EXCLUDE_CATEGORIES = ('articles', 'wiki', 'pages', 'cs1') + + def __init__( + self, + min_length_words: int = 50, + max_links: int = 10, + ): + """ + Args: + min_length_words (int, optional): Minimum section length. Defaults to 50. + max_links (int, optional): _description_. Defaults to 10. + """ + self.min_length_words = min_length_words + self.max_links = max_links + + + def get(self, name: str, selector: Selector = None, include: List = None, exclude: List = None, **kwargs) -> Dict: + """Get a specified Wikipedia page and extract a section based on the selector. + + Args: + name (_type_): _description_ + pageid (_type_, optional): _description_. Defaults to None. + auto_suggest (bool, optional): _description_. Defaults to True. + redirect (bool, optional): _description_. Defaults to True. + selector (Selector, optional): _description_. Defaults to None. + include (List, optional): _description_. Defaults to None. + exclude (List, optional): _description_. Defaults to None. + + Returns: + Dict: _description_ + """ + + page = _get_page(title=name, **kwargs) + if page is None: + return None + + # Only return a sections with a minimum number of words + exclude = (exclude or []) + list(self.EXCLUDE_HEADERS) + sections = process_page(page, + valid_header=lambda x: x not in exclude and (not include or x in include), + valid_content=lambda x: len(x.split())>=self.min_length_words + ) + if not sections: + return None + + key = header, section_title = selector(list(sections.keys())) + content = '\n'.join(sections[key]) + section_length = len(content.split()) + return { + "title": name, # title of wiki article + "topic": header or section_title, # title of wiki section + 'subtopic': section_title, + 'content': content, + 'internal_links': list(filter(lambda x: x not in exclude, page.sections)), + 'external_links': most_relevant_links(page, num_links=self.max_links), + 'tags': filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES), + 'source': 'Wikipedia', + 'extra': {'url': page.url, 'page_length': len(page.content.split()), 'section_length': section_length}, + } + + def search(self, name, results=3, selector: Selector = None) -> Dict: + titles = _wiki_search(name, results=results) + title = selector(titles) + return self.get(title, selector=selector) + + def random(self, pages=10, seed=None, selector: Selector = None, **kwargs) -> Dict: + titles = wiki.random(pages=pages) if seed is None else _get_random_titles(pages=pages, seed=seed) + title = selector(titles) + return self.get(title, selector=selector) + + + + +class WikiDateDataset(Dataset): + + INCLUDE_HEADERS = ("Events", "Births", "Deaths") + MONTHS = ("January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December") + EXCLUDE_CATEGORIES = ('articles', 'wiki', 'pages', 'cs1') + + def __init__(self, max_tries: int = 10, seed=None): + self.max_tries = max_tries + self.seed = seed + self.rng = random.Random(seed) + + def _random_date(self, year: int = None, month: int = None) -> int: + """Returns a random date in the format "Month_DD" (e.g., "January_01").""" + if year is None: + year = self.rng.randint(0, 2024) + if month is None: + month = self.rng.randint(1, 12) + + max_days = 31 if month in (1, 3, 5, 7, 8, 10, 12) else 30 + max_days = max_days if month != 2 else 29 + + day = self.rng.randint(1, max_days) + + random_date = datetime.date(year, month, day) + # Step 2: Format the date for Wikipedia URL + return random_date.strftime("%B_%d") # E.g., "January_01" + + def get(self, name, pageid=None, auto_suggest=True, redirect=True, selector: Selector = None) -> Dict: + + # Check that name is correctly formatted e.g., "January_01" + date = name.split('_') + assert len(date)==2, f"Date should be in the format 'Month_DD' (e.g., 'January_01'), but got {name!r}" + assert date[0] in self.MONTHS, f"Month should be one of {self.MONTHS}, but got {date[0]!r}" + assert date[1].isdigit(), f"Day should be a number, but got {date[1]!r}" + + page = _get_page(title=name, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect) + if page is None: + return None + + # Only return a sections which contain event-like format + # e.g. "1999 - Some event happened" + sections = process_page(page, + valid_header=lambda x: x in self.INCLUDE_HEADERS, + valid_content=lambda x: any([re.search(r'^\d+',line) for line in x.splitlines()]) + ) + if not sections: + return None + + key = header, section_title = selector(list(sections.keys())) + line = selector(sections[key]) + year, *event = line.replace(u'\u2013', '-').split('-') + links = [link for link in page.links if link in line] + + return { + "title": name, # title of wiki article + "topic": header or section_title, # title of wiki section + 'subtopic': year.strip(), + 'content': '-'.join(event).strip('. '), + 'internal_links': list(sections.keys()), + 'external_links': links, + 'tags': filter_categories(page.categories, exclude=WikiDataset.EXCLUDE_CATEGORIES), + 'source': 'Wikipedia', + 'extra': {'url': page.url, 'year': year, 'event': event, 'line': line, 'date': date+[year], 'section_title': section_title}, + } + + def search(self, name, results=5, selector: Selector = None) -> Dict: + raise NotImplementedError(f"Search is not implemented for {self.__class__.__name__}") + + def random(self, selector: Selector = None, **kwargs) -> Dict: + date = self._random_date() + return self.get(date, selector=selector) + diff --git a/prompting/tools/selector.py b/prompting/tools/selector.py new file mode 100644 index 00000000..9a421d80 --- /dev/null +++ b/prompting/tools/selector.py @@ -0,0 +1,48 @@ +import random + +class Selector: + def __init__(self, seed=None): + self.seed = seed + self.rng = random.Random(seed) + + def __call__(self, items, weights=None): + return self.rng.choices(items, weights=weights)[0] + + +class PageRankSelector(Selector): + """Preferentially chooses the items at the top of the list, under the assumption that they are more important.""" + def __init__(self, seed=None, alpha=0.85): + super().__init__(seed) + self.alpha = alpha + + def __call__(self, items): + weights = [self.alpha**i for i in range(len(items))] + return self.rng.choices(items, weights=weights)[0] + + +class SimilaritySelector(Selector): + """Chooses the item most similar to the query.""" + def __init__(self, seed=None, similarity_fn=None): + super().__init__(seed) + self.similarity_fn = similarity_fn + + def __call__(self, query, items): + return max(items, key=lambda item: self.similarity_fn(query, item)) + + +class TopSelector(Selector): + """Chooses the top item.""" + def __init__(self, seed=None): + super().__init__(seed) + + def __call__(self, items): + return items[0] + + +if __name__ == "__main__": + + selector = Selector(seed=42) + items = range(10) + item = selector(items) + + assert item in items, "Selector should return one of the items" \ No newline at end of file diff --git a/prompting/utils/config.py b/prompting/utils/config.py index a6001206..1c5dbb84 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -263,7 +263,7 @@ def add_validator_args(cls, parser): type=float, nargs="+", help="The probability of sampling each task.", - default=[0.5, 0.5, 0.0, 0.0, 0.0], + default=[0.25, 0.25, 0.0, 0.25, 0.25], ) parser.add_argument( @@ -305,7 +305,14 @@ def add_validator_args(cls, parser): "--neuron.moving_average_alpha", type=float, help="Moving average alpha parameter, how much to add of the new observation.", - default=0.05, + default=0.1, + ) + + parser.add_argument( + "--neuron.decay_alpha", + type=float, + help="Constant decay rate for the moving average score.", + default=0.001, ) parser.add_argument( diff --git a/prompting/utils/exceptions.py b/prompting/utils/exceptions.py new file mode 100644 index 00000000..fcbf6688 --- /dev/null +++ b/prompting/utils/exceptions.py @@ -0,0 +1,7 @@ + +class MaxRetryError(Exception): + """Exception raised when the maximum number of retries is exceeded.""" + + def __init__(self, message="Maximum number of retries exceeded"): + self.message = message + super().__init__(self.message) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 06d47650..66f4ec4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ torch==2.1.1 torchmetrics transformers==4.36.2 pre-commit==3.3.2 -mathgenerator # TODO: Use our own fork +git+https://github.com/synapse-alpha/mathgenerator.git@main#egg=mathgenerator numpy==1.22.0 rouge scipy==1.10.1 @@ -16,3 +16,4 @@ wandb==0.15.10 tenacity antlr4-python3-runtime==4.11 wikipedia +wikipedia_sections diff --git a/setup.py b/setup.py index e358bc6f..dd66b226 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,6 @@ def read_requirements(path): processed_requirements.append(req) return processed_requirements - requirements = read_requirements("requirements.txt") here = path.abspath(path.dirname(__file__)) diff --git a/tests/fixtures/dataset.py b/tests/fixtures/dataset.py index af82d9bf..df9bbc64 100644 --- a/tests/fixtures/dataset.py +++ b/tests/fixtures/dataset.py @@ -1,51 +1,38 @@ -from prompting.tools import MockDataset, CodingDataset, WikiDataset, StackOverflowDataset, DateQADataset, MathDataset +from prompting.tools.datasets import MockDataset, HFCodingDataset, WikiDataset, WikiDateDataset, MathDataset DATASETS = [ - # MockDataset, - CodingDataset, + MockDataset, + HFCodingDataset, WikiDataset, - # StackOverflowDataset, - DateQADataset, + WikiDateDataset, MathDataset, ] -WIKI_ARTICLE = { - 'title': 'Emilio Alvarez (bishop)', - 'url': 'https://en.wikipedia.org/wiki/Emilio_Alvarez_(bishop)', - 'length': 8185, - 'extract': '<p><b>Emilio Alvarez</b> (born January 16) is a religious leader in the United States, and founding bishop of the Union of Charismatic Orthodox Churches. He is also the founding director of the Institute for Paleo-Orthodox Christian Studies (formerly the certificate in Convergence Studies Program at New York Theological Seminary).', - 'backlinks': 7, - 'categories': [ - '21st-century American bishops', - '21st-century Puerto Rican peopl', - 'nvergence Movemen', - 'Living peopl', - 'People of Afro–Puerto Rican descen', - 'Puerto Rican bishops', - 'Religious leaders from New York (state)', - 'Short description matches Wikid', - 'Writers from New York (state)', - 'Year of birth missing (living people)' - ] - } -WIKI_CONTEXT = WikiDataset().next(info=WIKI_ARTICLE) -CODING_CONTEXT = CodingDataset(buffer_size=10).next() +MOCK_CONTEXT = MockDataset().next() +WIKI_CONTEXT = WikiDataset().next(name='Emilio Alvarez (bishop)', method='get') +CODING_CONTEXT = HFCodingDataset(buffer_size=1, seed=42).next() MATH_CONTEXT = MathDataset(seed=123).next() -DATEQA_CONTEXT = DateQADataset(seed=123).next() +DATEQA_CONTEXT = WikiDateDataset(seed=123).next() CONTEXTS = { + MockDataset: MOCK_CONTEXT, WikiDataset: WIKI_CONTEXT, - CodingDataset: CODING_CONTEXT, + HFCodingDataset: CODING_CONTEXT, MathDataset: MATH_CONTEXT, - DateQADataset: DATEQA_CONTEXT, + WikiDateDataset: DATEQA_CONTEXT, } - CONTEXT_FIELDS = { - WikiDataset: {"text", "title", "categories", "url", "sections", "fetch_time", "length", "backlinks", "extract"}, - CodingDataset: {"code", "repo_name", "path", "language", "size", "fetch_time", "license"}, - MathDataset: {"problem", "solution", 'topic', 'subtopic', "fetch_time", "solution_raw"}, - DateQADataset: {"section", "event", 'date', "next_page", "fetch_time"}, -} + 'title': str, + 'topic': str, + 'subtopic': str, + 'content': str, + 'internal_links': list, + 'external_links': list, + 'source': str, + 'tags': list, + 'extra': dict, + 'stats': dict, +} \ No newline at end of file diff --git a/tests/fixtures/task.py b/tests/fixtures/task.py index 20980aea..fd637295 100644 --- a/tests/fixtures/task.py +++ b/tests/fixtures/task.py @@ -1,4 +1,5 @@ from prompting.tasks import Task, QuestionAnsweringTask, SummarizationTask, DebuggingTask, MathTask, DateQuestionAnsweringTask +from prompting.tools import Context from .dataset import WIKI_CONTEXT, CODING_CONTEXT, MATH_CONTEXT, DATEQA_CONTEXT TASKS = [ @@ -9,7 +10,6 @@ DateQuestionAnsweringTask, ] -# TODO: Make fully deterministic CONTEXTS = { QuestionAnsweringTask: WIKI_CONTEXT, SummarizationTask: WIKI_CONTEXT, @@ -18,3 +18,26 @@ DateQuestionAnsweringTask: DATEQA_CONTEXT, } +TASK_FIELDS = { +'name': str, +'desc': str, +'goal': str, +'query': str, +'topic': str, +'subtopic': str, +'tags': list, +'context': Context, +'reward_definition': list, +'reference': str, +#'reward_threshold': float , +'penalty_definition': list, +# 'criteria': str = ("",), +'delimiter': str, +'complete': bool, +'static_reference': bool, +'static_query': bool, +'reference_system_prompt': str, +'reference_prompt': str, +'query_system_prompt': str, +'query_prompt': str, +} \ No newline at end of file diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 365b85e8..b23cf146 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,20 +1,63 @@ import pytest from .fixtures.dataset import DATASETS, CONTEXTS, CONTEXT_FIELDS +from prompting.tools.datasets import Dataset +from prompting.tools import Context @pytest.mark.parametrize('dataset', DATASETS) -def test_create_dataset(dataset): - data = dataset() - assert data is not None +def test_create_dataset(dataset: Dataset): + ds = dataset() + assert ds is not None @pytest.mark.parametrize('dataset', DATASETS) -def test_context_is_dict(dataset): - assert type(CONTEXTS[dataset]) == dict +def test_dataset_is_subclass_of_dataset_class(dataset: Dataset): + ds = dataset() + assert issubclass(type(ds), Dataset) + @pytest.mark.parametrize('dataset', DATASETS) -def test_dataset_context_contains_expected_fields(dataset): - assert set(CONTEXTS[dataset].keys()) == CONTEXT_FIELDS[dataset] +@pytest.mark.parametrize('method', ('next', 'get', 'random', 'search')) +def test_dataset_has_expected_methods(dataset: Dataset, method: str): + ds = dataset() + assert hasattr(ds, method) + assert callable(getattr(ds, method)) + + +@pytest.mark.skip(reason="Not implemented") +@pytest.mark.parametrize('dataset', DATASETS) +@pytest.mark.parametrize('method', ('next', 'get', 'random', 'search')) +def test_dataset_methods_return_contexts(dataset: Dataset, method: str): + ds = dataset() + assert hasattr(ds, method) + assert callable(getattr(ds, method)) + +@pytest.mark.parametrize('dataset', DATASETS) +def test_context_is_of_type_context_class(dataset: Dataset): + assert type(CONTEXTS[dataset]) == Context + + +@pytest.mark.parametrize('dataset', DATASETS) +@pytest.mark.parametrize('field', CONTEXT_FIELDS.keys()) +def test_context_contains_expected_field(dataset: Dataset, field: str): + assert hasattr(CONTEXTS[dataset], field) + +@pytest.mark.parametrize('dataset', DATASETS) +@pytest.mark.parametrize('field, expected_type', list(CONTEXT_FIELDS.items())) +def test_context_field_has_expected_types(dataset: Dataset, field: str, expected_type: type): + assert isinstance(getattr(CONTEXTS[dataset], field), expected_type) + + +@pytest.mark.parametrize('dataset', DATASETS) +@pytest.mark.parametrize('field', CONTEXT_FIELDS.keys()) +def test_context_field_is_not_null(dataset: Dataset, field: str): + assert getattr(CONTEXTS[dataset], field) + + +@pytest.mark.parametrize('dataset', DATASETS) +@pytest.mark.parametrize('field', ('creator', 'fetch_time', 'num_tries', 'fetch_method', 'next_kwargs')) +def test_context_stats_field_contains_expected_keys(dataset: Dataset, field: str): + assert field in CONTEXTS[dataset].stats \ No newline at end of file diff --git a/tests/test_mock.py b/tests/test_mock.py new file mode 100644 index 00000000..5bcba1ab --- /dev/null +++ b/tests/test_mock.py @@ -0,0 +1,94 @@ +import pytest +import asyncio +import bittensor as bt +from prompting.mock import MockDendrite, MockMetagraph, MockSubtensor +from prompting.protocol import PromptingSynapse + +wallet = bt.MockWallet() +wallet.create(coldkey_use_password=False) + +@pytest.mark.parametrize('netuid', [1, 2, 3]) +@pytest.mark.parametrize('n', [2, 4, 8, 16, 32, 64]) +@pytest.mark.parametrize('wallet', [wallet, None]) +def test_mock_subtensor(netuid, n, wallet): + + subtensor = MockSubtensor(netuid=netuid, n=n, wallet=wallet) + neurons = subtensor.neurons(netuid=netuid) + # Check netuid + assert subtensor.subnet_exists(netuid) + # Check network + assert subtensor.network == 'mock' + assert subtensor.chain_endpoint == 'mock_endpoint' + # Check number of neurons + assert len(neurons) == (n + 1 if wallet is not None else n) + # Check wallet + if wallet is not None: + assert subtensor.is_hotkey_registered(netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address) + + for neuron in neurons: + assert type(neuron) == bt.NeuronInfo + assert subtensor.is_hotkey_registered(netuid=netuid, hotkey_ss58=neuron.hotkey) + +@pytest.mark.parametrize('n', [16, 32, 64]) +def test_mock_metagraph(n): + mock_subtensor = MockSubtensor(netuid=1, n=n) + mock_metagraph = MockMetagraph(subtensor=mock_subtensor) + # Check axons + axons = mock_metagraph.axons + assert len(axons) == n + # Check ip and port + for axon in axons: + assert type(axon) == bt.AxonInfo + assert axon.ip == mock_metagraph.default_ip + assert axon.port == mock_metagraph.default_port + +def test_mock_reward_pipeline(): + pass + +def test_mock_neuron(): + pass + +@pytest.mark.parametrize('timeout', [0.1, 0.2]) +@pytest.mark.parametrize('min_time', [0, 0.05, 0.1]) +@pytest.mark.parametrize('max_time', [0.1, 0.15, 0.2]) +@pytest.mark.parametrize('n', [4, 16, 64]) +def test_mock_dendrite_timings(timeout, min_time, max_time, n): + mock_wallet = bt.MockWallet(config=None) + mock_dendrite = MockDendrite(mock_wallet) + mock_dendrite.min_time = min_time + mock_dendrite.max_time = max_time + mock_subtensor = MockSubtensor(netuid=1, n=n) + mock_metagraph = MockMetagraph(subtensor=mock_subtensor) + axons = mock_metagraph.axons + + async def run(): + return await mock_dendrite( + axons, + synapse = PromptingSynapse(roles=["user"], messages=["What is the capital of France?"]), + timeout = timeout + ) + + responses = asyncio.run(run()) + for synapse in responses: + assert hasattr(synapse, 'dendrite') and type(synapse.dendrite) == bt.TerminalInfo + + dendrite = synapse.dendrite + # check synapse.dendrite has (process_time, status_code, status_message) + for field in ('process_time', 'status_code', 'status_message'): + assert hasattr(dendrite, field) and getattr(dendrite, field) is not None + + # check that the dendrite take between min_time and max_time + assert min_time <= dendrite.process_time + assert dendrite.process_time <= max_time + 0.1 + # check that responses which take longer than timeout have 408 status code + if dendrite.process_time >= timeout + 0.1: + assert dendrite.status_code == 408 + assert dendrite.status_message == 'Timeout' + assert synapse.completion == '' + # check that responses which take less than timeout have 200 status code + elif dendrite.process_time < timeout: + assert dendrite.status_code == 200 + assert dendrite.status_message == 'OK' + # check that completions are not empty for successful responses + assert type(synapse.completion) == str and len(synapse.completion) > 0 + # dont check for responses which take between timeout and max_time because they are not guaranteed to have a status code of 200 or 408 diff --git a/tests/test_scoring.py b/tests/test_scoring.py new file mode 100644 index 00000000..1751652b --- /dev/null +++ b/tests/test_scoring.py @@ -0,0 +1,71 @@ +import pytest +from datetime import datetime +from prompting.rewards import DateRewardModel, DiffRewardModel, RelevanceRewardModel, RougeRewardModel, FloatDiffModel, RewardPipeline + +date1 = datetime.strptime('2022-01-01','%Y-%m-%d') +date2 = datetime.strptime('2022-01-03','%Y-%m-%d') +date3 = datetime.strptime('2020-01-08','%Y-%m-%d') +date4 = datetime.strptime('2022-02-01','%Y-%m-%d') +ref = 'January 1, 2022' +date_formats = ['%B %d, %Y', '%m/%d/%Y', '%d %B %Y', '%m-%d-%Y'] +dates1 = [date1.strftime(format) for format in date_formats] +scores1 = [1.0]*len(dates1) +dates2 = [date2.strftime(format) for format in date_formats] +scores2 = [0.9960079893439915]*len(dates2) +dates3 = [date3.strftime(format) for format in date_formats] +scores3 = [0.0]*len(dates3) +dates4 = [date4.strftime(format) for format in date_formats] +scores4 = [0.38251018447178037]*len(dates4) +tuples = list( zip(dates1+dates2+dates3+dates4, scores1+scores2+scores3+scores4) ) + +@pytest.mark.parametrize('reference', dates1) +@pytest.mark.parametrize('completion, expected_result', tuples) +def test_score_dates_with_different_format(reference, completion, expected_result): + score = DateRewardModel().date_score(reference, completion) + assert score == expected_result + +completion = ['0.5', '1/2', '1-0.5', '2*0.25'] +expected_result = [1.0, 1.0, 1.0, 1.0] +reference = ['0.5']*len(completion) +@pytest.mark.parametrize('reference', reference) +@pytest.mark.parametrize('completion, expected_result', zip(completion, expected_result)) +def test_math_score_expression_parsing(reference, completion, expected_result): + score = FloatDiffModel().math_score(reference, completion) + assert score == expected_result + +completion = ['1e3', '-1e3', '1e-3', '-1e-3'] +expected_result = [1.0, 0.0, 0.0, 0.0] +reference = ['1000']*len(completion) +@pytest.mark.parametrize('reference', reference) +@pytest.mark.parametrize('completion, expected_result', zip(completion, expected_result)) +def test_math_score_expression_parsing_with_exponents(reference, completion, expected_result): + score = FloatDiffModel().math_score(reference, completion) + assert score == expected_result + +completion = ['1.0.', '1.0', '1.0.0', '1,', '0 1'] +expected_result = [1.0, 1.0, 0.0, 1.0, 1.0] +reference = ['1.0']*len(completion) +@pytest.mark.parametrize('reference', reference) +@pytest.mark.parametrize('completion, expected_result', zip(completion, expected_result)) +def test_math_score_expression_parsing_with_punctuation(reference, completion, expected_result): + score = FloatDiffModel().math_score(reference, completion) + assert score == expected_result + +completion = ['-20', '-23', '23', '20', '1000', '2*10+3'] +expected_result = [0.0, 0.0, 1.0, 0.8695652173918714, 0.0, 1.0] +reference = ['23']*len(completion) +@pytest.mark.parametrize('reference', reference) +@pytest.mark.parametrize('completion, expected_result', zip(completion, expected_result)) +def test_math_score_expression_parsing_with_negative_numbers(reference, completion, expected_result): + score = FloatDiffModel().math_score(reference, completion) + assert score == expected_result + +completion = ['0', '0.001', '-0.0', '-0.001', '0.0001'] +expected_result = [1.0, 0.0, 1.0, 0.0, 0.0] +reference = ['0']*len(completion) +@pytest.mark.parametrize('reference', reference) +@pytest.mark.parametrize('completion, expected_result', zip(completion, expected_result)) +def test_math_score_expression_parsing_with_zeros(reference, completion, expected_result): + score = FloatDiffModel().math_score(reference, completion) + assert score == expected_result + \ No newline at end of file diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 6e063201..7b1cfa61 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,88 +1,113 @@ import pytest - +import inspect +from inspect import signature from prompting.tasks import Task -from .fixtures.task import CONTEXTS, TASKS +from prompting.rewards import REWARD_MODELS +from .fixtures.task import CONTEXTS, TASKS, TASK_FIELDS from .fixtures.llm import LLM_PIPELINE -""" -What we want to test for each task: -- The task is initialized correctly -- The task contains a query -- The task contains a reference answer -- Task contains a query_time -- Task contains a reference_time -- The task formats correctly -- All task fields are present as expected -- Tasks have reward definitions -""" - +# TODO: Check if format_challenge is defined +# TODO: Ensure that when static_reference is True, reference_time is not defined. Same for query_time and static_query +# TODO: Ensure that when generate_reference=True and static_reference is True,there is still a reference +# TODO: Ensure that when generate_reference=False and static_reference is True,there is still a reference +# TODO: Ensure that when generate_reference=False and static_reference is False,there is NOT a reference -# TODO: Math task only works when solution is floatable -# TODO: DateQA only accepts section in {Births, Deaths, Events} -# TODO: DateQA expect wiki entry for event @pytest.mark.parametrize('task', TASKS) def test_create_task(task: Task): + task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) - task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) - -@pytest.mark.parametrize('task', TASKS) -def test_task_contains_query(task: Task): - - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) - assert task.query is not None @pytest.mark.parametrize('task', TASKS) -def test_task_contains_reference(task: Task): +@pytest.mark.parametrize('field', TASK_FIELDS.keys()) +def test_task_contains_expected_field(task: Task, field: str): + task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) + assert hasattr(task, field) - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) - assert task.reference is not None @pytest.mark.parametrize('task', TASKS) -def test_task_contains_reward_definition(task: Task): - - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) - assert type(task.reward_definition) == list +@pytest.mark.parametrize('field, expected_type', list(TASK_FIELDS.items())) +def test_task_field_has_expected_type(task: Task, field: str, expected_type: type): + task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) + assert isinstance(getattr(task, field), expected_type) @pytest.mark.parametrize('task', TASKS) -def test_task_contains_goal(task: Task): +@pytest.mark.parametrize('field', TASK_FIELDS.keys()) +def test_task_field_is_not_null(task: Task, field: str): + task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) + assert getattr(task, field) is not None - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) - assert task.goal is not None - -@pytest.mark.parametrize('task', TASKS) -def test_task_contains_desc(task: Task): - - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) - assert task.desc is not None @pytest.mark.parametrize('task', TASKS) def test_task_complete_is_false_on_init(task: Task): - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) + task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) assert task.complete == False -@pytest.mark.parametrize('task', TASKS) -def test_task_contains_tags(task: Task): - - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) - assert type(task.tags) == list @pytest.mark.parametrize('task', TASKS) -def test_task_contains_context(task: Task): - context = CONTEXTS[task].copy() - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) - assert context == task.context +def test_task_contains_no_reference_if_not_static(task: Task): + task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task], create_reference=False) + assert task.static_reference or not task.reference + @pytest.mark.parametrize('task', TASKS) def test_task_contains_query_time(task: Task): - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) + task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) assert task.static_reference or task.reference_time>=0 + @pytest.mark.parametrize('task', TASKS) def test_task_contains_reference_time(task: Task): - task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task].copy()) + task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) assert task.static_query or task.query_time>=0 + + +@pytest.mark.parametrize('task', TASKS) +@pytest.mark.parametrize('full', (True, False)) +def test_task_state_dict(task: Task, full: bool): + + task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) + assert type(task.__state_dict__(full)) == dict + + +@pytest.mark.parametrize('task', TASKS) +@pytest.mark.parametrize('definition', ('reward_definition', 'penalty_definition')) +def test_task_contains_required_definition(task: Task, definition: str): + + task = task(llm_pipeline=LLM_PIPELINE, context=CONTEXTS[task]) + model_infos = getattr(task, definition) + total_weight = 0 + for model_info in model_infos: + + assert isinstance(model_info, dict) + + name = model_info.get("name") + assert name is not None + assert name in REWARD_MODELS + + params = {k: v for k, v in model_info.items() if k not in ["name", "weight"]} + cls_params = signature(REWARD_MODELS['rouge']).parameters + # check that all the parameters are in the class (no extra parameters are allowed) + for k, v in params.items(): + assert k in cls_params + # check that the type of the parameter is correct or not annotated + assert cls_params[k].annotation == inspect._empty or isinstance(v, cls_params[k].annotation) + + # check that all class parameters without default values are in the model_info + for k, v in cls_params.items(): + # ignore self, device, args, kwargs + if v.default == inspect._empty and v.name not in ("self", "device", "args", "kwargs"): + assert k in params + + weight = model_info.get("weight") + assert weight is not None + assert isinstance(weight, (float, int)) + assert 0 <= weight <= 1 + + total_weight += weight + + assert not model_infos or total_weight == 1 \ No newline at end of file