Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SN1-419: MSRv2: Zero Sum Scoring Experiment #634

Draft
wants to merge 25 commits into
base: staging
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
665bfc1
Seperate Prompting, Remove TTI Endpoint, Add Json Flag
bkb2135 Feb 28, 2025
4563830
Initial draft
Feb 28, 2025
6ae46ce
Precommit Changes
richwardle Feb 28, 2025
15c951d
Precommit Fix
bkb2135 Mar 3, 2025
cb29ad8
Improving TTI Final Prompt, Add Unittests for Prompts
richwardle Mar 3, 2025
2464424
Merge branch 'SN1-423-restructure-prompting' of github.com:macrocosm-…
richwardle Mar 3, 2025
8d78f54
Finalising draft for new MSR task
Mar 3, 2025
2f96a50
Await Reward Models
richwardle Mar 3, 2025
a6132c0
Add Next Action For Final Prompt
richwardle Mar 3, 2025
fc470bf
Add Detailed Log For Scoring Response Failed
richwardle Mar 3, 2025
629bf1e
Generating follow-up task in generator reward config
Mar 3, 2025
5948dbb
Precommit Fixes
bkb2135 Mar 3, 2025
1fad896
Fixing various import errors
Mar 4, 2025
9a5ebb3
Simplify Prompt Structure
richwardle Mar 4, 2025
0245a2a
Fix Unittest and Precommit
bkb2135 Mar 4, 2025
46fe23a
Merge branch 'SN1-423-restructure-prompting' into 'SN1-419-r-d-resear…
richwardle Mar 4, 2025
9d8abfc
Add Get Entry for DiscriminatorDataset Entry
richwardle Mar 4, 2025
a243288
Fix Merge Overwrite
richwardle Mar 4, 2025
dfaa788
Use Random In Discriminator Dataset
richwardle Mar 4, 2025
9c5b90b
Fixing bugs with task appending
Mar 5, 2025
4d9d58b
Bug Fixes
bkb2135 Mar 8, 2025
dc9401b
Remove Miner Logs
bkb2135 Mar 8, 2025
583b02e
Restructuring and Formatting
richard-wardle Mar 10, 2025
8f2cee9
Remove Redundant Reward Model
richard-wardle Mar 12, 2025
66c584f
Extract Out Weighted Average Calculation
richard-wardle Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions neurons/miners/epistula_miner/miner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# ruff: noqa: E402
import random

from loguru import logger

from shared import settings

logger.info("Loading settings as miner")
settings.shared_settings = settings.SharedSettings.load(mode="miner")
shared_settings = settings.shared_settings

Expand All @@ -15,11 +20,10 @@
from bittensor.core.axon import FastAPIThreadedServer
from bittensor.core.extrinsics.serving import serve_extrinsic
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from loguru import logger
from starlette.background import BackgroundTask
from starlette.responses import StreamingResponse
from web_retrieval import get_websites_with_similarity

from neurons.miners.epistula_miner.web_retrieval import get_websites_with_similarity
from prompting.llms.hf_llm import ReproducibleHF
from shared.epistula import verify_signature

Expand Down Expand Up @@ -76,11 +80,21 @@ async def word_stream(body, headers):

return StreamingResponse(word_stream(body, headers), media_type="text/event-stream")

async def create_discriminator_completion(self, request: Request):
async def choose_random():
data = {"choices": [{"delta": {"content": random.choice(["A", "B"])}, "index": 0, "finish_reason": None}]}
yield f"data: {json.dumps(data)}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(choose_random(), media_type="text/event-stream")

async def create_chat_completion(self, request: Request):
data = await request.json()
headers = request.headers
if self.llm and request.headers.get("task", None) == "inference":
return await self.create_inference_completion(request)
if request.headers.get("task", None) == "MultiStepReasoningTaskDiscriminator":
return await self.create_discriminator_completion(request)
if request.headers.get("task", None) == "WebRetrievalTask":
return await self.stream_web_retrieval(data, headers)
req = self.client.build_request("POST", "chat/completions", json=await self.format_openai_query(request))
Expand Down
15 changes: 7 additions & 8 deletions neurons/validator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from shared import settings

settings.shared_settings = settings.SharedSettings.load(mode="validator")

import asyncio
import multiprocessing as mp
import sys
Expand All @@ -9,17 +13,12 @@
import wandb
from bittensor.core.extrinsics.serving import serve_extrinsic

from prompting.llms.utils import GPUInfo
from prompting.rewards.scoring import task_scorer

# ruff: noqa: E402
from shared import settings
from shared.logging import init_wandb

settings.shared_settings = settings.SharedSettings.load(mode="validator")


from prompting.llms.utils import GPUInfo

# Add a handler to write logs to a file
loguru.logger.add("logfile.log", rotation="1000 MB", retention="10 days", level="DEBUG")
from loguru import logger
Expand Down Expand Up @@ -60,7 +59,7 @@ async def spawn_loops(task_queue, scoring_queue, reward_events):
logger.info("Starting ModelScheduler...")
asyncio.create_task(model_scheduler.start(scoring_queue), name="ModelScheduler"),
logger.info("Starting TaskScorer...")
asyncio.create_task(task_scorer.start(scoring_queue, reward_events), name="TaskScorer"),
asyncio.create_task(task_scorer.start(scoring_queue, reward_events, task_queue), name="TaskScorer"),
logger.info("Starting WeightSetter...")
asyncio.create_task(weight_setter.start(reward_events))

Expand Down Expand Up @@ -154,7 +153,7 @@ async def main():
step += 1

except Exception as e:
logger.error(f"Main loop error: {e}")
logger.exception(f"Main loop error: {e}")
raise
finally:
wandb.teardown()
Expand Down
141 changes: 141 additions & 0 deletions notebooks/demo.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,146 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inner async\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/lib/python3.10/ast.py:50: RuntimeWarning: coroutine 'async_demo' was never awaited\n",
" return compile(source, filename, mode, flags,\n",
"RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n"
]
},
{
"ename": "RuntimeError",
"evalue": "This event loop is already running",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[20], line 17\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21masync_demo\u001b[39m():\n\u001b[1;32m 13\u001b[0m sync_example()\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m async_demo()\n",
"Cell \u001b[0;32mIn[20], line 13\u001b[0m, in \u001b[0;36masync_demo\u001b[0;34m()\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21masync_demo\u001b[39m():\n\u001b[0;32m---> 13\u001b[0m sync_example()\n",
"Cell \u001b[0;32mIn[20], line 6\u001b[0m, in \u001b[0;36msync_example\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21msync_example\u001b[39m():\n\u001b[1;32m 5\u001b[0m loop \u001b[38;5;241m=\u001b[39m asyncio\u001b[38;5;241m.\u001b[39mget_event_loop()\n\u001b[0;32m----> 6\u001b[0m loop\u001b[38;5;241m.\u001b[39mrun_until_complete(inner_async())\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSync example\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m/usr/lib/python3.10/asyncio/base_events.py:625\u001b[0m, in \u001b[0;36mBaseEventLoop.run_until_complete\u001b[0;34m(self, future)\u001b[0m\n\u001b[1;32m 614\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Run until the Future is done.\u001b[39;00m\n\u001b[1;32m 615\u001b[0m \n\u001b[1;32m 616\u001b[0m \u001b[38;5;124;03mIf the argument is a coroutine, it is wrapped in a Task.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 622\u001b[0m \u001b[38;5;124;03mReturn the Future's result, or raise its exception.\u001b[39;00m\n\u001b[1;32m 623\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 624\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_closed()\n\u001b[0;32m--> 625\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_running\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 627\u001b[0m new_task \u001b[38;5;241m=\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m futures\u001b[38;5;241m.\u001b[39misfuture(future)\n\u001b[1;32m 628\u001b[0m future \u001b[38;5;241m=\u001b[39m tasks\u001b[38;5;241m.\u001b[39mensure_future(future, loop\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m)\n",
"File \u001b[0;32m/usr/lib/python3.10/asyncio/base_events.py:584\u001b[0m, in \u001b[0;36mBaseEventLoop._check_running\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 582\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_check_running\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 583\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_running():\n\u001b[0;32m--> 584\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mThis event loop is already running\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 585\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m events\u001b[38;5;241m.\u001b[39m_get_running_loop() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 586\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 587\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mCannot run the event loop while another loop is running\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
"\u001b[0;31mRuntimeError\u001b[0m: This event loop is already running"
]
}
],
"source": [
"import asyncio\n",
"\n",
"\n",
"def sync_example():\n",
" loop = asyncio.get_event_loop()\n",
" loop.run_until_complete(inner_async())\n",
" print(\"Sync example\")\n",
"\n",
"def inner_async():\n",
" print(\"Inner async\")\n",
"\n",
"async def async_demo():\n",
" sync_example()\n",
"\n",
"\n",
"\n",
"await async_demo()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting\n",
"Waiting\n",
"Waiting\n",
"Waiting\n",
"Waiting\n",
"Waiting\n",
"Waiting\n",
"Waiting\n",
"Waiting\n",
"Waiting\n",
"Waiting\n",
"Done\n"
]
}
],
"source": [
"from asyncio import create_task\n",
"import asyncio\n",
"\n",
"async def wrapped_wait():\n",
" async def wait():\n",
" for _ in range(10):\n",
" await asyncio.sleep(1)\n",
" print(\"Waiting\")\n",
" return \"Waited 10 seconds\"\n",
" # return create_task(wait())\n",
" return await wait()\n",
"print(\"Starting\")\n",
"await wrapped_wait()\n",
"print(\"Done\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Waited 0.2195 seconds\n",
"Waited 0.2276 seconds\n",
"Waited 0.2519 seconds\n",
"Waited 0.2570 seconds\n",
"Waited 0.2901 seconds\n",
"Waited 0.6218 seconds\n",
"Waited 0.6412 seconds\n",
"Waited 0.6476 seconds\n",
"Waited 0.8122 seconds\n",
"Waited 0.8261 seconds\n"
]
}
],
"source": [
"import asyncio\n",
"import random\n",
"\n",
"async def text_generator():\n",
" async def wait_random(min_seconds: float, max_seconds: float):\n",
" await asyncio.sleep(wait_time := random.uniform(min_seconds, max_seconds))\n",
" return f\"Waited {wait_time:.4f} seconds\"\n",
" \n",
" # Create tasks\n",
" tasks = [asyncio.create_task(wait_random(0.01, 1)) for _ in range(10)]\n",
" \n",
" # Yield results as they complete\n",
" for completed_task in asyncio.as_completed(tasks):\n",
" result = await completed_task\n",
" yield result\n",
"\n",
"async for result in text_generator():\n",
" print(result) "
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down
28 changes: 28 additions & 0 deletions prompting/datasets/msr_v2_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import random
from typing import ClassVar

from shared.base import BaseDataset, Context, DatasetEntry

dataset_entry_queue: list[Context] = []


class MSRDiscriminatorDatasetEntry(DatasetEntry):
miner_response: str | None = None
validator_reference: str
miner_uid: int
source: str | None = None


class MSRDiscriminatorDataset(BaseDataset):
name: ClassVar[str] = "msr_discriminator"

def random(self) -> Context:
return random.choice(dataset_entry_queue)

@classmethod
def add_entry(cls, miner_response: str, validator_reference: str, miner_uid: int):
dataset_entry_queue.append(
MSRDiscriminatorDatasetEntry(
miner_response=miner_response, validator_reference=validator_reference, miner_uid=miner_uid
)
)
38 changes: 38 additions & 0 deletions prompting/rewards/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import TYPE_CHECKING

import numpy as np

from prompting.rewards.reward import BaseRewardModel, BatchRewardOutput
from shared.dendrite import DendriteResponseEvent

if TYPE_CHECKING:
from prompting.tasks.msr_task_v2 import MultiStepReasoningTaskDiscriminator


class DiscriminatorRewardModel(BaseRewardModel):
"""
This reward model is used to reward the discriminator task by comparing the reference and the response.
"""

async def reward(
self,
reference: str,
response_event: DendriteResponseEvent,
task: "MultiStepReasoningTaskDiscriminator",
**kwargs,
) -> BatchRewardOutput:
completions: list[str] = response_event.completions

# Get miner_uid from either original_miner_uid or dataset_entry
miner_uid = task.original_miner_uid # if task.original_miner_uid is not None else task.dataset_entry.miner_uid
rewards: list[float] = []

for completion in completions:
rewards.append(1 / len(completions) if completion == reference else 0)

generator_reward = 1 - np.sum(rewards)
# Convert to list and use the miner_uid we retrieved
uids = [float(miner_uid)] + list(response_event.uids)
rewards = [generator_reward] + rewards

return BatchRewardOutput(rewards=np.array(rewards), timings=np.array([0] * len(rewards)), uids=uids)
4 changes: 2 additions & 2 deletions prompting/rewards/inference_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ async def reward(
) -> BatchRewardOutput:
"""Gives an exact reward of 1 if the response matches the reference, 0 otherwise"""
if model_id:
return ExactMatchRewardModel().reward(reference, response_event)
return RelevanceRewardModel().reward(reference, response_event)
return await ExactMatchRewardModel().reward(reference, response_event)
return await RelevanceRewardModel().reward(reference, response_event)
File renamed without changes.
10 changes: 8 additions & 2 deletions prompting/rewards/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BatchRewardOutput(BaseModel):
threshold: float | None = None
extra_info: dict = {}
model_config = ConfigDict(arbitrary_types_allowed=True)
uids: list[float] | None = None

@model_validator(mode="after")
def validate_rewards_and_timings(cls, v):
Expand Down Expand Up @@ -79,11 +80,14 @@ async def apply(
challenge: str | None = None,
reward_type: Literal["reward", "penalty"] = "reward",
task: BaseTextTask | None = None,
task_queue: list = None,
**kwargs,
) -> WeightedRewardEvent:
t0 = time.time()
comparator = reference if reward_type == "reward" else challenge
batch_rewards_output: BatchRewardOutput = await self.reward(comparator, response_event, task=task, **kwargs)
batch_rewards_output: BatchRewardOutput = await self.reward(
comparator, response_event, task=task, task_queue=task_queue, **kwargs
)
batch_rewards_time = time.time() - t0

return WeightedRewardEvent(
Expand All @@ -97,7 +101,7 @@ async def apply(
threshold=batch_rewards_output.threshold,
timings=batch_rewards_output.timings,
extra_info=kwargs,
uids=response_event.uids,
uids=batch_rewards_output.uids or response_event.uids,
)


Expand Down Expand Up @@ -143,6 +147,7 @@ async def apply(
challenge: str | None = None,
model_id: str | None = None,
task: BaseTextTask | None = None,
task_queue: list = None,
) -> list[WeightedRewardEvent]:
reward_events = []
for weighted_reward in cls.reward_definitions:
Expand All @@ -154,6 +159,7 @@ async def apply(
reward_type="reward",
model_id=model_id,
task=task,
task_queue=task_queue,
),
)
return reward_events
5 changes: 4 additions & 1 deletion prompting/rewards/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class TaskScorer(AsyncLoopRunner):
interval: int = 0
scoring_queue: list | None = None
reward_events: list | None = None
task_queue: list | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

async def start(self, scoring_queue, reward_events, name: str | None = None):
async def start(self, scoring_queue, reward_events, task_queue, name: str | None = None):
self.scoring_queue = scoring_queue
self.reward_events = reward_events
self.task_queue = task_queue
return await super().start(name=name)

def add_to_queue(
Expand Down Expand Up @@ -82,6 +84,7 @@ async def run_step(self) -> RewardLoggingEvent:
reference=scoring_config.task.reference,
model_id=scoring_config.task.llm_model,
task=scoring_config.task,
task_queue=self.task_queue,
)
self.reward_events.append(reward_events)

Expand Down
Loading
Loading