diff --git a/convert_model/mnist_eg/__init__.py b/convert_model/mnist_eg/__init__.py index 511de2e..cdca76c 100644 --- a/convert_model/mnist_eg/__init__.py +++ b/convert_model/mnist_eg/__init__.py @@ -5,29 +5,13 @@ n_classes = 10 -def conv_layer(): - return k.layers.Conv2D( - 32, - (3, 3), - activation="relu", - kernel_initializer="he_uniform", - ) - - -def pool_layer(): - return k.layers.MaxPool2D((2, 2), strides=(2, 2)) - - def mnist_model(): model = k.Sequential( [ k.Input(shape=in_shape), - conv_layer(), - pool_layer(), - conv_layer(), - pool_layer(), k.layers.Flatten(), - k.layers.Dense(500, activation="relu", kernel_initializer="he_uniform"), + k.layers.Dense(128, activation="relu"), + k.layers.Dense(500, activation="relu"), k.layers.Dense(n_classes), ] ) diff --git a/convert_model/mnist_eg/cm.py b/convert_model/mnist_eg/cm.py index 299f842..51118c6 100644 --- a/convert_model/mnist_eg/cm.py +++ b/convert_model/mnist_eg/cm.py @@ -29,7 +29,7 @@ def main(): mlmodel = convert(model) builder = nn_builder(mlmodel) config_builder(builder) - try_make_layers_updatable(builder, 2) + try_make_layers_updatable(builder) builder.inspect_layers() save_builder(builder, COREML_FILE) print(f"Successfully converted to Core ML model at {COREML_FILE}.") diff --git a/upload_mnist_models.py b/upload_mnist_models.py index ca3ee0f..3cc9fdc 100644 --- a/upload_mnist_models.py +++ b/upload_mnist_models.py @@ -3,14 +3,12 @@ tflite_file = "mnist.tflite" coreml_file = "mnist.mlmodel" name = "mnist_unified" -tflite_layers = [1152, 128, 36864, 128, 1600000, 2000, 20000, 40] +tflite_layers = [401408, 512, 256000, 2000, 20000, 40] coreml_layers = [ - {"name": "sequential/conv2d/BiasAdd", "type": "weights", "updatable": False}, - {"name": "sequential/conv2d/BiasAdd", "type": "bias", "updatable": False}, - {"name": "sequential/conv2d_1/BiasAdd", "type": "weights", "updatable": False}, - {"name": "sequential/conv2d_1/BiasAdd", "type": "bias", "updatable": False}, {"name": "sequential/dense/BiasAdd", "type": "weights", "updatable": True}, {"name": "sequential/dense/BiasAdd", "type": "bias", "updatable": True}, + {"name": "sequential/dense_1/BiasAdd", "type": "weights", "updatable": True}, + {"name": "sequential/dense_1/BiasAdd", "type": "bias", "updatable": True}, {"name": "Identity", "type": "weights", "updatable": True}, {"name": "Identity", "type": "bias", "updatable": True}, ]