Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to use hotkeys not uids for miner identification. #4

Merged
1 commit merged into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ModelId(BaseModel):
# Makes the object "Immutable" once created.
class Config:
frozen = True
extra = "forbid"

# TODO add pydantic validations on underlying fields.
path: str = Field(
Expand All @@ -31,7 +32,7 @@ def from_compressed_str(cls, cs: str) -> Type["ModelId"]:
return cls(
path=tokens[0],
name=tokens[1],
rev=tokens[2],
commit=tokens[2],
hash=tokens[3],
)

Expand Down
20 changes: 8 additions & 12 deletions model/storage/chain/chain_model_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,22 @@ def __init__(
wallet # Wallet is only needed to write to the chain, not to read.
)
self.subnet_uid = subnet_uid
self.metagraph = self.subtensor.metagraph(self.subnet_uid)

# TODO actually make this asynchronous with threadpools etc.
async def store_model_metadata(self, uid: int, model_id: ModelId):
async def store_model_metadata(self, hotkey: str, model_id: ModelId):
"""Stores model metadata on this subnet for a specific wallet."""
if self.wallet is None:
raise ValueError("No wallet available to write to the chain.")

# TODO: Confirm that the hotkey matches the wallet using self.metagraph.hotkeys[uid]
# TODO: Confirm that the hotkey matches the wallet
self.subtensor.commit(
wallet=self.wallet,
netuid=self.subnet_uid,
data=model_id.to_compressed_str(),
)

async def retrieve_model_metadata(self, uid: int) -> ModelMetadata:
async def retrieve_model_metadata(self, hotkey: str) -> ModelMetadata:
"""Retrieves model metadata on this subnet for specific hotkey"""
hotkey = self.metagraph.hotkeys[uid]
metadata = bt.extrinsics.serving.get_metadata(
self.subtensor, self.subnet_uid, hotkey
)
Expand Down Expand Up @@ -68,7 +66,6 @@ async def test_store_model_metadata():
coldkey = os.getenv("TEST_COLDKEY")
hotkey = os.getenv("TEST_HOTKEY")
net_uid = int(os.getenv("TEST_SUBNET_UID"))
uid = int(os.getenv("TEST_UID"))

wallet = bt.wallet(name=coldkey, hotkey=hotkey)

Expand All @@ -77,7 +74,7 @@ async def test_store_model_metadata():
)

# Store the metadata on chain.
await metadata_store.store_model_metadata(uid=uid, model_id=model_id)
await metadata_store.store_model_metadata(hotkey=hotkey, model_id=model_id)

print(f"Finished storing {model_id} on the chain.")

Expand All @@ -94,15 +91,15 @@ async def test_retrieve_model_metadata():

# Uses .env configured hotkey/uid for the test.
net_uid = int(os.getenv("TEST_SUBNET_UID"))
uid = int(os.getenv("TEST_UID"))
hotkey = os.getenv("TEST_HOTKEY")

# Do not require a wallet for retrieving data.
metadata_store = ChainModelMetadataStore(
subtensor=subtensor, wallet=None, subnet_uid=net_uid
)

# Retrieve the metadata from the chain.
model_metadata = await metadata_store.retrieve_model_metadata(uid)
model_metadata = await metadata_store.retrieve_model_metadata(hotkey)

print(f"Expecting matching model id: {expected_model_id == model_metadata.id}")

Expand All @@ -122,7 +119,6 @@ async def test_roundtrip_model_metadata():
coldkey = os.getenv("TEST_COLDKEY")
hotkey = os.getenv("TEST_HOTKEY")
net_uid = int(os.getenv("TEST_SUBNET_UID"))
uid = int(os.getenv("TEST_UID"))

wallet = bt.wallet(name=coldkey, hotkey=hotkey)

Expand All @@ -131,13 +127,13 @@ async def test_roundtrip_model_metadata():
)

# Store the metadata on chain.
await metadata_store.store_model_metadata(uid=uid, model_id=model_id)
await metadata_store.store_model_metadata(hotkey=hotkey, model_id=model_id)

# May need to use the underlying publish_metadata function with wait_for_inclusion: True to pass here.
# Otherwise it defaults to False and we only wait for finalization not necessarily inclusion.

# Retrieve the metadata from the chain.
model_metadata = await metadata_store.retrieve_model_metadata(uid)
model_metadata = await metadata_store.retrieve_model_metadata(hotkey)

print(f"Expecting matching metadata: {model_id == model_metadata.id}")

Expand Down
14 changes: 8 additions & 6 deletions model/storage/hugging_face/hugging_face_model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class HuggingFaceModelStore(ModelStore):
"""Hugging Face based implementation for storing and retrieving a model."""

async def store_model(self, uid: int, model: Model) -> ModelId:
async def store_model(self, hotkey: str, model: Model) -> ModelId:
"""Stores a trained model in Hugging Face."""
token = os.getenv("HF_ACCESS_TOKEN")
if not token:
Expand All @@ -32,7 +32,7 @@ async def store_model(self, uid: int, model: Model) -> ModelId:
)

# TODO actually make this asynchronous with threadpools etc.
async def retrieve_model(self, uid: int, model_id: ModelId) -> Model:
async def retrieve_model(self, hotkey: str, model_id: ModelId) -> Model:
"""Retrieves a trained model from Hugging Face."""
if not model_id.commit:
raise ValueError("No Hugging Face commit id found to read from the hub.")
Expand All @@ -41,7 +41,7 @@ async def retrieve_model(self, uid: int, model_id: ModelId) -> Model:
model = AutoModel.from_pretrained(
pretrained_model_name_or_path=model_id.path + "/" + model_id.name,
revision=model_id.commit,
cache_dir=utils.get_local_model_dir(uid, model_id),
cache_dir=utils.get_local_model_dir(hotkey, model_id),
use_safetensors=True,
)

Expand All @@ -68,10 +68,12 @@ async def test_roundtrip_model():
hf_model_store = HuggingFaceModelStore()

# Store the model in hf getting back the id with commit.
model.id = await hf_model_store.store_model(uid=0, model=model)
model.id = await hf_model_store.store_model(hotkey="hotkey0", model=model)

# Retrieve the model from hf.
retrieved_model = await hf_model_store.retrieve_model(uid=0, model_id=model.id)
retrieved_model = await hf_model_store.retrieve_model(
hotkey="hotkey0", model_id=model.id
)

# Check that they match.
# TODO create appropriate equality check.
Expand All @@ -92,7 +94,7 @@ async def test_retrieve_model():
hf_model_store = HuggingFaceModelStore()

# Retrieve the model from hf (first run) or cache.
model = await hf_model_store.retrieve_model(uid=0, model_id=model_id)
model = await hf_model_store.retrieve_model(hotkey="hotkey0", model_id=model_id)

print(f"Finished retrieving the model with id: {model.id}")

Expand Down
26 changes: 15 additions & 11 deletions model/storage/local/local_model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,33 @@
class LocalModelStore(ModelStore):
"""Local storage based implementation for storing and retrieving a model."""

async def clear_miner_directory(self, uid: int):
async def clear_miner_directory(self, hotkey: str):
"""Clears out the directory for a given uid."""
shutil.rmtree(path=utils.get_local_miner_dir(uid), ignore_errors=True)
shutil.rmtree(path=utils.get_local_miner_dir(hotkey), ignore_errors=True)

async def clear_model_directory(self, uid: int, model_id: ModelId):
async def clear_model_directory(self, hotkey: str, model_id: ModelId):
"""Clears out the directory for a given model."""
shutil.rmtree(path=utils.get_local_model_dir(uid, model_id), ignore_errors=True)
shutil.rmtree(
path=utils.get_local_model_dir(hotkey, model_id), ignore_errors=True
)

async def store_model(self, uid: int, model: Model) -> ModelId:
async def store_model(self, hotkey: str, model: Model) -> ModelId:
"""Stores a trained model locally."""

model.pt_model.save_pretrained(
save_directory=utils.get_local_model_dir(uid, model.id),
save_directory=utils.get_local_model_dir(hotkey, model.id),
safe_serialization=True,
)

# Return the same model id used as we do not edit the commit information.
return model.id

# TODO actually make this asynchronous with threadpools etc.
async def retrieve_model(self, uid: int, model_id: ModelId) -> Model:
async def retrieve_model(self, hotkey: str, model_id: ModelId) -> Model:
"""Retrieves a trained model locally."""

model = AutoModel.from_pretrained(
pretrained_model_name_or_path=utils.get_local_model_dir(uid, model_id),
pretrained_model_name_or_path=utils.get_local_model_dir(hotkey, model_id),
revision=model_id.commit,
local_files_only=True,
use_safetensors=True,
Expand All @@ -60,13 +62,15 @@ async def test_roundtrip_model():
local_model_store = LocalModelStore()

# Clear the local storage
await local_model_store.clear_model_directory(uid=0, model_id=model_id)
await local_model_store.clear_model_directory(hotkey="hotkey0", model_id=model_id)

# Store the model locally.
await local_model_store.store_model(uid=0, model=model)
await local_model_store.store_model(hotkey="hotkey0", model=model)

# Retrieve the model locally.
retrieved_model = await local_model_store.retrieve_model(uid=0, model_id=model_id)
retrieved_model = await local_model_store.retrieve_model(
hotkey="hotkey0", model_id=model_id
)

# Check that they match.
# TODO create appropriate equality check.
Expand Down
4 changes: 2 additions & 2 deletions model/storage/model_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ class ModelMetadataStore(abc.ABC):
"""An abstract base class for storing and retrieving model metadata."""

@abc.abstractmethod
async def store_model_metadata(self, uid: int, model_id: ModelId):
async def store_model_metadata(self, hotkey: str, model_id: ModelId):
"""Stores model metadata on this subnet for a specific miner."""
pass

@abc.abstractmethod
async def retrieve_model_metadata(self, uid: int) -> ModelMetadata:
async def retrieve_model_metadata(self, hotkey: str) -> ModelMetadata:
"""Retrieves model metadata + block information on this subnet for specific miner"""
pass
4 changes: 2 additions & 2 deletions model/storage/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ class ModelStore(abc.ABC):
"""An abstract base class for storing and retrieving a pre trained model."""

@abc.abstractmethod
async def store_model(self, uid: int, model: Model) -> ModelId:
async def store_model(self, hotkey: str, model: Model) -> ModelId:
"""Stores a trained model in the appropriate location based on implementation."""
pass

@abc.abstractmethod
async def retrieve_model(self, uid: int, pt_model_id: ModelId) -> Model:
async def retrieve_model(self, hotkey: str, pt_model_id: ModelId) -> Model:
"""Retrieves a trained model from the appropriate location based on implementation."""
pass
8 changes: 4 additions & 4 deletions model/storage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@


# TODO make this configurable.
def get_local_miner_dir(uid: int) -> str:
return os.path.join("local-models", str(uid))
def get_local_miner_dir(hotkey: str) -> str:
return os.path.join("local-models", hotkey)


def get_local_model_dir(uid: int, model_id: ModelId) -> str:
return os.path.join(get_local_miner_dir(uid), model_id.path, model_id.name)
def get_local_model_dir(hotkey: str, model_id: ModelId) -> str:
return os.path.join(get_local_miner_dir(hotkey), model_id.path, model_id.name)