-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Dmitri
committed
Apr 24, 2018
0 parents
commit 3861b31
Showing
14 changed files
with
2,536 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# AWS Greengrass ML Inference Playground |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# -*- mode: ruby -*- | ||
# vi: set ft=ruby : | ||
|
||
hostname = ENV['HOSTNAME'] ? ENV['HOSTNAME'] : 'ml-infer' | ||
box = ENV['BOX'] ? ENV['BOX'] : 'ubuntu/xenial64' | ||
|
||
# Vagrantfile API/syntax version. Don't touch unless you know what you're doing! | ||
VAGRANTFILE_API_VERSION = "2" | ||
|
||
Vagrant.configure(VAGRANTFILE_API_VERSION) do |config| | ||
config.vm.define "ml-infer" do |gg| | ||
# Box details | ||
gg.vm.box = "#{box}" | ||
gg.vm.hostname = "#{hostname}" | ||
|
||
# Box Specifications | ||
gg.vm.provider :virtualbox do |vb| | ||
vb.name = "#{hostname}" | ||
vb.memory = 2048 | ||
vb.cpus = 2 | ||
end | ||
|
||
# NFS-synced directory for pack development | ||
# Change "/path/to/directory/on/host" to point to existing directory on your laptop/host and uncomment: | ||
# config.vm.synced_folder "/path/to/directory/on/host", "/opt/stackstorm/packs", :nfs => true, :mount_options => ['nfsvers=3'] | ||
|
||
# Configure a private network | ||
gg.vm.network :private_network, ip: "192.168.16.31" | ||
|
||
# Public (bridged) network may come handy for external access to VM (e.g. sensor development) | ||
# See https://www.vagrantup.com/docs/networking/public_network.html | ||
# gg.vm.network "public_network", bridge: 'en0: Wi-Fi (AirPort)' | ||
|
||
gg.vm.provision "shell" do |s| | ||
s.path = "scripts/install.sh" | ||
s.privileged = false | ||
end | ||
end | ||
|
||
end |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Greeengrass Group definition file | ||
Group: | ||
name: ML_infer | ||
Cores: | ||
- name: ML_infer_core_1 | ||
key_path: ./certs | ||
config_path: ./config | ||
SyncShadow: False | ||
|
||
Lambdas: | ||
- name: GreengrassImageClassification | ||
handler: run.handler | ||
package: lambdas/GreengrassImageClassification | ||
alias: dev | ||
greengrassConfig: | ||
MemorySize: 128000 # Kb, ask AWS why | ||
Timeout: 10 # Sec | ||
Pinned: True # Set True for long-lived functions | ||
Environment: | ||
AccessSysfs: False | ||
# TODO: add access to local file system for input | ||
# ResourceAccessPolicies: | ||
# - ResourceId: /vagrant/input | ||
# Permission: 'rw' | ||
Variables: | ||
WATCH_PATTERN: /input/*.jpeg | ||
|
||
- name: HelloLongRunning | ||
handler: function.handler | ||
package: lambdas/HelloLongRunning | ||
alias: dev | ||
greengrassConfig: | ||
MemorySize: 128000 # Kb, ask AWS why | ||
Timeout: 10 # Sec | ||
Pinned: True | ||
Environment: | ||
Variables: | ||
INTERVAL: '5' | ||
|
||
# Subscriptions: # not implemented | ||
# - Source: Lambda::GreengrassImageClassification | ||
# Subject: hello/world | ||
# Target: cloud | ||
# - Source: cloud | ||
# Subject: hello/world | ||
# Target: Lambda::GreengrassImageClassification | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import mxnet as mx | ||
import numpy as np | ||
# import picamera | ||
import time | ||
import io | ||
import cv2 | ||
import urllib2 | ||
from collections import namedtuple | ||
Batch = namedtuple('Batch', ['data']) | ||
|
||
|
||
class ImagenetModel(object): | ||
|
||
# Loads a pre-trained model locally or from an external URL | ||
# and returns an MXNet graph that is ready for prediction | ||
def __init__(self, synset_path, network_prefix, params_url=None, | ||
symbol_url=None, synset_url=None, context=mx.cpu(), | ||
label_names=['prob_label'], input_shapes=[('data', (1, 3, 224, 224))]): | ||
|
||
# Download the symbol set and network if URLs are provided | ||
if params_url is not None: | ||
fetched_file = urllib2.urlopen(params_url) | ||
with open(network_prefix + "-0000.params", 'wb') as output: | ||
output.write(fetched_file.read()) | ||
|
||
if symbol_url is not None: | ||
fetched_file = urllib2.urlopen(symbol_url) | ||
with open(network_prefix + "-symbol.json", 'wb') as output: | ||
output.write(fetched_file.read()) | ||
|
||
if synset_url is not None: | ||
fetched_file = urllib2.urlopen(synset_url) | ||
with open(synset_path, 'wb') as output: | ||
output.write(fetched_file.read()) | ||
|
||
# Load the symbols for the networks | ||
with open(synset_path, 'r') as f: | ||
self.synsets = [l.rstrip() for l in f] | ||
|
||
# Load the network parameters from default epoch 0 | ||
sym, arg_params, aux_params = mx.model.load_checkpoint(network_prefix, 0) | ||
|
||
# Load the network into an MXNet module and bind the corresponding parameters | ||
self.mod = mx.mod.Module(symbol=sym, label_names=label_names, context=context) | ||
self.mod.bind(for_training=False, data_shapes=input_shapes) | ||
self.mod.set_params(arg_params, aux_params) | ||
self.camera = None | ||
|
||
def predict_from_image(self, image, reshape=(224, 224), top=5): | ||
top_n = [] | ||
|
||
# Construct a numpy array from the stream | ||
data = np.fromstring(image, dtype=np.uint8) | ||
# "Decode" the image from the array, preserving colour | ||
cvimage = cv2.imdecode(data, 1) | ||
|
||
# Switch RGB to BGR format (which ImageNet networks take) | ||
img = cv2.cvtColor(cvimage, cv2.COLOR_BGR2RGB) | ||
if img is None: | ||
return top_n | ||
|
||
# Resize image to fit network input | ||
img = cv2.resize(img, reshape) | ||
img = np.swapaxes(img, 0, 2) | ||
img = np.swapaxes(img, 1, 2) | ||
img = img[np.newaxis, :] | ||
|
||
# Run forward on the image | ||
self.mod.forward(Batch([mx.nd.array(img)])) | ||
prob = self.mod.get_outputs()[0].asnumpy() | ||
prob = np.squeeze(prob) | ||
|
||
# Extract the top N predictions from the softmax output | ||
a = np.argsort(prob)[::-1] | ||
for i in a[0:top]: | ||
top_n.append((prob[i], self.synsets[i])) | ||
return top_n | ||
|
||
# Captures an image from the PiCamera, then sends it for prediction | ||
def predict_from_cam(self, capfile='cap.jpg', reshape=(224, 224), top=5): | ||
if self.camera is None: | ||
self.camera = picamera.PiCamera() | ||
|
||
stream = io.BytesIO() | ||
self.camera.start_preview() | ||
time.sleep(2) | ||
self.camera.capture(stream, format='jpeg') | ||
return self.predict_from_image(stream.getvalue(), reshape, top) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import sys | ||
import glob | ||
import time | ||
import os | ||
import model | ||
|
||
WATCH_PATTERN = os.environ.get('WATCH_PATTERN', './input/*.jpeg') | ||
MODEL_PATH = os.environ.get('MODEL_PATH', './squeezenet/') | ||
|
||
global_model = model.ImagenetModel(MODEL_PATH + 'synset.txt', MODEL_PATH + 'squeezenet_v1.1') | ||
|
||
|
||
def run_image_classification(): | ||
while True: | ||
for fn in glob.glob(WATCH_PATTERN): | ||
|
||
try: | ||
with open(fn, 'r') as f: | ||
image = f.read() | ||
prediction = global_model.predict_from_image(image) | ||
print prediction | ||
except: | ||
e = sys.exc_info()[0] | ||
print("Exception occured during prediction: %s" % e) | ||
os.remove(fn) | ||
|
||
time.sleep(2) | ||
|
||
|
||
# When run as Lambda, this will be invoked | ||
run_image_classification() | ||
|
||
|
||
# Dummy function handler, not invoked for long-running lambda | ||
def handler(event, context): | ||
return |
Binary file added
BIN
+4.72 MB
lambdas/GreengrassImageClassification/squeezenet/squeezenet_v1.1-0000.params
Binary file not shown.
Oops, something went wrong.