-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathcreate_maxim_model.py
34 lines (27 loc) · 1.05 KB
/
create_maxim_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from tensorflow import keras
from maxim import maxim
from maxim.configs import MAXIM_CONFIGS
def Model(variant=None, input_resolution=256, **kw) -> keras.Model:
"""Factory function to easily create a Model variant like "S".
Every model file should have this Model() function that returns the flax
model function. The function name should be fixed.
Args:
variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'
| 'M-1' | 'M-2' | 'M-3'
input_resolution: Size of the input images.
**kw: Other UNet config dicts.
Returns:
The MAXIM() model function
"""
if variant is not None:
config = MAXIM_CONFIGS[variant]
for k, v in config.items():
kw.setdefault(k, v)
if "variant" in kw:
_ = kw.pop("variant")
model_name = kw.pop("name")
inputs = keras.Input((input_resolution, input_resolution, 3))
maxim_model = maxim.MAXIM(**kw)
outputs = maxim_model(inputs)
final_model = keras.Model(inputs, outputs, name=f"{model_name}_model")
return final_model