diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5db44d16..cedb6b55 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -5,9 +5,9 @@ name: Python package on: push: - branches: [ "main", "staging" ] + branches: [ "main", "staging", "pre-staging" ] pull_request: - branches: [ "main", "staging" ] + branches: [ "main", "staging", "pre-staging" ] jobs: build: diff --git a/prompting/mock.py b/prompting/mock.py index e5862c27..8e7d9ae3 100644 --- a/prompting/mock.py +++ b/prompting/mock.py @@ -74,8 +74,12 @@ def preprocess(self, **kwargs): class MockSubtensor(bt.MockSubtensor): - def __init__(self, netuid, n=16, wallet=None, network="mock"): - super().__init__(network=network) + def __init__(self, netuid, n=16, wallet=None): + + super().__init__() + # reset the underlying subtensor state + self.chain_state = None + self.setup() if not self.subnet_exists(netuid): self.create_subnet(netuid) @@ -102,6 +106,10 @@ def __init__(self, netuid, n=16, wallet=None, network="mock"): class MockMetagraph(bt.metagraph): + + default_ip = "127.0.0.0" + default_port = 8091 + def __init__(self, netuid=1, network="mock", subtensor=None): super().__init__( netuid=netuid, network=network, sync=False @@ -112,17 +120,17 @@ def __init__(self, netuid=1, network="mock", subtensor=None): self.sync(subtensor=subtensor) for axon in self.axons: - axon.ip = "127.0.0.0" - axon.port = 8091 - - bt.logging.info(f"Metagraph: {self}") - bt.logging.info(f"Axons: {self.axons}") + axon.ip = self.default_ip + axon.port = self.default_port class MockDendrite(bt.dendrite): """ Replaces a real bittensor network request with a mock request that just returns some static completion for all axons that are passed and adds some random delay. """ + min_time: float = 0 + max_time: float = 1 + def __init__(self, wallet): super().__init__(wallet) @@ -145,24 +153,24 @@ async def query_all_axons(streaming: bool): async def single_axon_response(i, axon): """Queries a single axon for a response.""" - start_time = time.time() + t0 = time.time() s = synapse.copy() # Attach some more required data so it looks real s = self.preprocess_synapse_for_request(axon, s, timeout) # We just want to mock the response, so we'll just fill in some data - process_time = random.random() + process_time = random.random()*(self.max_time-self.min_time) + self.min_time + await asyncio.sleep(process_time) if process_time < timeout: - s.dendrite.process_time = str(time.time() - start_time) # Update the status code and status message of the dendrite to match the axon s.completion = f'Mock miner completion {i}' s.dendrite.status_code = 200 s.dendrite.status_message = "OK" - synapse.dendrite.process_time = str(process_time) else: s.completion = "" s.dendrite.status_code = 408 s.dendrite.status_message = "Timeout" - synapse.dendrite.process_time = str(timeout) + + s.dendrite.process_time = str(time.time() - t0) # Return the updated synapse object after deserializing if requested if deserialize: diff --git a/tests/test_mock.py b/tests/test_mock.py new file mode 100644 index 00000000..5bcba1ab --- /dev/null +++ b/tests/test_mock.py @@ -0,0 +1,94 @@ +import pytest +import asyncio +import bittensor as bt +from prompting.mock import MockDendrite, MockMetagraph, MockSubtensor +from prompting.protocol import PromptingSynapse + +wallet = bt.MockWallet() +wallet.create(coldkey_use_password=False) + +@pytest.mark.parametrize('netuid', [1, 2, 3]) +@pytest.mark.parametrize('n', [2, 4, 8, 16, 32, 64]) +@pytest.mark.parametrize('wallet', [wallet, None]) +def test_mock_subtensor(netuid, n, wallet): + + subtensor = MockSubtensor(netuid=netuid, n=n, wallet=wallet) + neurons = subtensor.neurons(netuid=netuid) + # Check netuid + assert subtensor.subnet_exists(netuid) + # Check network + assert subtensor.network == 'mock' + assert subtensor.chain_endpoint == 'mock_endpoint' + # Check number of neurons + assert len(neurons) == (n + 1 if wallet is not None else n) + # Check wallet + if wallet is not None: + assert subtensor.is_hotkey_registered(netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address) + + for neuron in neurons: + assert type(neuron) == bt.NeuronInfo + assert subtensor.is_hotkey_registered(netuid=netuid, hotkey_ss58=neuron.hotkey) + +@pytest.mark.parametrize('n', [16, 32, 64]) +def test_mock_metagraph(n): + mock_subtensor = MockSubtensor(netuid=1, n=n) + mock_metagraph = MockMetagraph(subtensor=mock_subtensor) + # Check axons + axons = mock_metagraph.axons + assert len(axons) == n + # Check ip and port + for axon in axons: + assert type(axon) == bt.AxonInfo + assert axon.ip == mock_metagraph.default_ip + assert axon.port == mock_metagraph.default_port + +def test_mock_reward_pipeline(): + pass + +def test_mock_neuron(): + pass + +@pytest.mark.parametrize('timeout', [0.1, 0.2]) +@pytest.mark.parametrize('min_time', [0, 0.05, 0.1]) +@pytest.mark.parametrize('max_time', [0.1, 0.15, 0.2]) +@pytest.mark.parametrize('n', [4, 16, 64]) +def test_mock_dendrite_timings(timeout, min_time, max_time, n): + mock_wallet = bt.MockWallet(config=None) + mock_dendrite = MockDendrite(mock_wallet) + mock_dendrite.min_time = min_time + mock_dendrite.max_time = max_time + mock_subtensor = MockSubtensor(netuid=1, n=n) + mock_metagraph = MockMetagraph(subtensor=mock_subtensor) + axons = mock_metagraph.axons + + async def run(): + return await mock_dendrite( + axons, + synapse = PromptingSynapse(roles=["user"], messages=["What is the capital of France?"]), + timeout = timeout + ) + + responses = asyncio.run(run()) + for synapse in responses: + assert hasattr(synapse, 'dendrite') and type(synapse.dendrite) == bt.TerminalInfo + + dendrite = synapse.dendrite + # check synapse.dendrite has (process_time, status_code, status_message) + for field in ('process_time', 'status_code', 'status_message'): + assert hasattr(dendrite, field) and getattr(dendrite, field) is not None + + # check that the dendrite take between min_time and max_time + assert min_time <= dendrite.process_time + assert dendrite.process_time <= max_time + 0.1 + # check that responses which take longer than timeout have 408 status code + if dendrite.process_time >= timeout + 0.1: + assert dendrite.status_code == 408 + assert dendrite.status_message == 'Timeout' + assert synapse.completion == '' + # check that responses which take less than timeout have 200 status code + elif dendrite.process_time < timeout: + assert dendrite.status_code == 200 + assert dendrite.status_message == 'OK' + # check that completions are not empty for successful responses + assert type(synapse.completion) == str and len(synapse.completion) > 0 + # dont check for responses which take between timeout and max_time because they are not guaranteed to have a status code of 200 or 408