This is a very slightly modified version of the ``Predict with pre-trained models'' example from http://mxnet.io/tutorials/python/predict_imagenet.html. The only change has been the addition of a function to pull an image from a local directory and also configured so that a gpu is not needed.
From the MXNet page: "This is a demo for predicting with a pre-trained model on the full imagenet dataset, which contains over 10 million images and 10 thousands classes. For a more detailed explanation, please refer to predict.ipynb."
This was run using the AWS Machine Learning AMI
First we downloat the model.
import os, urllib
import mxnet as mx
def download(url,prefix=''):
filename = prefix+url.split("/")[-1]
if not os.path.exists(filename):
urllib.urlretrieve(url, filename)
path='http://data.mxnet.io/models/imagenet-11k/'
download(path+'resnet-152/resnet-152-symbol.json', 'full-')
download(path+'resnet-152/resnet-152-0000.params', 'full-')
download(path+'synset.txt', 'full-')
with open('full-synset.txt', 'r') as f:
synsets = [l.rstrip() for l in f]
sym, arg_params, aux_params = mx.model.load_checkpoint('full-resnet-152', 0)
mod = mx.mod.Module(symbol=sym)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
mod.set_params(arg_params, aux_params)
%matplotlib inline
import matplotlib
matplotlib.rc("savefig", dpi=100)
import matplotlib.pyplot as plt
import cv2
import numpy as np
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
def get_image(url, show=True):
filename = url.split("/")[-1]
urllib.urlretrieve(url, filename)
img = cv2.imread(filename)
if img is None:
print('failed to download ' + url)
if show:
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.axis('off')
return filename
def get_local(filename, show=True):
img = cv2.imread(filename)
if img is None:
print('failed to load ' + filename)
if show:
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.axis('off')
return filename
This function does the prediction by converting the image to an RGB format of the right size.
def predict(filename, mod, synsets):
img = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)
if img is None:
return None
img = cv2.resize(img, (224, 224))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]
mod.forward(Batch([mx.nd.array(img)]))
prob = mod.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
a = np.argsort(prob)[::-1]
for i in a[0:5]:
print('p=%2.2f,%s' %(prob[i], synsets[i][synsets[i].find(' '):]))
now run the predictor. use a local jpg file or modify to use get_image with a url or a jpg file
add a url and run
url = 'path to loal image'
predict(get_local(url), mod, synsets)