Skip to content

Commit

Permalink
Experimental KERAS 3
Browse files Browse the repository at this point in the history
Signed-off-by: Chaurasiya, Payal <[email protected]>
  • Loading branch information
payalcha committed Nov 15, 2024
1 parent 1d7c9bf commit f62ef63
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 27 deletions.
3 changes: 2 additions & 1 deletion openfl-workspace/keras_cnn_mnist/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
tensorflow==2.13
keras==3.6
tensorflow==2.18
22 changes: 10 additions & 12 deletions openfl-workspace/keras_cnn_mnist/src/keras_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

"""You may copy this file as the starting point of your own model."""

import tensorflow.keras as ke
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
import keras as ke
from keras import Sequential
from keras.layers import Conv2D, Dense, Flatten

from openfl.federated import KerasTaskRunner

Expand Down Expand Up @@ -72,14 +70,14 @@ def build_model(self,

model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=ke.losses.categorical_crossentropy,
optimizer=ke.optimizers.legacy.Adam(),
model.compile(loss=ke.losses.BinaryCrossentropy,
optimizer=ke.optimizers.Adam(),
metrics=['accuracy'])

# initialize the optimizer variables
opt_vars = model.optimizer.variables()

for v in opt_vars:
v.initializer.run(session=self.sess)
# # initialize the optimizer variables
# opt_vars = model.optimizer.variables()

# for v in opt_vars:
# v.initializer.run(session=self.sess)
# # ke.backend.get_session().run(v.initializer)
return model
18 changes: 4 additions & 14 deletions openfl-workspace/keras_cnn_mnist/src/mnist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
"""You may copy this file as the starting point of your own model."""

from logging import getLogger

import numpy as np
from tensorflow.python.keras.utils.data_utils import get_file
import keras

logger = getLogger(__name__)

Expand Down Expand Up @@ -38,18 +37,7 @@ def _load_raw_datashards(shard_num, collaborator_count):
Returns:
2 tuples: (image, label) of the training, validation dataset
"""
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file('mnist.npz',
origin=origin_folder + 'mnist.npz',
file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')

with np.load(path) as f:
# get all of mnist
X_train_tot = f['x_train']
y_train_tot = f['y_train']

X_valid_tot = f['x_test']
y_valid_tot = f['y_test']
(X_train_tot, y_train_tot), (X_valid_tot, y_valid_tot ) = keras.datasets.mnist.load_data(path="mnist.npz")

# create the shards
shard_num = int(shard_num)
Expand Down Expand Up @@ -116,3 +104,5 @@ def load_mnist_shard(shard_num, collaborator_count, categorical=True,
y_valid = one_hot(y_valid, num_classes)

return input_shape, num_classes, X_train, y_train, X_valid, y_valid

# _, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(0,2)

0 comments on commit f62ef63

Please sign in to comment.