-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLoss.py
156 lines (128 loc) · 6.53 KB
/
Loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -- coding: utf-8 --
"""DNPC/Loss.py: DNPC training loss function."""
import kornia
import torch
import torchmetrics
from Cameras.Perspective import PerspectiveCamera
from Cameras.utils import CameraProperties
from Datasets.Base import BaseDataset
from Framework import ConfigParameterList
from Methods.DNPC.Model import DNPCModel
from Optim.Losses.Base import BaseLoss
from Optim.Losses.FusedDSSIM import fused_dssim
from Optim.Losses.Robust import RobustLoss
from Optim.Losses.VGG import VGGLoss
from Methods.DNPC.CudaExtensions.PointBasedDistortionLoss import point_based_distortion_loss
def monocularDepthLoss(outputs: dict[str, torch.Tensor | None], gt_depth: torch.Tensor | None, fg_mask: torch.Tensor,
camera: PerspectiveCamera) -> torch.Tensor:
"""Loss function on monocular depth estimates.
Args:
outputs (dict[str, torch.Tensor | None]): Dictionary of model outputs.
gt_depth (torch.Tensor): 'Ground truth' metric monocular depth estimates.
fg_mask (torch.Tensor): 'Ground truth' foreground mask.
camera (PerspectiveCamera): Camera object used by dataset.
Returns:
torch.Tensor: Scalar output loss value.
"""
depth = outputs['depth_dynamic']
if depth is None:
return torch.tensor(0.0)
return torch.mean(torch.abs(depth - gt_depth) * fg_mask) / (camera.far_plane - camera.near_plane)
class DistortionLoss:
def __call__(self, outputs: dict[str, torch.Tensor | None], camera: PerspectiveCamera) -> torch.Tensor:
"""Pyhton wrapper for point-based distortion loss CUDA extension.
Args:
outputs (dict[str, torch.Tensor | None]): Model outputs.
camera (PerspectiveCamera): Camera object used by dataset.
Returns:
torch.Tensor: Scalar output loss value.
"""
positions = outputs['extras']['positions_dynamic']
if positions.numel() == 0:
return torch.tensor(0.0)
screen_xy, valid_mask, d = camera.projectPoints(positions)
screen_xy = screen_xy[valid_mask].long()
w = outputs['extras']['blending_weights_dynamic'][valid_mask]
d = (d[valid_mask] - camera.near_plane) / (camera.far_plane - camera.near_plane)
pixel_id = screen_xy[:, 0] + screen_xy[:, 1] * camera.properties.width
n_pixels = camera.properties.width * camera.properties.height
# sort by depth
d, sort_idx = d.sort(stable=True)
w = w[sort_idx]
pixel_id = pixel_id[sort_idx]
# sort so that the pixel_id is increasing
pixel_id, sort_idx = pixel_id.sort(stable=True)
w = w[sort_idx]
d = d[sort_idx]
return point_based_distortion_loss(w, d, pixel_id, n_pixels)
def binaryEntropy(x: torch.Tensor, exponent: float = 1.0) -> torch.Tensor:
"""Binary entropy loss function.
Args:
x (torch.Tensor): Input values in range [0, 1].
exponent (float, optional): Bias for asymetric binary entropy. Defaults to 1.0.
Returns:
torch.Tensor: Output loss value.
"""
x = x.clamp_min(1e-6).pow(exponent).clamp(1e-6, 1.0 - 1e-6)
return -x * torch.log2(x) - (1.0 - x) * torch.log2(1.0 - x)
def dynamicWeightsLoss(outputs: dict[str, torch.Tensor | None]) -> torch.Tensor:
"""Asymmetric binary entropy loss on dynamic weights.
Args:
outputs (dict[str, torch.Tensor | None]): Model outputs.
Returns:
torch.Tensor: Output loss value.
"""
return torch.mean(binaryEntropy(outputs['extras']['fused_dynamic_weights'], exponent=0.5))
class DNPCLoss(BaseLoss):
"""Full DNPC training loss function."""
def __init__(self, loss_config: ConfigParameterList, model: DNPCModel) -> None:
"""Initialize DNPC loss function.
Args:
loss_config (ConfigParameterList): Configuration parameters for loss function (located at TRAINING.LOSS).
model (DNPCModel): DNPC model instance.
"""
super().__init__()
# photometric losses
self.addLossMetric('Pixel', RobustLoss(loss_config.ROBUST_LOSS_ALPHA, loss_config.ROBUST_LOSS_C), loss_config.LAMBDA_PIXEL)
self.addLossMetric('DSSIM', fused_dssim, loss_config.LAMBDA_DSSIM)
self.addLossMetric('VGG', VGGLoss(), loss_config.LAMBDA_VGG)
# regularizers
self.addLossMetric('MonocularDepth', monocularDepthLoss, loss_config.LAMBDA_MONOCULAR_DEPTH)
self.addLossMetric('Distortion', DistortionLoss(), loss_config.LAMBDA_DISTORTION)
self.addLossMetric('DynamicWeights', dynamicWeightsLoss, 0.0)
self.addLossMetric('HashGridWeightDecay', model.dynamic_grid.gridWeightDecay, loss_config.LAMBDA_GRID_DECAY)
self.addLossMetric('MLPWeightDecay', model.dynamic_grid.networkWeightDecay, loss_config.LAMBDA_MLP_DECAY)
# quality metrics for plotting in wandb
self.addQualityMetric('PSNR', torchmetrics.functional.peak_signal_noise_ratio)
# kernel for morphological operations on foreground mask
self.morphology_kernel = torch.ones(loss_config.EROSION_KERNEL_SIZE, loss_config.EROSION_KERNEL_SIZE)
@torch.autocast(enabled=True, dtype=torch.float32, device_type="cuda")
def forward(
self,
outputs: dict[str, torch.Tensor | None],
camera_properties: CameraProperties,
dataset: BaseDataset,
) -> torch.Tensor:
"""Forward pass of DNPC loss function.
Args:
outputs (dict[str, torch.Tensor | None]): Model outputs.
camera_properties (CameraProperties): Camera properties object.
dataset (BaseDataset): Dataset instance used for training.
Returns:
torch.Tensor: Scalar output loss value.
"""
# compute foreground mask
with torch.no_grad():
fg_mask = kornia.morphology.dilation(camera_properties.segmentation[None], self.morphology_kernel)[0]
# run sublosses
return super().forward({
'Pixel': {'input': outputs['rgb'], 'target': camera_properties.rgb},
'DSSIM': {'input': outputs['rgb'], 'target': camera_properties.rgb},
'VGG': {'input': outputs['rgb'], 'target': camera_properties.rgb},
'MonocularDepth': {'outputs': outputs, 'gt_depth': camera_properties.depth, 'fg_mask': fg_mask, 'camera': dataset.camera},
'Distortion': {'outputs': outputs, 'camera': dataset.camera},
'DynamicWeights': {'outputs': outputs},
'HashGridWeightDecay': {},
'MLPWeightDecay': {},
'PSNR': {'preds': outputs['rgb'], 'target': camera_properties.rgb, 'data_range': 1.0}
})