Skip to content

Commit

Permalink
path to the labels file during classifier saving gets joined with scr…
Browse files Browse the repository at this point in the history
…ipt dir so it works in the installed version. (#456)

added error Log for training and saving model
  • Loading branch information
max-mauermann authored Oct 8, 2024
1 parent fd40dce commit 35079be
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
4 changes: 2 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def saveLinearClassifier(classifier, model_path: str, labels: list[str], mode="r
open(model_path, "wb").write(tflite_model)

if mode == "append":
labels = [*utils.readLines(cfg.LABELS_FILE), *labels]
labels = [*utils.readLines(os.path.join(SCRIPT_DIR, cfg.LABELS_FILE)), *labels]

# Save labels
with open(model_path.replace(".tflite", "_Labels.txt"), "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -399,7 +399,7 @@ def basic(self, inputs):
tf.saved_model.save(smodel, model_path, signatures=signatures)

if mode == "append":
labels = [*utils.readLines(cfg.LABELS_FILE), *labels]
labels = [*utils.readLines(os.path.join(SCRIPT_DIR, cfg.LABELS_FILE)), *labels]

# Save label file
labelIds = [label[:4].replace(" ", "") + str(i) for i, label in enumerate(labels, 1)]
Expand Down
58 changes: 35 additions & 23 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,34 +321,46 @@ def run_trial(self, trial, *args, **kwargs):

# Train model
print("Training model...", flush=True)
classifier, history = model.trainLinearClassifier(
classifier,
x_train,
y_train,
epochs=cfg.TRAIN_EPOCHS,
batch_size=cfg.TRAIN_BATCH_SIZE,
learning_rate=cfg.TRAIN_LEARNING_RATE,
val_split=cfg.TRAIN_VAL_SPLIT,
upsampling_ratio=cfg.UPSAMPLING_RATIO,
upsampling_mode=cfg.UPSAMPLING_MODE,
train_with_mixup=cfg.TRAIN_WITH_MIXUP,
train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING,
on_epoch_end=on_epoch_end,
)
try:
classifier, history = model.trainLinearClassifier(
classifier,
x_train,
y_train,
epochs=cfg.TRAIN_EPOCHS,
batch_size=cfg.TRAIN_BATCH_SIZE,
learning_rate=cfg.TRAIN_LEARNING_RATE,
val_split=cfg.TRAIN_VAL_SPLIT,
upsampling_ratio=cfg.UPSAMPLING_RATIO,
upsampling_mode=cfg.UPSAMPLING_MODE,
train_with_mixup=cfg.TRAIN_WITH_MIXUP,
train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING,
on_epoch_end=on_epoch_end,
)
except Exception as e:
utils.writeErrorLog(e)
raise Exception("Error training model")

print("...Done.", flush=True)

# Best validation AUPRC (at minimum validation loss)
best_val_auprc = history.history["val_AUPRC"][np.argmin(history.history["val_loss"])]
best_val_auroc = history.history["val_AUROC"][np.argmin(history.history["val_loss"])]

if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both":
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels)
model.saveLinearClassifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite":
model.saveLinearClassifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven":
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels)
else:
raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}")
print("Saving model...", flush=True)

try:
if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both":
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels)
model.saveLinearClassifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite":
model.saveLinearClassifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven":
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels)
else:
raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}")
except Exception as e:
utils.writeErrorLog(e)
raise Exception("Error saving model")

print(f"...Done. Best AUPRC: {best_val_auprc}, Best AUROC: {best_val_auroc}", flush=True)

Expand Down

0 comments on commit 35079be

Please sign in to comment.