Skip to content

Commit

Permalink
Updated train.py script
Browse files Browse the repository at this point in the history
  • Loading branch information
clamytoe committed Dec 18, 2022
1 parent b983bcb commit 26da4e0
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
fn_col="image", # type: ignore
shuffle=True,
label_col="label", # type: ignore
item_tfms=Resize(480),
item_tfms=Resize(460),
batch_ftms=list(tfms),
device=device("cuda"),
)
Expand All @@ -71,28 +71,40 @@
)
keep_path = learn.path
learn.path = DATA_DIR # type: ignore
learn.fit_one_cycle(
10,
cbs=[
EarlyStoppingCallback(
monitor="error_rate",
min_delta=0.000001,
patience=3,
),
SaveModelCallback(
monitor="accuracy",
min_delta=0.000001,
),
],

print("Finding best learning rate...")
lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
print(f"{lrs=}")

# print("Performing initial training...")
# learn.fit_one_cycle(3, lrs.valley)

save_best = SaveModelCallback(
monitor="valid_loss",
min_delta=0.000001,
fname="best_model",
)
early_stop = EarlyStoppingCallback(
monitor="valid_loss",
min_delta=0.000001,
patience=3,
)

print("Training the Model...")
learn.fit(10, lrs.valley, cbs=[save_best, early_stop])
learn.path = keep_path # type: ignore
learn.validate()

# Find optimal learning rate
lrs = learn.lr_find(suggest_funcs=(minimum, steep, valley, slide))
print("Model validation:", learn.validate())

# Fine tune the model
learn.fine_tune(12, lrs.valley)
# print("Unfreezing the last two layers...")
# learn.unfreeze()

# # Fine tune the model
# print("Fine tuning the model...")
# learn.fine_tune(12, lrs.valley)

# Save the model
print("Saving model as: fastai_model.pkl...")
learn.export("fastai_model.pkl")

print("Training completed!")

0 comments on commit 26da4e0

Please sign in to comment.