-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
40 lines (31 loc) · 1.6 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
import torch
from torchvision import transforms, models
from torch import nn
path = "seefood89.pth"
model = models.densenet121(pretrained=False)
for params in model.parameters():
params.require_grad = False
classifier = nn.Sequential(nn.Linear(1024,1024),nn.ReLU(),nn.Dropout(p=0.3),
nn.Linear(1024,512),nn.ReLU(),nn.Dropout(p=0.3),
nn.Linear(512,2),nn.LogSoftmax(dim=1))
model.classifier = classifier
model.load_state_dict(torch.load(path,map_location='cpu'))
model.eval()
classes = {0:'✅ hot dog 🌭 ', 1:'❌ not hot dog 🌭 '}
def get_prediction(img):
transform = transforms.Compose([transforms.ToPILImage(),transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),])
img_t = transform(img)
output = model(img_t.unsqueeze(0))
prediction = torch.argmax(output,dim=1)
return classes[prediction.item()]
title = "SEEFOOD"
description = "<p style='text-align: center'>It's shazam for food but only hotdogs (from HBO's Silicon Valley) , made using transfer learning ( Densenet121)</p>"
article="<p style='text-align: center'><a href='https://github.com/vinayakj02/SEEFOOD-classifier' target='_blank'>Github</a></p>"
gr.Interface(fn=get_prediction,
inputs="image",
outputs="label",
examples=["examples/fries.jpg", "examples/hotdog1.jpg", "examples/pizza.jpg", "examples/cream.jpg"],
title=title,description=description,article=article,).launch(server_name="0.0.0.0", server_port=7000)