Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial multigpu support #121

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cfg/coco.data
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
classes=80
train=../coco/trainvalno5k.txt
valid=../coco/5k.txt
train=data/coco/images/train2014
valid=data/coco/images/val2014
names=data/coco.names
backup=backup/
eval=coco
5 changes: 3 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def detect(
dataloader = LoadImages(images, img_size=img_size)

# Get classes and colors
classes = load_classes(parse_data_cfg('cfg/coco.data')['names'])
colors = [[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] for _ in range(len(classes))]

for i, (path, img, im0) in enumerate(dataloader):
Expand Down Expand Up @@ -112,8 +111,10 @@ def detect(
parser.add_argument('--img-size', type=int, default=32 * 13, help='size of each image dimension')
parser.add_argument('--conf-thres', type=float, default=0.50, help='object confidence threshold')
parser.add_argument('--nms-thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
parser.add_argument('--data-cfg', type=str, default='cfg/coco.data', help='coco.data file path')
opt = parser.parse_args()
print(opt)
classes = load_classes(opt.data_cfg)

with torch.no_grad():
detect(
Expand All @@ -123,4 +124,4 @@ def detect(
img_size=opt.img_size,
conf_thres=opt.conf_thres,
nms_thres=opt.nms_thres
)
)
24 changes: 10 additions & 14 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(self, p, img_size, targets=None, var=None):
if ONNX_EXPORT:
bs, nG = 1, self.nG # batch size, grid size
else:
bs, nG = p.shape[0], p.shape[-1]
bs, nG = p.shape[0], p.shape[-1] # grid size defined by previous layer outputs

if self.img_size != img_size:
create_grids(self, img_size, nG)
Expand All @@ -132,7 +132,7 @@ def forward(self, p, img_size, targets=None, var=None):
self.grid_xy = self.grid_xy.cuda()
self.anchor_wh = self.anchor_wh.cuda()

# p.view(bs, 255, 13, 13) -- > (bs, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
# p.view(bs, 3, 80 + 5, 13, 13) -- > (bs, 3, 13, 13, 80 + 5) # (bs, anchors, grid, grid, classes + xywh)
p = p.view(bs, self.nA, self.nC + 5, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction

# xy, width and height
Expand All @@ -153,11 +153,11 @@ def forward(self, p, img_size, targets=None, var=None):
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)

tcls = tcls[mask]
if p.is_cuda:
if xy.is_cuda:
txy, twh, mask, tcls = txy.cuda(), twh.cuda(), mask.cuda(), tcls.cuda()

# Compute losses
nT = sum([len(x) for x in targets]) # number of targets
nT = sum([len(x[torch.unique(x.nonzero()[:, 0])]) for x in targets]) # number of targets
nM = mask.sum().float() # number of anchors (assigned to targets)
k = 1 # nM / bs
if nM > 0:
Expand Down Expand Up @@ -220,13 +220,14 @@ def __init__(self, cfg_path, img_size=416):
self.module_defs = parse_model_cfg(cfg_path)
self.module_defs[0]['cfg'] = cfg_path
self.module_defs[0]['height'] = img_size
self.num_yolo = len([x for x in self.module_defs if x['type'] == 'yolo'])
self.hyperparams, self.module_list = create_modules(self.module_defs)
self.img_size = img_size
self.loss_names = ['loss', 'xy', 'wh', 'conf', 'cls', 'nT']
self.losses = []

def forward(self, x, targets=None, var=0):
self.losses = defaultdict(float)
losses_b = torch.zeros(6).cuda() if x.is_cuda else torch.zeros(6)
is_training = targets is not None
img_size = x.shape[-1]
layer_outputs = []
Expand All @@ -248,26 +249,21 @@ def forward(self, x, targets=None, var=0):
elif mtype == 'yolo':
if is_training: # get loss
x, *losses = module[0](x, img_size, targets, var)
for name, loss in zip(self.loss_names, losses):
self.losses[name] += loss
for k, loss in enumerate(losses):
losses_b[k] += loss
else: # get detections
x = module[0](x, img_size)
output.append(x)
layer_outputs.append(x)

if is_training:
self.losses['nT'] /= 3
losses_b[-1] /= self.num_yolo

if ONNX_EXPORT:
output = torch.cat(output, 1) # merge the 3 layers 85 x (507, 2028, 8112) to 85 x 10647
return output[5:85].t(), output[:4].t() # ONNX scores, boxes

return sum(output) if is_training else torch.cat(output, 1)


def get_yolo_layers(model):
a = [module_def['type'] == 'yolo' for module_def in model.module_defs]
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
return (sum(output), losses_b.unsqueeze(0)) if is_training else torch.cat(output, 1)


def create_grids(self, img_size, nG):
Expand Down
11 changes: 8 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from pathlib import Path

from torch.utils.data import DataLoader
from models import *
from utils.datasets import *
from utils.utils import *
Expand Down Expand Up @@ -39,15 +40,19 @@ def test(

# Get dataloader
# dataloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path), batch_size=batch_size)
dataloader = LoadImagesAndLabels(test_path, batch_size=batch_size, img_size=img_size)
dataset = ImageLabelDataset(test_path, batch_size=batch_size, img_size=img_size)
dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)

mean_mAP, mean_R, mean_P, seen = 0.0, 0.0, 0.0, 0
print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP'))
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class, jdict = \
[], [], [], [], [], [], [], [], []
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
coco91class = coco80_to_coco91_class()
# classes = load_classes(data_cfg_dict['names'])
for batch_i, (imgs, targets, paths, shapes) in enumerate(dataloader):
imgs.squeeze_(0)
targets.squeeze_(0)
t = time.time()
output = model(imgs.to(device))
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
Expand Down Expand Up @@ -131,7 +136,7 @@ def test(

# Print image mAP and running mean mAP
print(('%11s%11s' + '%11.3g' * 4 + 's') %
(seen, dataloader.nF, mean_P, mean_R, mean_mAP, time.time() - t))
(seen, dataset.nF, mean_P, mean_R, mean_mAP, time.time() - t))

# Print mAP per class
print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP') + '\n\nmAP Per Class:')
Expand All @@ -141,7 +146,7 @@ def test(

# Save JSON
if save_json:
imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.img_files]
imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataset.img_files]
with open('results.json', 'w') as file:
json.dump(jdict, file)

Expand Down
65 changes: 46 additions & 19 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import argparse
import time

from torch.utils.data import DataLoader

import test # Import test.py to get mAP after each epoch
from models import *
from utils.datasets import *
from utils.utils import *


torch.multiprocessing.set_sharing_strategy('file_system')


def train(
cfg,
data_cfg,
Expand All @@ -17,8 +22,10 @@ def train(
accumulated_batches=1,
multi_scale=False,
freeze_backbone=False,
num_workers=0,
var=0,
):

weights = 'weights' + os.sep
latest = weights + 'latest.pt'
best = weights + 'best.pt'
Expand All @@ -35,21 +42,27 @@ def train(
# Initialize model
model = Darknet(cfg, img_size)

# Get dataloader
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, multi_scale=multi_scale, augment=True)
# Get dataloader with multi-threading
train_loader = ImageLabelDataset(train_path, batch_size, img_size, multi_scale=multi_scale, augment=True)
dataloader = DataLoader(
dataset=train_loader,
batch_size=1,
shuffle=False,
num_workers=num_workers)

lr0 = 0.001
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
best_loss = float('inf')

if resume:
checkpoint = torch.load(latest, map_location='cpu')

# Load weights to resume from
model.load_state_dict(checkpoint['model'])

# if torch.cuda.device_count() > 1:
# model = nn.DataParallel(model)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device).train()

# Transfer learning (train only YOLO layers)
Expand All @@ -62,7 +75,7 @@ def train(
start_epoch = checkpoint['epoch'] + 1
if checkpoint['optimizer'] is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
best_loss = checkpoint['best_loss']
best_loss = checkpoint['best_loss'].to(device)

del checkpoint # current, saved

Expand All @@ -75,8 +88,8 @@ def train(
load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
cutoff = 15

# if torch.cuda.device_count() > 1:
# model = nn.DataParallel(model)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to(device).train()

# Set optimizer
Expand All @@ -88,7 +101,9 @@ def train(
# Start training
t0 = time.time()
model_info(model)
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches
n_burnin = min(round(train_loader.nB / 5), 1000) # number of burn-in batches
loss_names = ['loss', 'xy', 'wh', 'conf', 'cls', 'nT']

for epoch in range(epochs):
epoch += start_epoch

Expand Down Expand Up @@ -116,20 +131,24 @@ def train(
rloss = defaultdict(float) # running loss
optimizer.zero_grad()
for i, (imgs, targets, _, _) in enumerate(dataloader):
if sum([len(x) for x in targets]) < 1: # if no targets continue
if torch.all(targets == torch.zeros_like(targets)): # if no targets continue
continue

imgs.squeeze_(0)
targets.squeeze_(0)
# SGD burn-in
if (epoch == 0) & (i <= n_burnin):
lr = lr0 * (i / n_burnin) ** 4
for g in optimizer.param_groups:
g['lr'] = lr

# Compute loss
loss = model(imgs.to(device), targets, var=var)

# Compute gradient
loss.backward()
# Compute loss, compute gradient, update parameters
model.train()
losses = defaultdict(float)
loss, losses_b = model(imgs.to(device), targets.to(device), var=var)
losses_b = losses_b.sum(0)
for k, name in enumerate(loss_names):
losses[name] += losses_b[k]
loss.sum().backward()

# Accumulate gradient for x batches before optimizing
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
Expand All @@ -138,15 +157,15 @@ def train(

# Running epoch-means of tracked metrics
ui += 1
for key, val in model.losses.items():
for key, val in losses.items():
rloss[key] = (rloss[key] * ui + val) / (ui + 1)

s = ('%8s%12s' + '%10.3g' * 7) % (
'%g/%g' % (epoch, epochs - 1),
'%g/%g' % (i, len(dataloader) - 1),
rloss['xy'], rloss['wh'], rloss['conf'],
rloss['cls'], rloss['loss'],
model.losses['nT'], time.time() - t0)
losses['nT'], time.time() - t0)
t0 = time.time()
print(s)

Expand All @@ -156,9 +175,13 @@ def train(
best_loss = loss_per_target

# Save latest checkpoint
if type(model) is nn.DataParallel:
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
checkpoint = {'epoch': epoch,
'best_loss': best_loss,
'model': model.state_dict(),
'model': state_dict,
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest)

Expand All @@ -172,7 +195,7 @@ def train(

# Calculate mAP
with torch.no_grad():
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)

# Write epoch results
with open('results.txt', 'a') as file:
Expand All @@ -189,11 +212,14 @@ def train(
parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608')
parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels')
parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--num-workers', type=int, default=0, help='number of workers for dataloader')
parser.add_argument('--var', type=float, default=0, help='test variable')
opt = parser.parse_args()
print(opt, end='\n\n')

init_seeds()
data_cfg = parse_data_cfg(opt.data_cfg)
classes = load_classes(data_cfg['names'])

train(
opt.cfg,
Expand All @@ -204,5 +230,6 @@ def train(
batch_size=opt.batch_size,
accumulated_batches=opt.accumulated_batches,
multi_scale=opt.multi_scale,
num_workers=opt.num_workers,
var=opt.var,
)
Empty file added utils/__init__.py
Empty file.
Loading