Skip to content

Commit

Permalink
update task fitting example to use 8m
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Feb 14, 2025
1 parent 57b01a0 commit 66850ee
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 258 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
intime_model_selector: Optional[IntimeModelSelector] = None,
convert_to_fed_event: Optional[ConvertToFedEvent] = None,
analytics_receiver: Optional[AnalyticsReceiver] = None,
embedding_dimensions: int = 320 # embedding dimensions of ESM2-8m
):
"""PyTorch BaseFedJob.
Expand All @@ -65,6 +66,7 @@ def __init__(
if not provided, a ConvertToFedEvent object will be created.
analytics_receiver (AnlyticsReceiver, optional): Receive analytics.
If not provided, a TBAnalyticsReceiver will be configured.
embedding_dimensions: embedding dimensions of ESM2 model. Defaults to 320, the embedding dimensions of ESM2-8m.
"""
super().__init__(
name=name,
Expand Down Expand Up @@ -103,7 +105,7 @@ def __init__(
obj=analytics_receiver,
)

self.to_server(id="persistor", obj=BioNeMoMLPModelPersistor())
self.to_server(id="persistor", obj=BioNeMoMLPModelPersistor(embedding_dimensions=embedding_dimensions))

self.to_server(id="locator", obj=PTFileModelLocator(pt_persistor_id="persistor"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def __init__(
analytic_sender_id: str = "analytic_sender",
batch_size: int = 128,
num_workers: int = 0,
warm_start: bool = True,
embedding_dimensions: int = 320 # embedding dimensions of ESM2-8m
):
"""Simple CIFAR-10 Trainer.
"""BioNeMo MLP Trainer.
Args:
data_path: data file with labels in csv format.
Expand All @@ -60,7 +60,7 @@ def __init__(
If configured, TensorBoard events will be fired. Defaults to "analytic_sender".
batch_size: batch size for training and validation.
num_workers: number of workers for data loaders.
warm_start: Use `True` in federated learning and `False` when simulating local training only.
embedding_dimensions: embedding dimensions of ESM2 model. Defaults to 320, the embedding dimensions of ESM2-8m.
Returns:
an FLModel with the updated local model differences after running `train()`, the metrics after `validate()`,
Expand All @@ -78,7 +78,7 @@ def __init__(
self.batch_size = batch_size
self.num_workers = num_workers
self.analytic_sender_id = analytic_sender_id
self.warm_start = warm_start
self.embedding_dimensions = embedding_dimensions

self.sim_local = strtobool(os.getenv("SIM_LOCAL", "False"))

Expand Down Expand Up @@ -167,7 +167,7 @@ def initialize(self):
]
_X, _y = [], []
for label in class_labels:
_X.append(np.random.rand(1280)) # embedding dimensions of ESM2-650m
_X.append(np.random.rand(self.embedding_dimensions))
_y.append(label)
self.model.partial_fit(_X, _y, classes=class_labels)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
best_global_model_file_name=DefaultCheckpointFileName.BEST_GLOBAL_MODEL,
source_ckpt_file_full_name=None,
filter_id: str = None,
embedding_dimensions: int = 320 # embedding dimensions of ESM2-8m
):
"""Persist sklearn-based model to/from file system.
Expand All @@ -50,12 +51,14 @@ def __init__(
source_ckpt_file_full_name (str, optional): full file name for source model checkpoint file. Defaults to None.
filter_id: Optional string that defines a filter component that is applied to prepare the model to be saved,
e.g. for serialization of custom Python objects.
embedding_dimensions: embedding dimensions of ESM2 model. Defaults to 320, the embedding dimensions of ESM2-8m.
Raises:
ValueError: when source_ckpt_file_full_name does not exist
"""
super().__init__(
filter_id=filter_id,
)
self.embedding_dimensions = embedding_dimensions
self.model = MLPClassifier(solver="adam", hidden_layer_sizes=(512, 256, 128), random_state=10, max_iter=1)
self.log_dir = None
self.ckpt_preload_path = None
Expand Down Expand Up @@ -87,7 +90,7 @@ def _initialize(self, fl_ctx: FLContext):
]
_X, _y = [], []
for label in class_labels:
_X.append(np.random.rand(1280)) # embedding dimensions of ESM2-650m
_X.append(np.random.rand(self.embedding_dimensions))
_y.append(label)
self.model.fit(_X, _y)
self.log_info(
Expand Down
Loading

0 comments on commit 66850ee

Please sign in to comment.