Skip to content

Commit

Permalink
Modified script to use extra training images
Browse files Browse the repository at this point in the history
  • Loading branch information
clamytoe committed Dec 19, 2022
1 parent a6bfead commit 2e2e7f1
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# Model to use
MODEL = "convnext_nano"
MODEL_FILE = "fastai_model.pkl"

# Load data and set variables
SEED = 42
Expand All @@ -34,6 +35,8 @@

TRAIN_DF = pd.read_csv(DATA_DIR / "train.csv")
TEST_DF = pd.read_csv(DATA_DIR / "test.csv")
EXTRA_DF = pd.read_csv(DATA_DIR / "extra.csv")
TRAIN_DF = pd.concat([TRAIN_DF, EXTRA_DF], ignore_index=True)
TRAIN_DF["image"] = TRAIN_DF["Id"].map(lambda x: f"{x:0>4}.jpg")
TEST_DF["image"] = TEST_DF["Id"].map(lambda x: f"{IMG_DIR}/{x:0>4}.jpg")

Expand All @@ -54,8 +57,8 @@
path=str(IMG_DIR),
valid_pct=0.2,
seed=42,
bs=16,
val_bs=16,
bs=8,
val_bs=8,
fn_col="image", # type: ignore
shuffle=True,
label_col="label", # type: ignore
Expand Down Expand Up @@ -97,7 +100,8 @@
print("Model validation:", learn.validate())

# Save the model
print("Saving model as: fastai_model.pkl...")
learn.export("fastai_model.pkl")
print(f"Saving model as: {MODEL_FILE}...")
learn.export(MODEL_FILE)

print("Training completed!")
print(learn.summary())

0 comments on commit 2e2e7f1

Please sign in to comment.