Skip to content

Commit

Permalink
Added check for model before running prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
clamytoe committed Dec 19, 2022
1 parent 8012346 commit 07ec70f
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
PROJECT_DIR = Path("kitchenware_classifier")
DATA_DIR = PROJECT_DIR / "data"
TEST_FILE = DATA_DIR / "test.csv"
TEST_DF = pd.read_csv(TEST_FILE)
IMG_DIR = DATA_DIR / "images"
TEST_DF = pd.read_csv(TEST_FILE)
TEST_DF["image"] = TEST_DF["Id"].map(lambda x: f"{IMG_DIR}/{x:0>4}.jpg")
MODEL_FILE = "fastai_model.pkl"
OUTPUT_FILE = "submission.csv"

MODEL_FILE = Path("fastai_model.pkl")
EXPORT_FILE = Path("submission.csv")


def process_images(df, model):
Expand All @@ -26,14 +27,18 @@ def generate_submission(tta, dls):
vocab = np.array(dls.vocab)
sub = pd.read_csv(TEST_FILE)
sub["label"] = vocab[idxs]
sub.to_csv(OUTPUT_FILE, index=False)
sub.to_csv(EXPORT_FILE, index=False)
return sub


def main():
tta, dls = process_images(TEST_DF, MODEL_FILE)
sub = generate_submission(tta, dls)
print(sub.head())
if not MODEL_FILE.exists():
print(f"Model {MODEL_FILE} not found!")
print("Please run train.py first.")
else:
tta, dls = process_images(TEST_DF, MODEL_FILE)
sub = generate_submission(tta, dls)
print(sub.head())


if __name__ == "__main__":
Expand Down

0 comments on commit 07ec70f

Please sign in to comment.