-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathpostprocessing_functions.py
83 lines (59 loc) · 2.85 KB
/
postprocessing_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import numpy as np
import utils.data_format_utils as df_utils
from data_processing.camera_pipeline import apply_gains, apply_ccm, apply_smoothstep, gamma_compression
class SimplePostProcess:
def __init__(self, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False):
self.gains = gains
self.ccm = ccm
self.gamma = gamma
self.smoothstep = smoothstep
self.return_np = return_np
def process(self, image, meta_info):
return process_linear_image_rgb(image, meta_info, self.gains, self.ccm, self.gamma,
self.smoothstep, self.return_np)
def process_linear_image_rgb(image, meta_info, gains=True, ccm=True, gamma=True, smoothstep=True, return_np=False):
if gains:
image = apply_gains(image, meta_info['rgb_gain'], meta_info['red_gain'], meta_info['blue_gain'])
if ccm:
image = apply_ccm(image, meta_info['cam2rgb'])
if meta_info['gamma'] and gamma:
image = gamma_compression(image)
if meta_info['smoothstep'] and smoothstep:
image = apply_smoothstep(image)
image = image.clamp(0.0, 1.0)
if return_np:
image = df_utils.torch_to_npimage(image)
return image
class BurstSRPostProcess:
def __init__(self, no_white_balance=False, gamma=True, smoothstep=True, return_np=False):
self.no_white_balance = no_white_balance
self.gamma = gamma
self.smoothstep = smoothstep
self.return_np = return_np
def process(self, image, meta_info, external_norm_factor=None):
return process_burstsr_image_rgb(image, meta_info, external_norm_factor=external_norm_factor,
no_white_balance=self.no_white_balance, gamma=self.gamma,
smoothstep=self.smoothstep, return_np=self.return_np)
def process_burstsr_image_rgb(im, meta_info, return_np=False, external_norm_factor=None, gamma=True, smoothstep=True,
no_white_balance=False):
im = im * meta_info.get('norm_factor', 1.0)
if not meta_info.get('black_level_subtracted', False):
im = (im - torch.tensor(meta_info['black_level'])[[0, 1, -1]].view(3, 1, 1))
if not meta_info.get('while_balance_applied', False) and not no_white_balance:
im = im * torch.tensor(meta_info['cam_wb'])[[0, 1, -1]].view(3, 1, 1) / torch.tensor(meta_info['cam_wb'])[1]
im_out = im
if external_norm_factor is None:
im_out = im_out / im_out.max()
else:
im_out = im_out / external_norm_factor
im_out = im_out.clamp(0.0, 1.0)
if gamma:
im_out = im_out ** (1.0 / 2.2)
if smoothstep:
# Smooth curve
im_out = 3 * im_out ** 2 - 2 * im_out ** 3
if return_np:
im_out = im_out.permute(1, 2, 0).numpy() * 255.0
im_out = im_out.astype(np.uint8)
return im_out