-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainer.py
462 lines (432 loc) · 23.8 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
# -- coding: utf-8 --
"""DNPC/Trainer.py: DNPC training code."""
import random
import kornia
import numpy as np
import torch
from Datasets.utils import saveImage
import Framework
from Datasets.Base import BaseDataset
from Logging import Logger
from Methods.Base.GuiTrainer import GuiTrainer
from Methods.Base.utils import preTrainingCallback, trainingCallback
from Visual.ColorMap import ColorMap
from Visual.utils import pseudoColorDepth
from Methods.DNPC.utils import logScene
from Cameras.NDC import NDCCamera
from Methods.DNPC.Loss import DNPCLoss
from Optim.Samplers.DatasetSamplers import DatasetSampler
from Datasets.Colmap import storePly
from Cameras.utils import RayPropertySlice
@Framework.Configurable.configure(
NUM_ITERATIONS=10000,
WARMUP_ITERATIONS=500,
WANDB=Framework.ConfigParameterList(
RENDER_SCENE=False,
),
PROBABILITY_FIELD=Framework.ConfigParameterList(
DISABLE_INIT=False,
PRUNE_THRESHOLD=0.05,
PRUNE_STRIDE=100,
),
OPTIMIZER=Framework.ConfigParameterList(
STATIC_MODEL=Framework.ConfigParameterList(
LR=1.0e-2,
LR_TARGET=1.0e-2 / 30,
LR_DELAY_ITERATIONS=0,
LR_DELAY_FACTOR=1.0,
),
DYNAMIC_MODEL=Framework.ConfigParameterList(
LR=1.0e-2,
LR_TARGET=1.0e-2 / 30,
LR_DELAY_ITERATIONS=0,
LR_DELAY_FACTOR=1.0,
),
UNET=Framework.ConfigParameterList(
LR=3.0e-4,
LR_TARGET=5.0e-5,
),
),
LOSS=Framework.ConfigParameterList(
LAMBDA_PIXEL=1.0,
ROBUST_LOSS_ALPHA=0.0,
ROBUST_LOSS_C=0.23570452081,
LAMBDA_DSSIM=0.916,
LAMBDA_VGG=0.0516,
LAMBDA_MONOCULAR_DEPTH=20.0,
DEPTH_LOSS_END_ITERATION=500,
LAMBDA_DISTORTION=1.0,
LAMBDA_DYNAMIC_WEIGHTS=0.001,
DYNAMIC_WEIGHTS_START_ITERATION=0,
LAMBDA_GRID_DECAY=2.46,
LAMBDA_MLP_DECAY=0.00098,
EROSION_KERNEL_SIZE=3,
),
DEPTH_ALIGNMENT=Framework.ConfigParameterList(
READ_FROM_DISK=True,
WRITE_TO_DISK=True,
DEGREE=1,
NUM_SAMPLES=16,
NUM_ITERATIONS=1000,
INLIER_TOLERANCE=1e-3,
USE_STATIC_MASK=True,
RECALCULATE_INLIERS=False,
GLOBAL_REFINEMENT_STEPS=0,
GLOBAL_REFINEMENT_WINDOW=2,
GLOBAL_REFINEMENT_DEPTH_SUBTRACTION=0.00,
),
BACKUP=Framework.ConfigParameterList(
INITIAL_SAMPLING_WEIGHTS=False,
)
)
class DNPCTrainer(GuiTrainer):
"""DNPC training code."""
def __init__(self, **kwargs) -> None:
"""Initializes the DNPC trainer.
"""
super().__init__(**kwargs)
# build optimizer and scheduler
param_groups, schedulers = self.model.getOptimizerParamGroups(config=self.OPTIMIZER, max_iterations=self.NUM_ITERATIONS)
self.optimizer = torch.optim.AdamW(param_groups, 1.0, eps=1.0e-15, betas=(0.9, 0.99),
weight_decay=0.0, foreach=False, fused=True)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, schedulers)
self.scaler = torch.GradScaler(device='cuda', init_scale=128.0, growth_interval=self.NUM_ITERATIONS + 1)
# init loss
self.loss = DNPCLoss(loss_config=self.LOSS, model=self.model)
# dataset sampler
self.train_sampler = None
@preTrainingCallback(priority=10000)
def createSampler(self, _, dataset: 'BaseDataset') -> None:
"""Creates the sampler."""
self.train_sampler = DatasetSampler(dataset=dataset.train(), random=True)
self.scheduler.lr_lambdas[1].set_decay_stride(len(dataset.train())) # decay dynamic lr only once per epoch
@torch.no_grad()
@preTrainingCallback(priority=5000)
def initOcclusionMasks(self, _, dataset: 'BaseDataset') -> None:
""" Generates (dis-)occlusion masks for each frame using optical flow forward and backward consistency.
Args:
_ (int): training iteration (unused)
dataset (BaseDataset): Dataset used for training
"""
CONSISTENCY_THRESHOLD = 1.0
Logger.logInfo('Initializing occlusion masks...')
dataset.addCameraPropertyFields([('occlusion_mask', torch.Tensor, None)])
dataset.train()
for i in Logger.logProgressBar(range(len(dataset)), leave=False, desc='Initializing occlusion masks'):
sample_current = dataset[i]
dataset.camera.setProperties(sample_current)
x_direction, y_direction = dataset.camera.getPixelCoordinates()
sample_last = dataset[i - 1] if i > 0 else None
sample_next = dataset[i + 1] if i < len(dataset) - 1 else None
occlusion_mask = torch.zeros((1, sample_current.height, sample_current.width), device=Framework.config.GLOBAL.DEFAULT_DEVICE, dtype=torch.float32)
if sample_last is not None:
backward_flow = sample_current.backward_flow
forward_flow = sample_last.forward_flow
grid = backward_flow.clone()
grid[:1] = ((((grid[:1] + x_direction) / sample_current.width) - 0.5) * 2.0)
grid[1:2] = ((((grid[1:2] + y_direction) / sample_current.height) - 0.5) * 2.0)
forward_flow_resampled = torch.nn.functional.grid_sample(forward_flow[None], grid.permute(1, 2, 0)[None])[0]
occlusion_mask += (backward_flow + forward_flow_resampled).norm(dim=0, keepdim=True)
if sample_next is not None:
backward_flow = sample_next.backward_flow
forward_flow = sample_current.forward_flow
grid = forward_flow.clone()
grid[:1] = ((((grid[:1] + x_direction) / sample_current.width) - 0.5) * 2.0)
grid[1:2] = ((((grid[1:2] + y_direction) / sample_current.height) - 0.5) * 2.0)
backward_flow_resampled = torch.nn.functional.grid_sample(backward_flow[None], grid.permute(1, 2, 0)[None])[0]
occlusion_mask += (forward_flow + backward_flow_resampled).norm(dim=0, keepdim=True)
dataset.data['train'][i].occlusion_mask = (occlusion_mask < CONSISTENCY_THRESHOLD).float().cpu()
@preTrainingCallback(priority=1000)
@torch.no_grad()
def alignDepthMaps(self, _, dataset: 'BaseDataset') -> None:
"""Globally aligns depth maps with the sfm point cloud.
Args:
_ (int): Training iteration (unused)
dataset (BaseDataset): Dataset used for training
Raises:
MethodError: If the dataset does not contain a sfm point cloud
"""
# check if dataset already contains metric depth maps
if dataset.train()[0].depth is not None:
Logger.logInfo('Depth maps already exist, skipping global alignment...')
return
# helper functions for RANSAC linear system solving
def solve(a, b):
elems = [a.pow(i) for i in range(self.DEPTH_ALIGNMENT.DEGREE + 1)]
return torch.linalg.lstsq(torch.stack(elems, dim=-1), b).solution
def apply(a, x):
b = (torch.stack([a.pow(i) for i in range(self.DEPTH_ALIGNMENT.DEGREE + 1)], dim=-1) * x).sum(dim=-1)
return b
def getError(depths, depth_estim, x):
errors = ((1.0 / (depths + 1e-8)) - (1.0 / (apply(depth_estim, x) + 1e-8))).abs()
return errors.mean(), errors < self.DEPTH_ALIGNMENT.INLIER_TOLERANCE
# load aligned depth maps from previous run, if available
scale_factor = dataset.IMAGE_SCALE_FACTOR if dataset.IMAGE_SCALE_FACTOR is not None else 1.0
prealigned_path = dataset.dataset_path / 'dnpc_aligned_depth' / f'scale_{scale_factor}'
if self.DEPTH_ALIGNMENT.READ_FROM_DISK and prealigned_path.exists():
Logger.logInfo('loading aligned depth maps...')
for i in Logger.logProgressBar(range(len(dataset)), leave=False, desc='Loading depth maps'):
aligned_depth = np.load(str(prealigned_path / f'{i:05d}.npy'))
dataset.data['train'][i].depth = torch.from_numpy(aligned_depth).cpu()
return
# align depth maps
else:
Logger.logInfo('aligning monocular disparity maps with sfm point cloud...')
dataset.train()
# check if point cloud exists
if dataset.point_cloud is None:
raise Framework.MethodError('DNPC requires dataset with an sfm a point cloud to initialize depth maps')
sfm_points = dataset.point_cloud.positions.to(Framework.config.GLOBAL.DEFAULT_DEVICE)
for i in Logger.logProgressBar(range(len(dataset)), leave=False, desc='Initializing depth maps'):
# project point cloud to image and estimate inverse gt depth
dataset.camera.setProperties(dataset[i])
xy, mask, depths = dataset.camera.projectPoints(sfm_points)
xy = xy[mask].long()
depths = 1.0 / depths[mask]
depth_estim_raw = dataset.camera.properties._misc.to(Framework.config.GLOBAL.DEFAULT_DEVICE)
depth_estim = depth_estim_raw[0, xy[:, 1], xy[:, 0]]
# filter matches based on segmentation mask
m = torch.ones_like(depths, dtype=torch.bool)
if self.DEPTH_ALIGNMENT.USE_STATIC_MASK:
bg_mask = kornia.morphology.erosion(1.0 - dataset.camera.properties.segmentation[None], torch.ones(9, 9))[0]
m *= (bg_mask[0, xy[:, 1], xy[:, 0]] > 0.0)
depths = depths[m]
depth_estim = depth_estim[m]
# estim full solution
x_min = solve(depth_estim, depths)
e_min = getError(depths, depth_estim, x_min)
Logger.logDebug(f'{i:05d}: initial inliers {e_min[1].float().mean() * 100:.2f}%, x={x_min}')
# RANSAC to filter outliers
for _ in range(self.DEPTH_ALIGNMENT.NUM_ITERATIONS):
indices = random.sample(range(len(depths)), self.DEPTH_ALIGNMENT.NUM_SAMPLES)
x = solve(depth_estim[indices], depths[indices])
e = getError(depths, depth_estim, x)
if e[0] < e_min[0] and (x >= 0).all():
x_min = x
e_min = e
Logger.logDebug(f'{i:05d}: best ransac inliers {e_min[1].float().mean() * 100:.2f}%, x={x_min}')
# recalculate using all inliers
if self.DEPTH_ALIGNMENT.RECALCULATE_INLIERS:
x = solve(depth_estim[e_min[1]], depths[e_min[1]])
else:
x = x_min
# write depth map to dataset, clip to dataset far plane
if (x < 0).sum() > 0:
Logger.logWarning(f'{i:05d}: negative depth coefficients detected ({x})')
dataset.data['train'][i].depth = (1.0 / apply(depth_estim_raw, x)).clamp(0.0, dataset.camera.far_plane)
dataset.data['train'][i]._misc = None
# improve edges in depth maps
Logger.logInfo('correcting edges...')
for i, sample in enumerate(Logger.logProgressBar(dataset, leave=False, desc='sample')):
depth = sample.depth
# estimate edges in depth map and mark as invalid
depth_edges = kornia.filters.canny(depth[None])[1][0].bool()
depth_edges = kornia.morphology.dilation(depth_edges[None].float(), torch.ones(3, 3))[0].bool()
values = -sample.depth.clone()
values[depth_edges] = -dataset.camera.far_plane
# iteratively fill via maxpool
while depth_edges.sum().item() > 0:
depth_edges_eroded = kornia.morphology.erosion(depth_edges[None].float(), torch.ones(3, 3))[0].bool()
depth_edges_diff = depth_edges_eroded ^ depth_edges # xor
values_filled = torch.nn.functional.max_pool2d(values[None], kernel_size=3, stride=1, padding=1)[0]
values[depth_edges_diff] = values_filled[depth_edges_diff]
depth_edges = depth_edges_eroded
# reassign to dataset
dataset.data['train'][i].depth[:, 3:-3, 3:-3] = -values[:, 3:-3, 3:-3].cpu()
# global depth refinement, average depth of dynamic content over multiple frames to reduce jittering
if self.DEPTH_ALIGNMENT.GLOBAL_REFINEMENT_STEPS > 0:
Logger.logInfo('running global depth refinement...')
for _ in Logger.logProgressBar(range(self.DEPTH_ALIGNMENT.GLOBAL_REFINEMENT_STEPS), leave=False, desc='Refining depth maps'):
camera_properties_train = [i for i in dataset.train()]
dynamic_points_world = []
for sample in camera_properties_train:
dataset.camera.setProperties(sample)
fg_mask_eroded = (kornia.morphology.erosion(sample.segmentation[None], torch.ones(9, 9))[0] * sample.occlusion_mask).flatten() > 0.0
rays = dataset.camera.generateRays()
xyz = rays[:, RayPropertySlice.origin] + (rays[:, RayPropertySlice.direction] * rays[:, RayPropertySlice.depth])
xyz = xyz[fg_mask_eroded]
dynamic_points_world.append(xyz)
for i in range(len(camera_properties_train)):
projected_depths = []
dataset.camera.setProperties(camera_properties_train[i])
current_depth = None
for j in range(i - self.DEPTH_ALIGNMENT.GLOBAL_REFINEMENT_WINDOW, i + self.DEPTH_ALIGNMENT.GLOBAL_REFINEMENT_WINDOW + 1):
if j < 0 or j >= len(camera_properties_train) or dynamic_points_world[j].shape[0] == 0:
continue
depths = dataset.camera.projectPoints(dynamic_points_world[j])[2]
depths = depths.mean()
if j == i:
current_depth = depths
projected_depths.append(depths)
if current_depth is not None:
median_depth = torch.tensor(projected_depths).median()
fg_mask_dilated = (kornia.morphology.dilation(camera_properties_train[i].segmentation[None], torch.ones(9, 9))[0]) > 0.0
depth_corrected = camera_properties_train[i].depth.clone()
depth_corrected[fg_mask_dilated] = depth_corrected[fg_mask_dilated] + (median_depth - current_depth)
dataset.data['train'][i].depth = depth_corrected.cpu()
# subtract constant value from dynamic areas in depth maps (DepthAnything produces constant errors for some scenes (e.g. Skating, Truck))
for j, i in enumerate(dataset.train()):
i.depth[i.segmentation > 0] = i.depth[i.segmentation > 0] - self.DEPTH_ALIGNMENT.GLOBAL_REFINEMENT_DEPTH_SUBTRACTION * (dataset.camera.far_plane - dataset.camera.near_plane)
dataset.data['train'][j].depth = i.depth.cpu()
# save data to disk
if self.DEPTH_ALIGNMENT.WRITE_TO_DISK:
Logger.logInfo('writing aligned depth map to disk')
prealigned_path.mkdir(parents=True)
for i in Logger.logProgressBar(range(len(dataset)), leave=False, desc='Writing depth maps'):
# write aligned depth maps (overrides previous alignments!)
np.save(str(prealigned_path / f'{i:05d}.npy'), dataset.data['train'][i].depth.cpu().numpy())
# visualize to ply (useful for debugging)
dataset.camera.setProperties(dataset[i])
rays = dataset.camera.generateRays()
xyz = rays[:, RayPropertySlice.origin] + (rays[:, RayPropertySlice.direction] * rays[:, RayPropertySlice.depth])
# ply
xyz = torch.cat((xyz, sfm_points), dim=0)
rgb = rays[:, RayPropertySlice.rgb]
rgb = torch.cat((rgb, torch.tensor((1.0, 0.0, 0.0))[None].repeat((sfm_points.shape[0], 1))), dim=0)
storePly(str(prealigned_path / f'{i:05d}.ply'), xyz.cpu().numpy(), (rgb.cpu().numpy() * 255.0).astype(np.uint8))
# img
depth_img = pseudoColorDepth(
color_map='SPECTRAL',
depth=dataset.camera.properties.depth,
near_far=(dataset.camera.near_plane, dataset.camera.far_plane),
alpha=None
)
saveImage(prealigned_path / f'{i:05d}.png', depth_img)
@preTrainingCallback(priority=100)
@torch.no_grad()
def initSamplingGrid(self, _, dataset: 'BaseDataset') -> None:
"""Initializes the sampling grid based on segmentation and prealigned depth maps.
Args:
_ (int): Training iteration (unused)
dataset (BaseDataset): Training dataset containing segmentation and depth.
Raises:
MethodError: If the dataset uses an unsupported camera type.
"""
Logger.logInfo('Initializing sampling grid...')
if isinstance(dataset.camera, NDCCamera):
raise Framework.MethodError('DNPC does not support NDC camera datasets')
# recompute dataset bounding box
near = 1e8
far = 0.0
bounding_box = None
for properties in dataset.train():
dataset.camera.setProperties(properties)
near = min(near, properties.depth.min().item())
far = max(far, properties.depth.max().item())
rays = dataset.camera.generateRays()
xyz = rays[:, RayPropertySlice.origin] + rays[:, RayPropertySlice.direction] * rays[:, RayPropertySlice.depth]
if bounding_box is None:
bounding_box = torch.stack((xyz.min(dim=0).values, xyz.max(dim=0).values), dim=0)
else:
torch.min(bounding_box[0], xyz.min(dim=0).values, out=bounding_box[0])
torch.max(bounding_box[1], xyz.max(dim=0).values, out=bounding_box[1])
bounding_box_mean = bounding_box.mean(dim=0)
bounding_box = bounding_box_mean + (bounding_box - bounding_box_mean) * 1.1
dataset.camera.far_plane = far * 1.1
dataset.camera.near_plane = max(near * 0.9, 0.01)
Logger.logInfo(f'new far: {dataset.camera.far_plane}, new near: {dataset.camera.near_plane}')
dataset._bounding_box = bounding_box.cpu()
# initialize probability field
self.model.probability_field.initialize(dataset, self.PROBABILITY_FIELD.DISABLE_INIT)
# write initial weights to disk (useful for debugging)
if self.BACKUP.INITIAL_SAMPLING_WEIGHTS:
xyz = self.model.probability_field.centers
rgb = torch.index_select(ColorMap.get('VIRIDIS'), dim=0, index=(self.model.probability_field.global_weights * 255).int().flatten())
mask = self.model.probability_field.global_weights > 0.05
storePly(str(self.output_directory / 'initial_weights.ply'), xyz[mask].cpu().numpy(), (rgb[mask].cpu().numpy() * 255.0).astype(np.uint8))
# calc and set extent for static and dynamic appearance fields
bb_min, bb_max = dataset.getBoundingBox().to(Framework.config.GLOBAL.DEFAULT_DEVICE)
grid = self.model.probability_field
half_cell_size = grid.cell_size / 2.0
static_bb_min = grid.centers.min(dim=0).values - half_cell_size
torch.max(static_bb_min, bb_min, out=static_bb_min)
static_bb_max = grid.centers.max(dim=0).values + half_cell_size
torch.min(static_bb_max, bb_max, out=static_bb_max)
self.model.static_grid.setExtent(static_bb_min, static_bb_max)
if self.model.dynamic_grid is not None:
dynamic_centers = grid.centers[grid.global_dynamic_mask]
dynamic_bb_min = dynamic_centers.min(dim=0).values - half_cell_size
torch.max(dynamic_bb_min, bb_min, out=dynamic_bb_min)
dynamic_bb_max = dynamic_centers.max(dim=0).values + half_cell_size
torch.min(dynamic_bb_max, bb_max, out=dynamic_bb_max)
self.model.dynamic_grid.setExtent(dynamic_bb_min, dynamic_bb_max)
@preTrainingCallback(priority=2)
def freeTemporaryMemory(self, *_) -> None:
"""Frees temporary torch memory."""
torch.cuda.empty_cache()
@trainingCallback(priority=1000, start_iteration='WARMUP_ITERATIONS', iteration_stride="PROBABILITY_FIELD.PRUNE_STRIDE")
@torch.no_grad()
def pruneProbabilityField(self, *_) -> None:
"""Prunes the probability field, removing cells below a given threshold.
"""
self.model.probability_field.prune(self.PROBABILITY_FIELD.PRUNE_THRESHOLD)
self.model.probability_field.sparsify()
@trainingCallback(priority=10, start_iteration='LOSS.DEPTH_LOSS_END_ITERATION', iteration_stride='NUM_ITERATIONS')
def disableDepthLoss(self, *_) -> None:
"""Disables monoc depth losses after a given iteration.
"""
for loss in self.loss.loss_metrics:
if loss.name == 'MonocularDepth':
loss.weight = 0.0
@trainingCallback(priority=10, start_iteration='LOSS.DYNAMIC_WEIGHTS_START_ITERATION', iteration_stride='NUM_ITERATIONS')
def enableDynamicWeightsLoss(self, *_) -> None:
"""Enables binary entropy loss on dynamic weights after a given iteration.
"""
for loss in self.loss.loss_metrics:
if loss.name == 'DynamicWeights':
loss.weight = self.LOSS.LAMBDA_DYNAMIC_WEIGHTS
@trainingCallback(priority=5)
def trainingIteration(self, iteration: int, dataset: 'BaseDataset') -> None:
"""Performs a single training iteration.
Args:
iteration (int): Current training iteration id.
dataset (BaseDataset): Dataset used for training.
"""
# set training mode
self.model.train()
dataset.train()
self.loss.train()
# clear gradients
self.optimizer.zero_grad()
# sample dataset
sample = self.train_sampler.get(dataset=dataset)
camera_properties = sample['camera_properties']
# sample_index = sample['sample_id']
dataset.camera.setProperties(camera_properties)
# update model
with torch.autocast(device_type="cuda"):
# run rasterizer
outputs = self.renderer.renderImage(
camera=dataset.camera,
to_chw=True)
# calculate loss
loss = self.loss(outputs, camera_properties, dataset)
# backpropagate
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scheduler.step()
self.scaler.update()
# update probability field
if iteration >= self.WARMUP_ITERATIONS:
with torch.no_grad():
self.model.probability_field.update(
outputs['extras']['blending_weights'],
outputs['extras']['fused_dynamic_weights'],
outputs['extras']['dynamic_mask'],
outputs['extras']['sample_indices'],
camera_properties.timestamp,
)
@trainingCallback(active='WANDB.ACTIVATE', priority=500, iteration_stride='WANDB.INTERVAL')
@torch.no_grad()
def logWandB(self, iteration: int, dataset: 'BaseDataset') -> None:
"""Logs training data to wandb.
Args:
iteration (int): Current training iteration id.
dataset (BaseDataset): Dataset used for training.
"""
super().logWandB(iteration, dataset)
# visualize scene and sampling grid as point clouds
if self.WANDB.RENDER_SCENE:
logScene(self.model, iteration, dataset, None, 'scene')
# log cell distribution
Framework.wandb.log({'cell_distribution': self.model.probability_field.getDistribution()}, step=iteration)