Skip to content

Commit

Permalink
TensorFlow Taskrunner Workspace with Keras 3 (#1330)
Browse files Browse the repository at this point in the history
* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* formatting changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* changed dockstring

Signed-off-by: yes <[email protected]>

---------

Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh authored Feb 11, 2025
1 parent b2a6728 commit 9cc8790
Show file tree
Hide file tree
Showing 11 changed files with 409 additions and 0 deletions.
2 changes: 2 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/.workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
current_plan_name: default

5 changes: 5 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/plan/cols.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

collaborators:

7 changes: 7 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/plan/data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

# collaborator_name,data_directory_path
one,1


2 changes: 2 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/plan/defaults
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
../../workspace/plan/defaults

42 changes: 42 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.Aggregator
settings :
init_state_path : save/init.pbuf
best_state_path : save/best.pbuf
last_state_path : save/last.pbuf
rounds_to_train : 10

collaborator :
defaults : plan/defaults/collaborator.yaml
template : openfl.component.Collaborator
settings :
delta_updates : false
opt_treatment : RESET

data_loader :
defaults : plan/defaults/data_loader.yaml
template : src.dataloader.MNISTInMemory
settings :
collaborator_count : 2
data_group_name : mnist
batch_size : 256

task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.taskrunner.CNNTaskruner

network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/assigner.yaml

tasks :
defaults : plan/defaults/tasks_keras.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
2 changes: 2 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
keras==3.8.0
tensorflow==2.18.0
3 changes: 3 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""You may copy this file as the starting point of your own model."""
42 changes: 42 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/src/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

from openfl.federated import KerasDataLoader
from .mnist_utils import load_mnist_shard


class MNISTInMemory(KerasDataLoader):
"""Data Loader for MNIST Dataset."""

def __init__(self, data_path, batch_size, **kwargs):
"""
Initialize.
Args:
data_path: File path for the dataset
batch_size (int): The batch size for the data loader
**kwargs: Additional arguments, passed to super init and load_mnist_shard
"""
super().__init__(batch_size, **kwargs)

try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)

self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
self.y_valid = y_valid

self.num_classes = num_classes
118 changes: 118 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/src/mnist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

from logging import getLogger

import numpy as np
from tensorflow.python.keras.utils.data_utils import get_file

logger = getLogger(__name__)


def one_hot(labels, classes):
"""
One Hot encode a vector.
Args:
labels (list): List of labels to onehot encode
classes (int): Total number of categorical classes
Returns:
np.array: Matrix of one-hot encoded labels
"""
return np.eye(classes)[labels]


def _load_raw_datashards(shard_num, collaborator_count):
"""
Load the raw data by shard.
Returns tuples of the dataset shard divided into training and validation.
Args:
shard_num (int): The shard number to use
collaborator_count (int): The number of collaborators in the federation
Returns:
2 tuples: (image, label) of the training, validation dataset
"""
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file('mnist.npz',
origin=origin_folder + 'mnist.npz',
file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')

with np.load(path) as f:
# get all of mnist
X_train_tot = f['x_train']
y_train_tot = f['y_train']

X_valid_tot = f['x_test']
y_valid_tot = f['y_test']

# create the shards
shard_num = int(shard_num)
X_train = X_train_tot[shard_num::collaborator_count]
y_train = y_train_tot[shard_num::collaborator_count]

X_valid = X_valid_tot[shard_num::collaborator_count]
y_valid = y_valid_tot[shard_num::collaborator_count]

return (X_train, y_train), (X_valid, y_valid)


def load_mnist_shard(shard_num, collaborator_count, categorical=True,
channels_last=True, **kwargs):
"""
Load the MNIST dataset.
Args:
shard_num (int): The shard to use from the dataset
collaborator_count (int): The number of collaborators in the federation
categorical (bool): True = convert the labels to one-hot encoded
vectors (Default = True)
channels_last (bool): True = The input images have the channels
last (Default = True)
**kwargs: Additional parameters to pass to the function
Returns:
list: The input shape
int: The number of classes
numpy.ndarray: The training data
numpy.ndarray: The training labels
numpy.ndarray: The validation data
numpy.ndarray: The validation labels
"""
img_rows, img_cols = 28, 28
num_classes = 10

(X_train, y_train), (X_valid, y_valid) = _load_raw_datashards(
shard_num, collaborator_count
)

if channels_last:
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
else:
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)

X_train = X_train.astype('float32')
X_valid = X_valid.astype('float32')
X_train /= 255
X_valid /= 255

logger.info(f'MNIST > X_train Shape : {X_train.shape}')
logger.info(f'MNIST > y_train Shape : {y_train.shape}')
logger.info(f'MNIST > Train Samples : {X_train.shape[0]}')
logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}')

if categorical:
# convert class vectors to binary class matrices
y_train = one_hot(y_train, num_classes)
y_valid = one_hot(y_valid, num_classes)

return input_shape, num_classes, X_train, y_train, X_valid, y_valid
96 changes: 96 additions & 0 deletions openfl-workspace/keras/tensorflow/mnist/src/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

import tensorflow as tf
import keras

class CNNModel(keras.Model):
"""
Custom Keras Model for a Convolutional Neural Network (CNN) with custom training and testing steps.
This model showcase how to define a custom training and testing step for a Keras model with Tensorflow.
Methods
-------
train_step(data)
Performs a single training step, including forward pass, loss computation, gradient calculation,
and weight updates. Also updates the metrics.
test_step(data)
Performs a single testing step, including forward pass, loss computation, and metric updates.
"""

def train_step(self, data):
"""
Perform a single training step.
Args:
data (tuple): A tuple containing the input data and labels. If the tuple has three elements,
it should be (x, y, sample_weight). Otherwise, it should be (x, y).
Returns:
dict: A dictionary mapping metric names to their current values. This includes the loss and
any other metrics configured in `compile()`.
Notes:
- The loss function and metrics are configured in the `compile()` method.
- The optimizer is used to apply the computed gradients to the model's trainable variables.
"""

# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
sample_weight = None
x, y = data

with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value.
# The loss function is configured in `compile()`.
loss = self.compute_loss(
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
)

# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)

# Update weights
self.optimizer.apply(gradients, trainable_vars)

# Update the metrics.
# Metrics are configured in `compile()`.
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred, sample_weight=sample_weight)

# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}

def test_step(self, data):
"""
Perform a single test step.
Args:
data (tuple): A tuple containing the input data (x) and the true labels (y).
Returns:
dict: A dictionary mapping metric names to their current values. This includes the loss and other metrics tracked in self.metrics.
"""

# Unpack the data
x, y = data
# Compute predictions
y_pred = self(x, training=False)
# Updates the metrics tracking the loss
loss = self.compute_loss(y=y, y_pred=y_pred)
# Update the metrics.
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
Loading

0 comments on commit 9cc8790

Please sign in to comment.