-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4f4f1b0
Showing
27 changed files
with
2,194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
logging/ | ||
error.log | ||
|
||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# StegNet Paper: Mega-Image-Steganography-Capacity-with-Deep-Convolutional-Network | ||
--- | ||
|
||
## How to create ImageNet Dataset used by StegNet | ||
|
||
[[Read the LMDB Creator Doc][./lmdb_creator/README.md]] | ||
|
||
|
||
## How to run the StegNet Model | ||
|
||
Step 1. Setup Environmental Variables: | ||
|
||
```bash | ||
export ILSVRC2012_MDB_PATH="<Your Path to Created 'ILSVRC2012_image_train.mdb' Directory>" | ||
``` | ||
|
||
Step 2. Run the code | ||
|
||
```bash | ||
python ./main.py | ||
``` | ||
|
||
The command line arguments can be tweeked: | ||
``` | ||
-h, --help | ||
--train_max_epoch TRAIN_MAX_EPOCH | ||
--batch_size BATCH_SIZE | ||
--restart # Restart from scratch | ||
--global_mode {train,inference} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import pathlib | ||
import io | ||
import math | ||
import multiprocessing as mp | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
||
from . import ilsvrc2012 | ||
|
||
_dispatcher = { | ||
'ILSVRC2012': ilsvrc2012.DatasetILSVRC2012 | ||
} | ||
|
||
|
||
def get_dataset_by_name(name): | ||
return _dispatcher[name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
class Dataset(object): | ||
def __init__(self, train_ratio=0.8, seed=42): | ||
raise NotImplementedError | ||
|
||
def get_name(self): | ||
raise NotImplementedError | ||
|
||
def get_shape(self): | ||
raise NotImplementedError | ||
|
||
def get_whole_size(self): | ||
raise NotImplementedError | ||
|
||
def get_train_size(self): | ||
raise NotImplementedError | ||
|
||
def get_valid_size(self): | ||
raise NotImplementedError | ||
|
||
def fetch_train_data(self, batch_size): | ||
raise NotImplementedError | ||
|
||
def fetch_valid_data(self, batch_size): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import io | ||
import os | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
||
import lmdb | ||
import msgpack | ||
|
||
from .dataset import Dataset | ||
|
||
ILSVRC2012_MDB_PATH = os.environ['ILSVRC2012_MDB_PATH'] | ||
|
||
class DatasetILSVRC2012(Dataset): | ||
def __init__(self, train_ratio=0.8, seed=42): | ||
self.seed = seed | ||
self.mdb_path = ILSVRC2012_MDB_PATH | ||
self.env = lmdb.open(self.mdb_path, readonly=True) | ||
self.whole_size = self.env.stat()['entries'] | ||
self.train_size = int(self.whole_size * train_ratio) | ||
self.valid_size = self.whole_size - self.train_size | ||
self.inrows, self.incols, self.incnls = 64, 64, 3 | ||
|
||
def get_name(self): | ||
return 'ILSVRC2012' | ||
|
||
def get_shape(self): | ||
return self.inrows, self.incols, self.incnls | ||
|
||
def get_whole_size(self): | ||
return self.whole_size | ||
|
||
def get_train_size(self): | ||
return self.train_size | ||
|
||
def get_valid_size(self): | ||
return self.valid_size | ||
|
||
def _fetch_data_in_range(self, batch_size, lower_bound, upper_bound): | ||
# Image is normalized to [-1, 1] | ||
np.random.seed(self.seed) | ||
rand_range = np.arange(lower_bound, upper_bound) | ||
with self.env.begin() as txn: | ||
while True: | ||
image_v = np.zeros(shape=(batch_size, self.inrows, self.incols, self.incnls)) | ||
image_idx = np.random.choice(rand_range, size=batch_size) | ||
for index in range(batch_size): | ||
image_rawd = txn.get('{:08d}'.format(image_idx[index]).encode()) | ||
image_info = msgpack.unpackb(image_rawd, encoding='utf-8') | ||
with Image.open(io.BytesIO(image_info['image'])) as im: | ||
im = im.resize((self.inrows, self.incols), Image.ANTIALIAS) | ||
image_data = np.array(im) | ||
image_v[index, :, :, :] = image_data | ||
image_v = image_v / 255. * 2 - 1 | ||
yield image_v | ||
|
||
def fetch_train_data(self, batch_size): | ||
return self._fetch_data_in_range(batch_size, 0, self.train_size) | ||
|
||
def fetch_valid_data(self, batch_size): | ||
return self._fetch_data_in_range(batch_size, self.train_size, self.whole_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
''' | ||
Generators | ||
''' | ||
|
||
from .dataset_generator import dataset_generator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
''' | ||
Input Image Data Generators | ||
''' | ||
|
||
import contextlib | ||
import queue as Q | ||
import time | ||
|
||
import params | ||
import utils | ||
|
||
|
||
def dataset_generator(queue, dataset, mode, role, batch_size): | ||
''' | ||
Generate image data from dataset | ||
''' | ||
if mode == 'train': | ||
dgen = dataset.fetch_train_data | ||
else: | ||
dgen = dataset.fetch_valid_data | ||
|
||
# queue-name: one of ['covr/train', 'hide/train', 'covr/valid', 'hide/valid'] | ||
qname = '{}/{}'.format(role, mode) | ||
|
||
for image in dgen(batch_size): | ||
if params.SHOULD_FINISH.value: | ||
break | ||
with contextlib.suppress(Q.Full): | ||
queue[qname].put(image, timeout=params.QUEUE_TIMEOUT) | ||
# Setup queue to allow exit without flushing all the data to the pipe | ||
queue[qname].cancel_join_thread() | ||
utils.eprint('dataset_generator(%s/%s): exit' % (role, mode)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Create lmdb files of ImageNet 2012 Contest | ||
|
||
Step 1. Download and Extract ILSVRC2012 Dataset | ||
|
||
Step 2. Make sure that you have `ILSVRC2012_devkit_t12` and `ILSVRC2012_img_train` directory | ||
|
||
Step 3. Set up environment variables: | ||
|
||
```bash | ||
export IMAGE_DIR="<Your ILSVRC2012_img_train Directory Path>" | ||
export DK_DIR="<Your ILSVRC2012_devkit_t12 Directory Path>" | ||
export MDB_OUT_DIR="<Your Expected Directory for Generating LMDB File>" # Note: Reserve 60GB at least | ||
``` | ||
|
||
Step 4. Run the python script | ||
|
||
```bash | ||
python ./images2lmdb.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
''' | ||
Generate ILSVRC2012 Dataset LMDB file | ||
''' | ||
import io | ||
import os | ||
import pathlib | ||
import struct | ||
import sys | ||
import time | ||
|
||
import PIL.Image | ||
import lmdb | ||
import msgpack | ||
|
||
import scipy.io | ||
import cytoolz as tz | ||
import numpy as np | ||
|
||
# Prepare | ||
IMAGE_DIR = os.environ['IMAGE_DIR'] | ||
DK_DIR = os.environ['DK_DIR'] | ||
MDB_OUT_DIR = os.environ['MDB_OUT_DIR'] | ||
|
||
seed = 42 | ||
np.random.seed(seed) | ||
|
||
lmdb_map_size = 50*1024*1024*1024 | ||
lmdb_txn_size = 500 | ||
|
||
# Setup PATHs | ||
META_PATH = os.path.join(DK_DIR, 'data', 'meta.mat') | ||
META_MP_PATH = os.path.join(MDB_OUT_DIR, 'meta.msgpack') | ||
LMDB_PATH = os.path.join(MDB_OUT_DIR, 'ILSVRC2012_image_train.mdb') | ||
|
||
# Generate meta.msgpack | ||
meta = scipy.io.loadmat(META_PATH, squeeze_me=True) | ||
synsets = meta['synsets'] | ||
|
||
meta_info = [{ | ||
'ILSVRC2012_ID': int(s['ILSVRC2012_ID']), | ||
'WNID': str(s['WNID']), | ||
'words': str(s['words']), | ||
'gloss': str(s['gloss']), | ||
'wordnet_height': int(s['wordnet_height']), | ||
'num_train_images': int(s['num_train_images']) | ||
} for s in synsets] | ||
|
||
meta_info_packed = msgpack.packb(meta_info, use_bin_type=True) | ||
|
||
with open(META_MP_PATH, 'wb') as f: | ||
f.write(meta_info_packed) | ||
|
||
# Generate LMDB | ||
def make_context(): | ||
return { | ||
'image_id': 0, | ||
'clock_beg': time.time(), | ||
'clock_end': time.time(), | ||
} | ||
|
||
|
||
def process_image_one(txn, image_id, wordnet_id, label, image_abspath): | ||
''' | ||
txn: lmdb transaction object | ||
image_id: int | ||
The image id, increasing index | ||
wordnet_id: str | ||
The wordnet id, i.e. n07711569 | ||
image_abspath: str | ||
The image's absolute path | ||
''' | ||
with PIL.Image.open(image_abspath) as im, io.BytesIO() as bio: | ||
if im.mode != 'RGB': | ||
im = im.convert('RGB') | ||
rows, cols = im.size | ||
cnls = 3 | ||
im.resize((256, 256)) | ||
im.save(bio, format='webp') | ||
image_bytes = bio.getvalue() | ||
|
||
filename = os.path.basename(image_abspath).rstrip('.JPEG') | ||
|
||
info = { | ||
'wordnet_id': wordnet_id, | ||
'filename': filename, | ||
'image': image_bytes, | ||
'rows': rows, | ||
'cols': cols, | ||
'cnls': cnls, | ||
'label': label, | ||
} | ||
key = '{:08d}'.format(image_id).encode() | ||
txn.put(key, msgpack.packb(info, use_bin_type=True)) | ||
|
||
|
||
def imagenet_walk(wnid_meta_map, image_Dir): | ||
def get_category_image_abspaths(Path): | ||
return [str(f.absolute()) for f in Path.iterdir() if f.is_file()] | ||
|
||
def process_category_one(count, category_Path): | ||
wordnet_id = category_Path.name | ||
metainfo = wnid_meta_map[wordnet_id] | ||
words = metainfo['words'] | ||
gloss = metainfo['gloss'] | ||
label = metainfo['ILSVRC2012_ID'] | ||
|
||
print('Process count=%d, label=%d, wordnet_id=%s' % (count, label, wordnet_id)) | ||
print(' %s: %s' % (words, gloss)) | ||
for image_abspath in get_category_image_abspaths(category_Path): | ||
yield { | ||
'label': label, | ||
'wordnet_id': wordnet_id, | ||
'image_abspath': image_abspath | ||
} | ||
|
||
categories = [d for d in image_Dir.iterdir() if d.is_dir()] | ||
|
||
image_files = [ | ||
image_info | ||
for count, category_Path in enumerate(categories) | ||
for image_info in process_category_one(count, category_Path) | ||
] | ||
return image_files | ||
|
||
|
||
def process_images(ctx, lmdb_env, image_infos, image_total): | ||
image_id = ctx['image_id'] | ||
|
||
with lmdb_env.begin(write=True) as txn: | ||
for image_info in image_infos: | ||
wordnet_id = image_info['wordnet_id'] | ||
label = image_info['label'] | ||
image_abspath = image_info['image_abspath'] | ||
process_image_one(txn, image_id, wordnet_id, label, image_abspath) | ||
image_id = image_id + 1 | ||
|
||
clock_beg = ctx['clock_beg'] | ||
clock_end = time.time() | ||
|
||
elapse = clock_end - clock_beg | ||
elapse_h = int(elapse) // 60 // 60 | ||
elapse_m = int(elapse) // 60 % 60 | ||
elapse_s = int(elapse) % 60 | ||
|
||
estmt = (image_total - image_id) / image_id * elapse | ||
estmt_h = int(estmt) // 60 // 60 | ||
estmt_m = int(estmt) // 60 % 60 | ||
estmt_s = int(estmt) % 60 | ||
|
||
labels = [image_info['label'] for image_info in image_infos] | ||
print('ImageId: {:8d}/{:8d}, time: {:2d}h/{:2d}m/{:2d}s, remain: {:2d}h/{:2d}m/{:2d}s, Sample: {} ...'.format( | ||
image_id, image_total, | ||
elapse_h, elapse_m, elapse_s, | ||
estmt_h, estmt_m, estmt_s, | ||
str(labels)[:80])) | ||
|
||
ctx['image_id'] = image_id | ||
ctx['clock_end'] = clock_end | ||
|
||
|
||
wnid_meta_map = { m['WNID']: m for m in meta_info } | ||
|
||
image_train_env = lmdb.open(LMDB_PATH, map_size=lmdb_map_size) | ||
|
||
image_infos = imagenet_walk(wnid_meta_map, pathlib.Path(IMAGE_DIR)) | ||
image_total = len(image_infos) | ||
np.random.shuffle(image_infos) | ||
|
||
ctx = make_context() | ||
for image_infos_partial in tz.partition_all(lmdb_txn_size, image_infos): | ||
process_images(ctx, image_train_env, image_infos_partial, image_total) |
Oops, something went wrong.