-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
atiyo
committed
Jan 6, 2018
0 parents
commit 24b15d3
Showing
223 changed files
with
996 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# PyTorch Deep Image Prior | ||
|
||
An implementation of image reconstruction methods from [Deep Image Prior (Ulyanov et al., 2017)](https://arxiv.org/abs/1711.10925) in PyTorch. | ||
|
||
The point of the paper is to execute some common image manipulation tasks using neural networks untrained on data prior to use. | ||
|
||
Architectures differ from those used in the actual paper. The authors use some specific networks for specific tasks. This repo uses a couple of alternative architectures to produce similar results. One where upsampling is done on the basis of pixel shuffling, and the other using transposed convolutions. Pixel shuffling results in some hotspots that do not disappear with further training. | ||
|
||
## Requirements | ||
|
||
Python3 with PyTorch, torchvision and NumPy. CUDA and cuDNN are optional (settable within the script in a self-explanatory way) but strongly recommended. | ||
|
||
## To use | ||
|
||
It's relatively easy to play around with the settings from within the scripts. To reproduce the results in the repo, do the following. | ||
|
||
Make a directory to hold the network output: | ||
```bash | ||
mkdir output | ||
``` | ||
|
||
Generate output images with: | ||
```bash | ||
python3 deep_image_prior.py | ||
``` | ||
|
||
Consolidate output images into a training gif and sample some actual data with: | ||
```bash | ||
python3 parse_ec2_results.py | ||
``` | ||
|
||
## Results | ||
|
||
Note that the images here have been downsampled for formatting sensibly in the README. Full sized samples are in the repo if you would like to have a closer look. | ||
|
||
Training was done over 25k iterations on an Amazon GPU instance. Takes roughly an hour on 512x512 images. | ||
|
||
### Upsampling with transposed convolutions: | ||
Note the grid like speckles during training. These are caused by convolutional kernels overlapping with one another during upsampling. | ||
Ground truth | Input | Output | Training | ||
------------ | ----- | ------ | -------- | ||
data:image/s3,"s3://crabby-images/6353f/6353fce6c09e3da08297107eed140684ca2d8924" alt="truth"|data:image/s3,"s3://crabby-images/23ac4/23ac437a463bb233fe948181495fcc299d61a0b1" alt="deconstructed"|data:image/s3,"s3://crabby-images/013c3/013c32f7477b4d026714bdd87426d75bac6213da" alt="actuals"|data:image/s3,"s3://crabby-images/e5a7f/e5a7ff310b8b41ca6c92ab95cc5d33c335af5ccf" alt="actuals" | ||
|
||
### Upsampling with pixel shuffling: | ||
No speckles, however there is a hotspot (in the out of focus region towards the bunny's behind) that becomes a black spot. The appearance of these hotspots seems commonplace through both architectures, butt he extra smoothness given by the convolution transpose layers repairs these more effectively. | ||
Ground truth | Input | Output | Training | ||
------------ | ----- | ------ | -------- | ||
data:image/s3,"s3://crabby-images/49b6c/49b6c10ca5d7e2c395e9b10af41e512a214f0d61" alt="truth"|data:image/s3,"s3://crabby-images/8ed3c/8ed3c9b004fac1f23933523d5213b9f1ff839684" alt="deconstructed"|data:image/s3,"s3://crabby-images/6ce77/6ce7727895e49b3173847e40fdfeb57cb0c8b4e6" alt="actuals"|data:image/s3,"s3://crabby-images/82ba8/82ba8d2b675ca5797a05eda4cbff669af13d6a1b" alt="actuals" |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
import torch.nn.functional as F | ||
import torchvision | ||
from PIL import Image | ||
|
||
|
||
#use cuda, or not? be prepared for a long wait if you don't have cuda capabilities. | ||
use_cuda = True | ||
#input image. the architectures have been designed for 512x512 colour images | ||
ground_truth_path = 'bunny_512.jpg' | ||
#proportion of pixels to black out. | ||
prop = 0.5 | ||
#standard deviation of added noise after each training set | ||
sigma = 1./30 | ||
#number of training iterations | ||
num_steps = 25001 | ||
#number of steps to take before saving an output image | ||
save_frequency = 250 | ||
#where to put the output | ||
output_name = 'output/output' | ||
#choose either 'pixel_shuffle' or 'deconv' as the architecture used. | ||
method = 'pixel_shuffle' | ||
|
||
#accept a file path to a jpg, return a torch tensor | ||
def jpg_to_tensor(filepath=ground_truth_path): | ||
pil = Image.open(ground_truth_path) | ||
pil_to_tensor = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) | ||
if use_cuda: | ||
tensor = pil_to_tensor(pil).cuda() | ||
else: | ||
tensor = pil_to_tensor(pil) | ||
return tensor.view([1]+list(tensor.shape)) | ||
|
||
#accept a torch tensor, convert it to a jpg at a certain path | ||
def tensor_to_jpg(tensor, filename): | ||
tensor = tensor.view(tensor.shape[1:]) | ||
if use_cuda: | ||
tensor = tensor.cpu() | ||
tensor_to_pil = torchvision.transforms.Compose([torchvision.transforms.ToPILImage()]) | ||
pil = tensor_to_pil(tensor) | ||
pil.save(filename) | ||
|
||
#function which zeros out a random proportion of pixels from an image tensor. | ||
def zero_out_pixels(tensor, prop=prop): | ||
if use_cuda: | ||
mask = torch.rand([1]+[1] + list(tensor.shape[2:])).cuda() | ||
else: | ||
mask = torch.rand([1]+[1] + list(tensor.shape[2:])) | ||
mask[mask<prop] = 0 | ||
mask[mask!=0] = 1 | ||
mask = mask.repeat(1,3,1,1) | ||
deconstructed = tensor * mask | ||
return mask, deconstructed | ||
|
||
#define an encoder decoder network with pixel shuffle upsampling | ||
class pixel_shuffle_hourglass(nn.Module): | ||
def __init__(self): | ||
super(pixel_shuffle_hourglass, self).__init__() | ||
self.d_conv_1 = nn.Conv2d(3, 8, 5, stride=2, padding=2) | ||
self.d_bn_1 = nn.BatchNorm2d(8) | ||
|
||
self.d_conv_2 = nn.Conv2d(8, 16, 5, stride=2, padding=2) | ||
self.d_bn_2 = nn.BatchNorm2d(16) | ||
|
||
self.d_conv_3 = nn.Conv2d(16, 32, 5, stride=2, padding=2) | ||
self.d_bn_3 = nn.BatchNorm2d(32) | ||
self.s_conv_3 = nn.Conv2d(32, 4, 5, stride=1, padding=2) | ||
|
||
self.d_conv_4 = nn.Conv2d(32, 64, 5, stride=2, padding=2) | ||
self.d_bn_4 = nn.BatchNorm2d(64) | ||
self.s_conv_4 = nn.Conv2d(64, 4, 5, stride=1, padding=2) | ||
|
||
self.d_conv_5 = nn.Conv2d(64, 128, 5, stride=2, padding=2) | ||
self.d_bn_5 = nn.BatchNorm2d(128) | ||
self.s_conv_5 = nn.Conv2d(128, 4, 5, stride=1, padding=2) | ||
|
||
self.d_conv_6 = nn.Conv2d(128, 256, 5, stride=2, padding=2) | ||
self.d_bn_6 = nn.BatchNorm2d(256) | ||
|
||
self.u_conv_5 = nn.Conv2d(68, 128, 5, stride=1, padding=2) | ||
self.u_bn_5 = nn.BatchNorm2d(128) | ||
|
||
self.u_conv_4 = nn.Conv2d(36, 64, 5, stride=1, padding=2) | ||
self.u_bn_4 = nn.BatchNorm2d(64) | ||
|
||
self.u_conv_3 = nn.Conv2d(20, 32, 5, stride=1, padding=2) | ||
self.u_bn_3 = nn.BatchNorm2d(32) | ||
|
||
self.u_conv_2 = nn.Conv2d(8, 16, 5, stride=1, padding=2) | ||
self.u_bn_2 = nn.BatchNorm2d(16) | ||
|
||
self.u_conv_1 = nn.Conv2d(4, 16, 5, stride=1, padding=2) | ||
self.u_bn_1 = nn.BatchNorm2d(16) | ||
|
||
self.out_conv = nn.Conv2d(4, 3, 5, stride=1, padding=2) | ||
self.out_bn = nn.BatchNorm2d(3) | ||
|
||
|
||
def forward(self, noise): | ||
down_1 = self.d_conv_1(noise) | ||
down_1 = self.d_bn_1(down_1) | ||
down_1 = F.leaky_relu(down_1) | ||
|
||
down_2 = self.d_conv_2(down_1) | ||
down_2 = self.d_bn_2(down_2) | ||
down_2 = F.leaky_relu(down_2) | ||
|
||
down_3 = self.d_conv_3(down_2) | ||
down_3 = self.d_bn_3(down_3) | ||
down_3 = F.leaky_relu(down_3) | ||
skip_3 = self.s_conv_3(down_3) | ||
|
||
down_4 = self.d_conv_4(down_3) | ||
down_4 = self.d_bn_4(down_4) | ||
down_4 = F.leaky_relu(down_4) | ||
skip_4 = self.s_conv_4(down_4) | ||
|
||
down_5 = self.d_conv_5(down_4) | ||
down_5 = self.d_bn_5(down_5) | ||
down_5 = F.leaky_relu(down_5) | ||
skip_5 = self.s_conv_5(down_5) | ||
|
||
down_6 = self.d_conv_6(down_5) | ||
down_6 = self.d_bn_6(down_6) | ||
down_6 = F.leaky_relu(down_6) | ||
|
||
up_5 = F.pixel_shuffle(down_6, 2) | ||
up_5 = torch.cat([up_5, skip_5], 1) | ||
up_5 = self.u_conv_5(up_5) | ||
up_5 = self.u_bn_5(up_5) | ||
up_5 = F.leaky_relu(up_5) | ||
|
||
up_4 = F.pixel_shuffle(up_5, 2) | ||
up_4 = torch.cat([up_4, skip_4], 1) | ||
up_4 = self.u_conv_4(up_4) | ||
up_4 = self.u_bn_4(up_4) | ||
up_4 = F.leaky_relu(up_4) | ||
|
||
up_3 = F.pixel_shuffle(up_4, 2) | ||
up_3 = torch.cat([up_3, skip_3], 1) | ||
up_3 = self.u_conv_3(up_3) | ||
up_3 = self.u_bn_3(up_3) | ||
up_3 = F.leaky_relu(up_3) | ||
|
||
up_2 = F.pixel_shuffle(up_3, 2) | ||
up_2 = self.u_conv_2(up_2) | ||
up_2 = self.u_bn_2(up_2) | ||
up_2 = F.leaky_relu(up_2) | ||
|
||
up_1 = F.pixel_shuffle(up_2, 2) | ||
up_1 = self.u_conv_1(up_1) | ||
up_1 = self.u_bn_1(up_1) | ||
up_1 = F.leaky_relu(up_1) | ||
|
||
out = F.pixel_shuffle(up_1, 2) | ||
out = self.out_conv(out) | ||
out = self.out_bn(out) | ||
out = F.sigmoid(out) | ||
return out | ||
|
||
#define an encoder decoder network with convolution transpose upsampling. | ||
class deconv_hourglass(nn.Module): | ||
def __init__(self): | ||
super(deconv_hourglass, self).__init__() | ||
self.d_conv_1 = nn.Conv2d(3, 8, 5, stride=2, padding=2) | ||
self.d_bn_1 = nn.BatchNorm2d(8) | ||
|
||
self.d_conv_2 = nn.Conv2d(8, 16, 5, stride=2, padding=2) | ||
self.d_bn_2 = nn.BatchNorm2d(16) | ||
|
||
self.d_conv_3 = nn.Conv2d(16, 32, 5, stride=2, padding=2) | ||
self.d_bn_3 = nn.BatchNorm2d(32) | ||
self.s_conv_3 = nn.Conv2d(32, 4, 5, stride=1, padding=2) | ||
|
||
self.d_conv_4 = nn.Conv2d(32, 64, 5, stride=2, padding=2) | ||
self.d_bn_4 = nn.BatchNorm2d(64) | ||
self.s_conv_4 = nn.Conv2d(64, 4, 5, stride=1, padding=2) | ||
|
||
self.d_conv_5 = nn.Conv2d(64, 128, 5, stride=2, padding=2) | ||
self.d_bn_5 = nn.BatchNorm2d(128) | ||
self.s_conv_5 = nn.Conv2d(128, 4, 5, stride=1, padding=2) | ||
|
||
self.d_conv_6 = nn.Conv2d(128, 256, 5, stride=2, padding=2) | ||
self.d_bn_6 = nn.BatchNorm2d(256) | ||
|
||
self.u_deconv_5 = nn.ConvTranspose2d(256, 124, 4, stride=2, padding=1) | ||
self.u_bn_5 = nn.BatchNorm2d(128) | ||
|
||
self.u_deconv_4 = nn.ConvTranspose2d(128, 60, 4, stride=2, padding=1) | ||
self.u_bn_4 = nn.BatchNorm2d(64) | ||
|
||
self.u_deconv_3 = nn.ConvTranspose2d(64, 28, 4, stride=2, padding=1) | ||
self.u_bn_3 = nn.BatchNorm2d(32) | ||
|
||
self.u_deconv_2 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1) | ||
self.u_bn_2 = nn.BatchNorm2d(16) | ||
|
||
self.u_deconv_2 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1) | ||
self.u_bn_2 = nn.BatchNorm2d(16) | ||
|
||
self.u_deconv_1 = nn.ConvTranspose2d(16, 8, 4, stride=2, padding=1) | ||
self.u_bn_1 = nn.BatchNorm2d(8) | ||
|
||
self.out_deconv = nn.ConvTranspose2d(8, 3, 4, stride=2, padding=1) | ||
self.out_bn = nn.BatchNorm2d(3) | ||
|
||
|
||
def forward(self, noise): | ||
down_1 = self.d_conv_1(noise) | ||
down_1 = self.d_bn_1(down_1) | ||
down_1 = F.leaky_relu(down_1) | ||
|
||
down_2 = self.d_conv_2(down_1) | ||
down_2 = self.d_bn_2(down_2) | ||
down_2 = F.leaky_relu(down_2) | ||
|
||
down_3 = self.d_conv_3(down_2) | ||
down_3 = self.d_bn_3(down_3) | ||
down_3 = F.leaky_relu(down_3) | ||
skip_3 = self.s_conv_3(down_3) | ||
|
||
down_4 = self.d_conv_4(down_3) | ||
down_4 = self.d_bn_4(down_4) | ||
down_4 = F.leaky_relu(down_4) | ||
skip_4 = self.s_conv_4(down_4) | ||
|
||
down_5 = self.d_conv_5(down_4) | ||
down_5 = self.d_bn_5(down_5) | ||
down_5 = F.leaky_relu(down_5) | ||
skip_5 = self.s_conv_5(down_5) | ||
|
||
down_6 = self.d_conv_6(down_5) | ||
down_6 = self.d_bn_6(down_6) | ||
down_6 = F.leaky_relu(down_6) | ||
|
||
up_5 = self.u_deconv_5(down_6) | ||
up_5 = torch.cat([up_5, skip_5], 1) | ||
up_5 = self.u_bn_5(up_5) | ||
up_5 = F.leaky_relu(up_5) | ||
|
||
up_4 = self.u_deconv_4(up_5) | ||
up_4 = torch.cat([up_4, skip_4], 1) | ||
up_4 = self.u_bn_4(up_4) | ||
up_4 = F.leaky_relu(up_4) | ||
|
||
up_3 = self.u_deconv_3(up_4) | ||
up_3 = torch.cat([up_3, skip_3], 1) | ||
up_3 = self.u_bn_3(up_3) | ||
up_3 = F.leaky_relu(up_3) | ||
|
||
up_2 = self.u_deconv_2(up_3) | ||
up_2 = self.u_bn_2(up_2) | ||
up_2 = F.leaky_relu(up_2) | ||
|
||
up_1 = self.u_deconv_1(up_2) | ||
up_1 = self.u_bn_1(up_1) | ||
up_1 = F.leaky_relu(up_1) | ||
|
||
out = self.out_deconv(up_1) | ||
out = self.out_bn(out) | ||
out = F.sigmoid(out) | ||
|
||
return out | ||
|
||
if __name__=='__main__': | ||
#import image | ||
truth = jpg_to_tensor(ground_truth_path) | ||
#deconstruct image | ||
mask, deconstructed = zero_out_pixels(truth) | ||
#save the deconstructed image | ||
tensor_to_jpg(deconstructed, 'deconstructed.jpg') | ||
#convert the image and mask to variables. | ||
mask = Variable(mask) | ||
deconstructed = Variable(deconstructed) | ||
|
||
#input of the network is noise | ||
if use_cuda: | ||
noise = Variable(torch.randn(deconstructed.shape).cuda()) | ||
else: | ||
noise = Variable(torch.randn(deconstructed.shape)) | ||
|
||
#initialise the network with the chosen architecture | ||
if method=='pixel_shuffle': | ||
net = pixel_shuffle_hourglass() | ||
elif method=='deconv': | ||
net = deconv_hourglass() | ||
|
||
#bind the network to the gpu if cuda is enabled | ||
if use_cuda: | ||
net.cuda() | ||
#network optimizer set up | ||
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) | ||
|
||
#dummy index to provide names to output files | ||
save_img_ind = 0 | ||
for step in range(num_steps): | ||
#get the network output | ||
output = net(noise) | ||
#we are only concerned with the output where we have the image available. | ||
masked_output = output * mask | ||
# calculate the l2_loss over the masked output and take an optimizer step | ||
optimizer.zero_grad() | ||
loss = torch.sum((masked_output - deconstructed)**2) | ||
loss.backward() | ||
optimizer.step() | ||
print('At step {}, loss is {}'.format(step, loss.data.cpu())) | ||
#every save_frequency steps, save a jpg | ||
if step % save_frequency == 0: | ||
tensor_to_jpg(output.data,output_name+'_{}.jpg'.format(save_img_ind)) | ||
save_img_ind += 1 | ||
if use_cuda: | ||
noise.data += sigma * torch.randn(noise.shape).cuda() | ||
else: | ||
noise.data += sigma * torch.randn(noise.shape) | ||
|
||
#clean up any mess we're leaving on the gpu | ||
if use_cuda: | ||
torch.cuda.empty_cache() |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.