Skip to content

Commit

Permalink
add learning to remember rare events
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirnachum committed Mar 2, 2017
1 parent bc70271 commit 6a9c0da
Show file tree
Hide file tree
Showing 5 changed files with 1,231 additions and 0 deletions.
55 changes: 55 additions & 0 deletions learning_to_remember_rare_events/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
Code for the Memory Module as described
in "Learning to Remember Rare Events" by
Lukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio
published as a conference paper at ICLR 2017.

Requirements:
* TensorFlow (see tensorflow.org for how to install)
* Some basic command-line utilities (git, unzip).

Description:

The general memory module is located in memory.py.
Some code is provided to see the memory module in
action on the standard Omniglot dataset.
Download and setup the dataset using data_utils.py
and then run the training script train.py
(see example commands below).

Note that the structure and parameters of the model
are optimized for the data preparation as provided.

Quick Start:

First download and set-up Omniglot data by running

```
python data_utils.py
```

Then run the training script:

```
python train.py --memory_size=8192 \
--batch_size=16 --validation_length=50 \
--episode_width=5 --episode_length=30
```

The first validation batch may look like this (although it is noisy):
```
0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516, 3-shot: 0.604,
4-shot: 0.656, 5-shot: 0.684
```
At step 500 you may see something like this:
```
0-shot: 0.036, 1-shot: 0.836, 2-shot: 0.900, 3-shot: 0.940,
4-shot: 0.944, 5-shot: 0.916
```
At step 4000 you may see something like this:
```
0-shot: 0.044, 1-shot: 0.960, 2-shot: 1.000, 3-shot: 0.988,
4-shot: 0.972, 5-shot: 0.992
```

Maintained by Ofir Nachum (ofirnachum) and
Lukasz Kaiser (lukaszkaiser).
242 changes: 242 additions & 0 deletions learning_to_remember_rare_events/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Data loading and other utilities.
Use this file to first copy over and pre-process the Omniglot dataset.
Simply call
python data_utils.py
"""

import cPickle as pickle
import logging
import os
import subprocess

import numpy as np
from scipy.misc import imresize
from scipy.misc import imrotate
from scipy.ndimage import imread
import tensorflow as tf


MAIN_DIR = ''
REPO_LOCATION = 'https://github.com/brendenlake/omniglot.git'
REPO_DIR = os.path.join(MAIN_DIR, 'omniglot')
DATA_DIR = os.path.join(REPO_DIR, 'python')
TRAIN_DIR = os.path.join(DATA_DIR, 'images_background')
TEST_DIR = os.path.join(DATA_DIR, 'images_evaluation')
DATA_FILE_FORMAT = os.path.join(MAIN_DIR, '%s_omni.pkl')

TRAIN_ROTATIONS = True # augment training data with rotations
TEST_ROTATIONS = False # augment testing data with rotations
IMAGE_ORIGINAL_SIZE = 105
IMAGE_NEW_SIZE = 28


def get_data():
"""Get data in form suitable for episodic training.
Returns:
Train and test data as dictionaries mapping
label to list of examples.
"""
with tf.gfile.GFile(DATA_FILE_FORMAT % 'train') as f:
processed_train_data = pickle.load(f)
with tf.gfile.GFile(DATA_FILE_FORMAT % 'test') as f:
processed_test_data = pickle.load(f)

train_data = {}
test_data = {}

for data, processed_data in zip([train_data, test_data],
[processed_train_data, processed_test_data]):
for image, label in zip(processed_data['images'],
processed_data['labels']):
if label not in data:
data[label] = []
data[label].append(image.reshape([-1]).astype('float32'))

intersection = set(train_data.keys()) & set(test_data.keys())
assert not intersection, 'Train and test data intersect.'
ok_num_examples = [len(ll) == 20 for _, ll in train_data.iteritems()]
assert all(ok_num_examples), 'Bad number of examples in train data.'
ok_num_examples = [len(ll) == 20 for _, ll in test_data.iteritems()]
assert all(ok_num_examples), 'Bad number of examples in test data.'

logging.info('Number of labels in train data: %d.', len(train_data))
logging.info('Number of labels in test data: %d.', len(test_data))

return train_data, test_data


def crawl_directory(directory, augment_with_rotations=False,
first_label=0):
"""Crawls data directory and returns stuff."""
label_idx = first_label
images = []
labels = []
info = []

# traverse root directory
for root, _, files in os.walk(directory):
logging.info('Reading files from %s', root)
fileflag = 0
for file_name in files:
full_file_name = os.path.join(root, file_name)
img = imread(full_file_name, flatten=True)
for i, angle in enumerate([0, 90, 180, 270]):
if not augment_with_rotations and i > 0:
break

images.append(imrotate(img, angle))
labels.append(label_idx + i)
info.append(full_file_name)

fileflag = 1

if fileflag:
label_idx += 4 if augment_with_rotations else 1

return images, labels, info


def resize_images(images, new_width, new_height):
"""Resize images to new dimensions."""
resized_images = np.zeros([images.shape[0], new_width, new_height],
dtype=np.float32)

for i in range(images.shape[0]):
resized_images[i, :, :] = imresize(images[i, :, :],
[new_width, new_height],
interp='bilinear',
mode=None)
return resized_images


def write_datafiles(directory, write_file,
resize=True, rotate=False,
new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
first_label=0):
"""Load and preprocess images from a directory and write them to a file.
Args:
directory: Directory of alphabet sub-directories.
write_file: Filename to write to.
resize: Whether to resize the images.
rotate: Whether to augment the dataset with rotations.
new_width: New resize width.
new_height: New resize height.
first_label: Label to start with.
Returns:
Number of new labels created.
"""

# these are the default sizes for Omniglot:
imgwidth = IMAGE_ORIGINAL_SIZE
imgheight = IMAGE_ORIGINAL_SIZE

logging.info('Reading the data.')
images, labels, info = crawl_directory(directory,
augment_with_rotations=rotate,
first_label=first_label)

images_np = np.zeros([len(images), imgwidth, imgheight], dtype=np.bool)
labels_np = np.zeros([len(labels)], dtype=np.uint32)
for i in xrange(len(images)):
images_np[i, :, :] = images[i]
labels_np[i] = labels[i]

if resize:
logging.info('Resizing images.')
resized_images = resize_images(images_np, new_width, new_height)

logging.info('Writing resized data in float32 format.')
data = {'images': resized_images,
'labels': labels_np,
'info': info}
with tf.gfile.GFile(write_file, 'w') as f:
pickle.dump(data, f)
else:
logging.info('Writing original sized data in boolean format.')
data = {'images': images_np,
'labels': labels_np,
'info': info}
with tf.gfile.GFile(write_file, 'w') as f:
pickle.dump(data, f)

return len(np.unique(labels_np))


def maybe_download_data():
"""Download Omniglot repo if it does not exist."""
if os.path.exists(REPO_DIR):
logging.info('It appears that Git repo already exists.')
else:
logging.info('It appears that Git repo does not exist.')
logging.info('Cloning now.')

subprocess.check_output('git clone %s' % REPO_LOCATION, shell=True)

if os.path.exists(TRAIN_DIR):
logging.info('It appears that train data has already been unzipped.')
else:
logging.info('It appears that train data has not been unzipped.')
logging.info('Unzipping now.')

subprocess.check_output('unzip %s.zip -d %s' % (TRAIN_DIR, DATA_DIR),
shell=True)

if os.path.exists(TEST_DIR):
logging.info('It appears that test data has already been unzipped.')
else:
logging.info('It appears that test data has not been unzipped.')
logging.info('Unzipping now.')

subprocess.check_output('unzip %s.zip -d %s' % (TEST_DIR, DATA_DIR),
shell=True)


def preprocess_omniglot():
"""Download and prepare raw Omniglot data.
Downloads the data from GitHub if it does not exist.
Then load the images, augment with rotations if desired.
Resize the images and write them to a pickle file.
"""

maybe_download_data()

directory = TRAIN_DIR
write_file = DATA_FILE_FORMAT % 'train'
num_labels = write_datafiles(
directory, write_file, resize=True, rotate=TRAIN_ROTATIONS,
new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE)

directory = TEST_DIR
write_file = DATA_FILE_FORMAT % 'test'
write_datafiles(directory, write_file, resize=True, rotate=TEST_ROTATIONS,
new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
first_label=num_labels)


def main(unused_argv):
logging.basicConfig(level=logging.INFO)
preprocess_omniglot()


if __name__ == '__main__':
tf.app.run()
Loading

0 comments on commit 6a9c0da

Please sign in to comment.