-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutils.py
268 lines (223 loc) · 11.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# The MIT License (MIT)
# Copyright © 2021 Yuma Rao
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software.
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
# Utils for checkpointing and saving the model.
import torch
import wandb
import copy
import bittensor as bt
import openvalidators
from openvalidators.misc import ttl_get_block
from openvalidators.reward import MockRewardModel
def should_reinit_wandb(self):
# Check if wandb run needs to be rolled over.
return not self.config.wandb.off and self.step and self.step % self.config.wandb.run_step_length == 0
def init_wandb(self, reinit=False):
"""Starts a new wandb run."""
tags = [self.wallet.hotkey.ss58_address,
openvalidators.__version__,
str(openvalidators.__spec_version__),
f'netuid_{self.metagraph.netuid}']
if self.config.mock:
tags.append("mock")
if self.config.neuron.use_custom_gating_model:
tags.append("custom_gating_model")
for fn in self.reward_functions:
if not isinstance(fn, MockRewardModel):
tags.append(str(fn.name))
if self.config.neuron.disable_set_weights:
tags.append("disable_set_weights")
if self.config.neuron.disable_log_rewards:
tags.append("disable_log_rewards")
wandb_config = {key: copy.deepcopy(self.config.get(key, None)) for key in ('neuron', 'reward', 'netuid', 'wandb')}
wandb_config['neuron'].pop('full_path', None)
self.wandb = wandb.init(
anonymous="allow",
reinit=reinit,
project=self.config.wandb.project_name,
entity=self.config.wandb.entity,
config=wandb_config,
mode="offline" if self.config.wandb.offline else "online",
dir=self.config.neuron.full_path,
tags=tags,
notes=self.config.wandb.notes,
)
bt.logging.success(
prefix="Started a new wandb run",
sufix=f"<blue> {self.wandb.name} </blue>",
)
def reinit_wandb(self):
"""Reinitializes wandb, rolling over the run."""
self.wandb.finish()
init_wandb(self, reinit=True)
def should_checkpoint(self):
# Check if enough epoch blocks have elapsed since the last checkpoint.
return (
ttl_get_block(self) % self.config.neuron.checkpoint_block_length
< self.prev_block % self.config.neuron.checkpoint_block_length
)
def checkpoint(self):
"""Checkpoints the training process."""
bt.logging.info("checkpoint()")
resync_metagraph(self)
save_state(self)
def resync_metagraph(self: 'openvalidators.neuron.neuron'):
"""Resyncs the metagraph and updates the hotkeys and moving averages based on the new metagraph."""
bt.logging.info("resync_metagraph()")
# Copies state of metagraph before syncing.
previous_metagraph = copy.deepcopy(self.metagraph)
# Sync the metagraph.
self.metagraph.sync(subtensor=self.subtensor)
# Check if the metagraph axon info has changed.
metagraph_axon_info_updated = previous_metagraph.axons != self.metagraph.axons
if metagraph_axon_info_updated:
bt.logging.info("Metagraph updated, re-syncing hotkeys, dendrite pool and moving averages")
# Reconstruct the dendrite pool with the new endpoints.
self.dendrite_pool.resync(self.metagraph)
# Zero out all hotkeys that have been replaced.
for uid, hotkey in enumerate(self.hotkeys):
if hotkey != self.metagraph.hotkeys[uid]:
self.moving_averaged_scores[uid] = 0 # hotkey has been replaced
# Check to see if the metagraph has changed size.
# If so, we need to add new hotkeys and moving averages.
if len(self.hotkeys) < len(self.metagraph.hotkeys):
# Update the size of the moving average scores.
new_moving_average = torch.zeros((self.metagraph.n)).to(self.device)
new_moving_average[: len(self.hotkeys)] = self.moving_averaged_scores
self.moving_averaged_scores = new_moving_average
# Resize the gating model.
bt.logging.info("Re-syncing gating model")
self.gating_model.resync(previous_metagraph, self.metagraph)
# Update the hotkeys.
self.hotkeys = copy.deepcopy(self.metagraph.hotkeys)
def resync_linear_layer(
linear_layer: torch.nn.Module,
previous_metagraph: "bt.metagraph.Metagraph",
metagraph: "bt.metagraph.Metagraph",
):
"""Resync the linear layer with the latest state of the network
Args:
linear_layer (:obj: torch.nn.Module): Linear layer to be resynced
previous_metagraph (:obj: bt.metagraph.Metagraph):
Previous state of metagraph before updated resync
metagraph (:obj: bt.metagraph.Metagraph):
Latest state of the metagraph with updated uids and hotkeys
"""
uids_hotkeys_state_dict = dict(zip(previous_metagraph.uids.tolist(), previous_metagraph.hotkeys))
latest_uids_hotkeys_state_dict = dict(zip(metagraph.uids.tolist(), metagraph.hotkeys))
updated_uids_indices = []
for uid, latest_hotkey in latest_uids_hotkeys_state_dict.items():
if uids_hotkeys_state_dict.get(uid) != latest_hotkey:
updated_uids_indices.append(uid)
for index in updated_uids_indices:
# Reinitialize the bias of the selected index of the linear layer
torch.nn.init.zeros_(linear_layer.bias[index])
# Clone the weights of the selected index of the linear layer
weights = linear_layer.weight[index].clone()
# Adds a dimension to the weights tensor to make it compatible with the xavier_uniform_ function
torch.nn.init.xavier_uniform_(weights.unsqueeze(0))
reinitialized_weights = weights.squeeze(0)
# Copy the reinitialized weights back to the selected index of the linear layer
linear_layer.weight[index].data.copy_(reinitialized_weights)
def check_uid_availability(metagraph: "bt.metagraph.Metagraph", uid: int, vpermit_tao_limit: int) -> bool:
"""Check if uid is available. The UID should be available if it is serving and has less than vpermit_tao_limit stake
Args:
metagraph (:obj: bt.metagraph.Metagraph): Metagraph object
uid (int): uid to be checked
vpermit_tao_limit (int): Validator permit tao limit
Returns:
bool: True if uid is available, False otherwise
"""
# Filter non serving axons.
if not metagraph.axons[uid].is_serving:
return False
# Filter validator permit > 1024 stake.
if metagraph.validator_permit[uid]:
if metagraph.S[uid] > vpermit_tao_limit:
return False
# Available otherwise.
return True
def save_state(self):
r"""Save hotkeys, gating model, neuron model and moving average scores to filesystem."""
bt.logging.info("save_state()")
try:
neuron_state_dict = {
"neuron_weights": self.moving_averaged_scores.to('cpu').tolist(),
"neuron_hotkeys": self.hotkeys,
}
torch.save(neuron_state_dict, f"{self.config.neuron.full_path}/model.torch")
bt.logging.success(
prefix="Saved model",
sufix=f"<blue>{ self.config.neuron.full_path }/model.torch</blue>",
)
except Exception as e:
bt.logging.warning(f"Failed to save model with error: {e}")
try:
# Save the gating model.
gating_model_linear_layer_dict = self.gating_model.linear.state_dict()
gating_model_name = self.config.gating.model_name.replace("/", "_")
gating_model_file_path = f"{self.config.neuron.full_path}/{gating_model_name}_gating_linear_layer.pth"
torch.save(gating_model_linear_layer_dict, gating_model_file_path)
if not self.config.wandb.off:
wandb.log({
"step": self.step,
"block": ttl_get_block(self),
**neuron_state_dict
})
if not self.config.wandb.off and self.config.wandb.track_gating_model:
model_artifact = wandb.Artifact(f"{gating_model_name}_gating_linear_layer", type="model")
model_artifact.add_file(gating_model_file_path)
self.wandb.log_artifact(model_artifact)
bt.logging.success(prefix="Saved gating model", sufix=f"<blue>{gating_model_file_path}</blue>")
except Exception as e:
bt.logging.warning(f"Failed to save gating model with error: {e}")
try:
# Save diversity model.
diversity_model_dict = {"historic_embeddings": self.diversity_model.historic_embeddings.to('cpu')}
diversity_model_file_path = f"{self.config.neuron.full_path}/diversity_model.pth"
torch.save(diversity_model_dict, diversity_model_file_path)
bt.logging.success(
prefix="Saved diversity model",
sufix=f"<blue>{diversity_model_file_path}</blue> {list(self.diversity_model.historic_embeddings.shape)}",
)
except Exception as e:
bt.logging.warning(f"Failed to save diversity model with error: {e}")
# empty cache
torch.cuda.empty_cache()
def load_state(self):
r"""Load hotkeys and moving average scores from filesystem."""
bt.logging.info("load_state()")
try:
state_dict = torch.load(f"{self.config.neuron.full_path}/model.torch")
# Check for nans in saved state dict
neuron_weights = torch.tensor(state_dict["neuron_weights"])
if not torch.isnan(neuron_weights).any():
self.moving_averaged_scores = neuron_weights.to(self.device)
self.hotkeys = state_dict["neuron_hotkeys"]
bt.logging.success(
prefix="Reloaded model",
sufix=f"<blue>{ self.config.neuron.full_path }/model.torch</blue>",
)
except Exception as e:
bt.logging.warning(f"Failed to load model with error: {e}")
try:
# Load diversity model.
diversity_model_file_path = f"{self.config.neuron.full_path}/diversity_model.pth"
diversity_model_dict = torch.load(diversity_model_file_path)
self.diversity_model.historic_embeddings = diversity_model_dict["historic_embeddings"].to(self.device)
bt.logging.success(
prefix="Reloaded diversity model",
sufix=f"<blue>{diversity_model_file_path}</blue> {list(self.diversity_model.historic_embeddings.shape)}",
)
except Exception as e:
bt.logging.warning(f"Failed to load diversity model with error: {e}")