Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mock/tests #107

Merged
merged 6 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: Python package

on:
push:
branches: [ "main", "staging" ]
branches: [ "main", "staging", "pre-staging" ]
pull_request:
branches: [ "main", "staging" ]
branches: [ "main", "staging", "pre-staging" ]

jobs:
build:
Expand Down
32 changes: 20 additions & 12 deletions prompting/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ def preprocess(self, **kwargs):


class MockSubtensor(bt.MockSubtensor):
def __init__(self, netuid, n=16, wallet=None, network="mock"):
super().__init__(network=network)
def __init__(self, netuid, n=16, wallet=None):

super().__init__()
# reset the underlying subtensor state
self.chain_state = None
self.setup()

if not self.subnet_exists(netuid):
self.create_subnet(netuid)
Expand All @@ -102,6 +106,10 @@ def __init__(self, netuid, n=16, wallet=None, network="mock"):


class MockMetagraph(bt.metagraph):

default_ip = "127.0.0.0"
default_port = 8091

def __init__(self, netuid=1, network="mock", subtensor=None):
super().__init__(
netuid=netuid, network=network, sync=False
Expand All @@ -112,17 +120,17 @@ def __init__(self, netuid=1, network="mock", subtensor=None):
self.sync(subtensor=subtensor)

for axon in self.axons:
axon.ip = "127.0.0.0"
axon.port = 8091

bt.logging.info(f"Metagraph: {self}")
bt.logging.info(f"Axons: {self.axons}")
axon.ip = self.default_ip
axon.port = self.default_port


class MockDendrite(bt.dendrite):
"""
Replaces a real bittensor network request with a mock request that just returns some static completion for all axons that are passed and adds some random delay.
"""
min_time: float = 0
max_time: float = 1

def __init__(self, wallet):
super().__init__(wallet)

Expand All @@ -145,24 +153,24 @@ async def query_all_axons(streaming: bool):
async def single_axon_response(i, axon):
"""Queries a single axon for a response."""

start_time = time.time()
t0 = time.time()
s = synapse.copy()
# Attach some more required data so it looks real
s = self.preprocess_synapse_for_request(axon, s, timeout)
# We just want to mock the response, so we'll just fill in some data
process_time = random.random()
process_time = random.random()*(self.max_time-self.min_time) + self.min_time
await asyncio.sleep(process_time)
if process_time < timeout:
s.dendrite.process_time = str(time.time() - start_time)
# Update the status code and status message of the dendrite to match the axon
s.completion = f'Mock miner completion {i}'
s.dendrite.status_code = 200
s.dendrite.status_message = "OK"
synapse.dendrite.process_time = str(process_time)
else:
s.completion = ""
s.dendrite.status_code = 408
s.dendrite.status_message = "Timeout"
synapse.dendrite.process_time = str(timeout)

s.dendrite.process_time = str(time.time() - t0)

# Return the updated synapse object after deserializing if requested
if deserialize:
Expand Down
94 changes: 94 additions & 0 deletions tests/test_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest
import asyncio
import bittensor as bt
from prompting.mock import MockDendrite, MockMetagraph, MockSubtensor
from prompting.protocol import PromptingSynapse

wallet = bt.MockWallet()
wallet.create(coldkey_use_password=False)

@pytest.mark.parametrize('netuid', [1, 2, 3])
@pytest.mark.parametrize('n', [2, 4, 8, 16, 32, 64])
@pytest.mark.parametrize('wallet', [wallet, None])
def test_mock_subtensor(netuid, n, wallet):

subtensor = MockSubtensor(netuid=netuid, n=n, wallet=wallet)
neurons = subtensor.neurons(netuid=netuid)
# Check netuid
assert subtensor.subnet_exists(netuid)
# Check network
assert subtensor.network == 'mock'
assert subtensor.chain_endpoint == 'mock_endpoint'
# Check number of neurons
assert len(neurons) == (n + 1 if wallet is not None else n)
# Check wallet
if wallet is not None:
assert subtensor.is_hotkey_registered(netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address)

for neuron in neurons:
assert type(neuron) == bt.NeuronInfo
assert subtensor.is_hotkey_registered(netuid=netuid, hotkey_ss58=neuron.hotkey)

@pytest.mark.parametrize('n', [16, 32, 64])
def test_mock_metagraph(n):
mock_subtensor = MockSubtensor(netuid=1, n=n)
mock_metagraph = MockMetagraph(subtensor=mock_subtensor)
# Check axons
axons = mock_metagraph.axons
assert len(axons) == n
# Check ip and port
for axon in axons:
assert type(axon) == bt.AxonInfo
assert axon.ip == mock_metagraph.default_ip
assert axon.port == mock_metagraph.default_port

def test_mock_reward_pipeline():
pass

def test_mock_neuron():
pass

@pytest.mark.parametrize('timeout', [0.1, 0.2])
@pytest.mark.parametrize('min_time', [0, 0.05, 0.1])
@pytest.mark.parametrize('max_time', [0.1, 0.15, 0.2])
@pytest.mark.parametrize('n', [4, 16, 64])
def test_mock_dendrite_timings(timeout, min_time, max_time, n):
mock_wallet = bt.MockWallet(config=None)
mock_dendrite = MockDendrite(mock_wallet)
mock_dendrite.min_time = min_time
mock_dendrite.max_time = max_time
mock_subtensor = MockSubtensor(netuid=1, n=n)
mock_metagraph = MockMetagraph(subtensor=mock_subtensor)
axons = mock_metagraph.axons

async def run():
return await mock_dendrite(
axons,
synapse = PromptingSynapse(roles=["user"], messages=["What is the capital of France?"]),
timeout = timeout
)

responses = asyncio.run(run())
for synapse in responses:
assert hasattr(synapse, 'dendrite') and type(synapse.dendrite) == bt.TerminalInfo

dendrite = synapse.dendrite
# check synapse.dendrite has (process_time, status_code, status_message)
for field in ('process_time', 'status_code', 'status_message'):
assert hasattr(dendrite, field) and getattr(dendrite, field) is not None

# check that the dendrite take between min_time and max_time
assert min_time <= dendrite.process_time
assert dendrite.process_time <= max_time + 0.1
# check that responses which take longer than timeout have 408 status code
if dendrite.process_time >= timeout + 0.1:
assert dendrite.status_code == 408
assert dendrite.status_message == 'Timeout'
assert synapse.completion == ''
# check that responses which take less than timeout have 200 status code
elif dendrite.process_time < timeout:
assert dendrite.status_code == 200
assert dendrite.status_message == 'OK'
# check that completions are not empty for successful responses
assert type(synapse.completion) == str and len(synapse.completion) > 0
# dont check for responses which take between timeout and max_time because they are not guaranteed to have a status code of 200 or 408
Loading