Skip to content

Commit

Permalink
fix a visualization bug
Browse files Browse the repository at this point in the history
  • Loading branch information
QiuJueqin committed May 20, 2020
1 parent 15ff2b5 commit b0d80da
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
14 changes: 7 additions & 7 deletions src/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,27 @@ def __init__(self, phase, cfg):

def __getitem__(self, index):
image, image_id = self.load_image(index)
gt_class_ids, gt_boxes = self.load_annotations(index)

image_meta = {'index': index,
'image_id': image_id,
'orig_size': np.array(image.shape, dtype=np.int32)}

gt_class_ids, gt_boxes = self.load_annotations(index)

image, image_meta, gt_boxes = self.preprocess(image, image_meta, gt_boxes)
gt = self.prepare_annotations(gt_class_ids, gt_boxes)

inp = {'image': image.transpose(2, 0, 1),
'image_meta': image_meta,
'gt': gt}

if self.cfg.debug == 1:
image = image * image_meta['rgb_std'] + image_meta['rgb_mean']
save_path = os.path.join(self.cfg.debug_dir, image_meta['image_id'] + '.png')
visualize_boxes(image, gt_class_ids, gt_boxes,
class_names=self.class_names,
save_path=save_path)

batch = {'image': image.transpose(2, 0, 1),
'image_meta': image_meta,
'gt': gt}

return batch
return inp

def __len__(self):
return len(self.sample_ids)
Expand Down
4 changes: 2 additions & 2 deletions src/model/squeezedet.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
if m is self.convdet:
nn.init.normal_(m.weight, mean=0.0, std=0.001)
nn.init.normal_(m.weight, mean=0.0, std=0.002)
else:
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.normal_(m.weight, mean=0.0, std=0.005)
if m.bias is not None:
nn.init.constant_(m.bias, 0)

Expand Down
2 changes: 0 additions & 2 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def parse(self, args=''):
cfg.save_dir = os.path.join(cfg.exp_dir, cfg.exp_id)
cfg.debug_dir = os.path.join(cfg.save_dir, 'debug')
print('The results will be saved to ', cfg.save_dir)

return cfg

@staticmethod
Expand All @@ -137,4 +136,3 @@ def print(cfg):
for name in sorted(names):
if not name.startswith('_'):
print('{:<30} {}'.format(name, getattr(cfg, name)))
print('/n')

0 comments on commit b0d80da

Please sign in to comment.