Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
adamcavendish committed May 7, 2018
0 parents commit 4f4f1b0
Show file tree
Hide file tree
Showing 27 changed files with 2,194 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
logging/
error.log

__pycache__/
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# StegNet Paper: Mega-Image-Steganography-Capacity-with-Deep-Convolutional-Network
---

## How to create ImageNet Dataset used by StegNet

[[Read the LMDB Creator Doc][./lmdb_creator/README.md]]


## How to run the StegNet Model

Step 1. Setup Environmental Variables:

```bash
export ILSVRC2012_MDB_PATH="<Your Path to Created 'ILSVRC2012_image_train.mdb' Directory>"
```

Step 2. Run the code

```bash
python ./main.py
```

The command line arguments can be tweeked:
```
-h, --help
--train_max_epoch TRAIN_MAX_EPOCH
--batch_size BATCH_SIZE
--restart # Restart from scratch
--global_mode {train,inference}
```

17 changes: 17 additions & 0 deletions dataset_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pathlib
import io
import math
import multiprocessing as mp

import numpy as np
from PIL import Image

from . import ilsvrc2012

_dispatcher = {
'ILSVRC2012': ilsvrc2012.DatasetILSVRC2012
}


def get_dataset_by_name(name):
return _dispatcher[name]
24 changes: 24 additions & 0 deletions dataset_tools/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
class Dataset(object):
def __init__(self, train_ratio=0.8, seed=42):
raise NotImplementedError

def get_name(self):
raise NotImplementedError

def get_shape(self):
raise NotImplementedError

def get_whole_size(self):
raise NotImplementedError

def get_train_size(self):
raise NotImplementedError

def get_valid_size(self):
raise NotImplementedError

def fetch_train_data(self, batch_size):
raise NotImplementedError

def fetch_valid_data(self, batch_size):
raise NotImplementedError
61 changes: 61 additions & 0 deletions dataset_tools/ilsvrc2012.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import io
import os

import numpy as np
from PIL import Image

import lmdb
import msgpack

from .dataset import Dataset

ILSVRC2012_MDB_PATH = os.environ['ILSVRC2012_MDB_PATH']

class DatasetILSVRC2012(Dataset):
def __init__(self, train_ratio=0.8, seed=42):
self.seed = seed
self.mdb_path = ILSVRC2012_MDB_PATH
self.env = lmdb.open(self.mdb_path, readonly=True)
self.whole_size = self.env.stat()['entries']
self.train_size = int(self.whole_size * train_ratio)
self.valid_size = self.whole_size - self.train_size
self.inrows, self.incols, self.incnls = 64, 64, 3

def get_name(self):
return 'ILSVRC2012'

def get_shape(self):
return self.inrows, self.incols, self.incnls

def get_whole_size(self):
return self.whole_size

def get_train_size(self):
return self.train_size

def get_valid_size(self):
return self.valid_size

def _fetch_data_in_range(self, batch_size, lower_bound, upper_bound):
# Image is normalized to [-1, 1]
np.random.seed(self.seed)
rand_range = np.arange(lower_bound, upper_bound)
with self.env.begin() as txn:
while True:
image_v = np.zeros(shape=(batch_size, self.inrows, self.incols, self.incnls))
image_idx = np.random.choice(rand_range, size=batch_size)
for index in range(batch_size):
image_rawd = txn.get('{:08d}'.format(image_idx[index]).encode())
image_info = msgpack.unpackb(image_rawd, encoding='utf-8')
with Image.open(io.BytesIO(image_info['image'])) as im:
im = im.resize((self.inrows, self.incols), Image.ANTIALIAS)
image_data = np.array(im)
image_v[index, :, :, :] = image_data
image_v = image_v / 255. * 2 - 1
yield image_v

def fetch_train_data(self, batch_size):
return self._fetch_data_in_range(batch_size, 0, self.train_size)

def fetch_valid_data(self, batch_size):
return self._fetch_data_in_range(batch_size, self.train_size, self.whole_size)
5 changes: 5 additions & 0 deletions generators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
'''
Generators
'''

from .dataset_generator import dataset_generator
32 changes: 32 additions & 0 deletions generators/dataset_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
'''
Input Image Data Generators
'''

import contextlib
import queue as Q
import time

import params
import utils


def dataset_generator(queue, dataset, mode, role, batch_size):
'''
Generate image data from dataset
'''
if mode == 'train':
dgen = dataset.fetch_train_data
else:
dgen = dataset.fetch_valid_data

# queue-name: one of ['covr/train', 'hide/train', 'covr/valid', 'hide/valid']
qname = '{}/{}'.format(role, mode)

for image in dgen(batch_size):
if params.SHOULD_FINISH.value:
break
with contextlib.suppress(Q.Full):
queue[qname].put(image, timeout=params.QUEUE_TIMEOUT)
# Setup queue to allow exit without flushing all the data to the pipe
queue[qname].cancel_join_thread()
utils.eprint('dataset_generator(%s/%s): exit' % (role, mode))
19 changes: 19 additions & 0 deletions lmdb_creator/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Create lmdb files of ImageNet 2012 Contest

Step 1. Download and Extract ILSVRC2012 Dataset

Step 2. Make sure that you have `ILSVRC2012_devkit_t12` and `ILSVRC2012_img_train` directory

Step 3. Set up environment variables:

```bash
export IMAGE_DIR="<Your ILSVRC2012_img_train Directory Path>"
export DK_DIR="<Your ILSVRC2012_devkit_t12 Directory Path>"
export MDB_OUT_DIR="<Your Expected Directory for Generating LMDB File>" # Note: Reserve 60GB at least
```

Step 4. Run the python script

```bash
python ./images2lmdb.py
```
171 changes: 171 additions & 0 deletions lmdb_creator/images2lmdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
'''
Generate ILSVRC2012 Dataset LMDB file
'''
import io
import os
import pathlib
import struct
import sys
import time

import PIL.Image
import lmdb
import msgpack

import scipy.io
import cytoolz as tz
import numpy as np

# Prepare
IMAGE_DIR = os.environ['IMAGE_DIR']
DK_DIR = os.environ['DK_DIR']
MDB_OUT_DIR = os.environ['MDB_OUT_DIR']

seed = 42
np.random.seed(seed)

lmdb_map_size = 50*1024*1024*1024
lmdb_txn_size = 500

# Setup PATHs
META_PATH = os.path.join(DK_DIR, 'data', 'meta.mat')
META_MP_PATH = os.path.join(MDB_OUT_DIR, 'meta.msgpack')
LMDB_PATH = os.path.join(MDB_OUT_DIR, 'ILSVRC2012_image_train.mdb')

# Generate meta.msgpack
meta = scipy.io.loadmat(META_PATH, squeeze_me=True)
synsets = meta['synsets']

meta_info = [{
'ILSVRC2012_ID': int(s['ILSVRC2012_ID']),
'WNID': str(s['WNID']),
'words': str(s['words']),
'gloss': str(s['gloss']),
'wordnet_height': int(s['wordnet_height']),
'num_train_images': int(s['num_train_images'])
} for s in synsets]

meta_info_packed = msgpack.packb(meta_info, use_bin_type=True)

with open(META_MP_PATH, 'wb') as f:
f.write(meta_info_packed)

# Generate LMDB
def make_context():
return {
'image_id': 0,
'clock_beg': time.time(),
'clock_end': time.time(),
}


def process_image_one(txn, image_id, wordnet_id, label, image_abspath):
'''
txn: lmdb transaction object
image_id: int
The image id, increasing index
wordnet_id: str
The wordnet id, i.e. n07711569
image_abspath: str
The image's absolute path
'''
with PIL.Image.open(image_abspath) as im, io.BytesIO() as bio:
if im.mode != 'RGB':
im = im.convert('RGB')
rows, cols = im.size
cnls = 3
im.resize((256, 256))
im.save(bio, format='webp')
image_bytes = bio.getvalue()

filename = os.path.basename(image_abspath).rstrip('.JPEG')

info = {
'wordnet_id': wordnet_id,
'filename': filename,
'image': image_bytes,
'rows': rows,
'cols': cols,
'cnls': cnls,
'label': label,
}
key = '{:08d}'.format(image_id).encode()
txn.put(key, msgpack.packb(info, use_bin_type=True))


def imagenet_walk(wnid_meta_map, image_Dir):
def get_category_image_abspaths(Path):
return [str(f.absolute()) for f in Path.iterdir() if f.is_file()]

def process_category_one(count, category_Path):
wordnet_id = category_Path.name
metainfo = wnid_meta_map[wordnet_id]
words = metainfo['words']
gloss = metainfo['gloss']
label = metainfo['ILSVRC2012_ID']

print('Process count=%d, label=%d, wordnet_id=%s' % (count, label, wordnet_id))
print(' %s: %s' % (words, gloss))
for image_abspath in get_category_image_abspaths(category_Path):
yield {
'label': label,
'wordnet_id': wordnet_id,
'image_abspath': image_abspath
}

categories = [d for d in image_Dir.iterdir() if d.is_dir()]

image_files = [
image_info
for count, category_Path in enumerate(categories)
for image_info in process_category_one(count, category_Path)
]
return image_files


def process_images(ctx, lmdb_env, image_infos, image_total):
image_id = ctx['image_id']

with lmdb_env.begin(write=True) as txn:
for image_info in image_infos:
wordnet_id = image_info['wordnet_id']
label = image_info['label']
image_abspath = image_info['image_abspath']
process_image_one(txn, image_id, wordnet_id, label, image_abspath)
image_id = image_id + 1

clock_beg = ctx['clock_beg']
clock_end = time.time()

elapse = clock_end - clock_beg
elapse_h = int(elapse) // 60 // 60
elapse_m = int(elapse) // 60 % 60
elapse_s = int(elapse) % 60

estmt = (image_total - image_id) / image_id * elapse
estmt_h = int(estmt) // 60 // 60
estmt_m = int(estmt) // 60 % 60
estmt_s = int(estmt) % 60

labels = [image_info['label'] for image_info in image_infos]
print('ImageId: {:8d}/{:8d}, time: {:2d}h/{:2d}m/{:2d}s, remain: {:2d}h/{:2d}m/{:2d}s, Sample: {} ...'.format(
image_id, image_total,
elapse_h, elapse_m, elapse_s,
estmt_h, estmt_m, estmt_s,
str(labels)[:80]))

ctx['image_id'] = image_id
ctx['clock_end'] = clock_end


wnid_meta_map = { m['WNID']: m for m in meta_info }

image_train_env = lmdb.open(LMDB_PATH, map_size=lmdb_map_size)

image_infos = imagenet_walk(wnid_meta_map, pathlib.Path(IMAGE_DIR))
image_total = len(image_infos)
np.random.shuffle(image_infos)

ctx = make_context()
for image_infos_partial in tz.partition_all(lmdb_txn_size, image_infos):
process_images(ctx, image_train_env, image_infos_partial, image_total)
Loading

0 comments on commit 4f4f1b0

Please sign in to comment.