Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mmdet support musa backend #12313

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmdet/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ def inference_mot(model: nn.Module, img: np.ndarray, frame_id: int,
test_pipeline = build_test_pipeline(cfg)
data = test_pipeline(data)

if not next(model.parameters()).is_cuda:
if not next(model.parameters()).is_cuda and not (next(
model.parameters()).device.type == 'musa'):

for m in model.modules():
assert not isinstance(
m, RoIPool
Expand Down
11 changes: 8 additions & 3 deletions mmdet/models/backbones/csp_darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,14 @@ def __init__(self,

def forward(self, x):
x = self.conv1(x)
with torch.cuda.amp.autocast(enabled=False):
x = torch.cat(
[x] + [pooling(x) for pooling in self.poolings], dim=1)
if x.device.type == 'musa':
with torch_musa.core.amp.autocast(enabled=False):
x = torch.cat(
[x] + [pooling(x) for pooling in self.poolings], dim=1)
else:
with torch.cuda.amp.autocast(enabled=False):
x = torch.cat(
[x] + [pooling(x) for pooling in self.poolings], dim=1)
x = self.conv2(x)
return x

Expand Down
9 changes: 7 additions & 2 deletions mmdet/models/layers/se_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,13 @@ def __init__(self, channels: int, init_cfg: OptMultiConfig = None) -> None:

def forward(self, x: Tensor) -> Tensor:
"""Forward function for ChannelAttention."""
with torch.cuda.amp.autocast(enabled=False):
out = self.global_avgpool(x)
if x.device.type == 'musa':
with torch_musa.core.amp.autocast(enabled=False):
out = self.global_avgpool(x)
else:
"""Forward function for ChannelAttention."""
with torch.cuda.amp.autocast(enabled=False):
out = self.global_avgpool(x)
out = self.fc(out)
out = self.act(out)
return x * out
2 changes: 2 additions & 0 deletions mmdet/models/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def forward(self,
calculate_loss_func = py_sigmoid_focal_loss
elif torch.cuda.is_available() and pred.is_cuda:
calculate_loss_func = sigmoid_focal_loss
elif torch.musa.is_available() and pred.device.type == 'musa':
calculate_loss_func = sigmoid_focal_loss
else:
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/task_modules/assigners/iou2d_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False):
bboxes1 = cast_tensor_type(bboxes1, self.scale, self.dtype)
bboxes2 = cast_tensor_type(bboxes2, self.scale, self.dtype)
overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
if not overlaps.is_cuda and overlaps.dtype == torch.float16:
if not overlaps.is_cuda and overlaps.device.type != 'musa'
and overlaps.dtype == torch.float16:
# resume cpu float32
overlaps = overlaps.float()
return overlaps
Expand Down
29 changes: 22 additions & 7 deletions mmdet/models/task_modules/assigners/sim_ota_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,28 @@ def assign(self,

valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
# disable AMP autocast and calculate BCE with FP32 to avoid overflow
with torch.cuda.amp.autocast(enabled=False):
cls_cost = (
F.binary_cross_entropy(
valid_pred_scores.to(dtype=torch.float32),
gt_onehot_label,
reduction='none',
).sum(-1).to(dtype=valid_pred_scores.dtype))
try:
import torch_musa
IS_MUSA_AVAILABLE = True
except Exception:
IS_MUSA_AVAILABLE = False

if IS_MUSA_AVAILABLE:
with torch_musa.core.amp.autocast(enabled=False):
cls_cost = (
F.binary_cross_entropy(
valid_pred_scores.to(dtype=torch.float32),
gt_onehot_label,
reduction='none',
).sum(-1).to(dtype=valid_pred_scores.dtype))
else:
with torch.cuda.amp.autocast(enabled=False):
cls_cost = (
F.binary_cross_entropy(
valid_pred_scores.to(dtype=torch.float32),
gt_onehot_label,
reduction='none',
).sum(-1).to(dtype=valid_pred_scores.dtype))

cost_matrix = (
cls_cost * self.cls_weight + iou_cost * self.iou_weight +
Expand Down
8 changes: 7 additions & 1 deletion mmdet/models/task_modules/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,17 @@ def random_choice(self, gallery: Union[Tensor, ndarray, list],
Tensor or ndarray: sampled indices.
"""
assert len(gallery) >= num

try:
import torch_musa
IS_MUSA_AVAILABLE = True
except Exception:
IS_MUSA_AVAILABLE = False
is_tensor = isinstance(gallery, torch.Tensor)
if not is_tensor:
if torch.cuda.is_available():
device = torch.cuda.current_device()
elif IS_MUSA_AVAILABLE:
device = torch.musa.current_device()
else:
device = 'cpu'
gallery = torch.tensor(gallery, dtype=torch.long, device=device)
Expand Down
7 changes: 7 additions & 0 deletions mmdet/models/task_modules/samplers/score_hlr_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,17 @@ def random_choice(gallery: Union[Tensor, ndarray, list],
"""
assert len(gallery) >= num

try:
import torch_musa
IS_MUSA_AVAILABLE = True
except Exception:
IS_MUSA_AVAILABLE = False
is_tensor = isinstance(gallery, torch.Tensor)
if not is_tensor:
if torch.cuda.is_available():
device = torch.cuda.current_device()
elif IS_MUSA_AVAILABLE:
device = torch.musa.current_device()
else:
device = 'cpu'
gallery = torch.tensor(gallery, dtype=torch.long, device=device)
Expand Down
2 changes: 2 additions & 0 deletions mmdet/models/task_modules/tracking/aflink.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def __init__(self,
load_checkpoint(self.model, checkpoint)
if torch.cuda.is_available():
self.model.cuda()
elif torch.musa.is_available():
self.model.musa()
self.model.eval()

self.device = next(self.model.parameters()).device
Expand Down
2 changes: 1 addition & 1 deletion mmdet/structures/bbox/bbox_overlaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def fp16_clamp(x, min=None, max=None):
if not x.is_cuda and x.dtype == torch.float16:
if not x.is_cuda and x.device.type != 'musa' and x.dtype == torch.float16:
# clamp for cpu float16, tensor fp16 has no clamp implementation
return x.float().clamp(min, max).half()

Expand Down
103 changes: 73 additions & 30 deletions mmdet/utils/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mmengine import MMLogger
from mmengine.config import Config
from mmengine.device import get_max_cuda_memory
from mmengine.device.utils import is_musa_available
from mmengine.dist import get_world_size
from mmengine.runner import Runner, load_checkpoint
from mmengine.utils.dl_utils import set_multi_processing
Expand Down Expand Up @@ -193,14 +194,22 @@ def _init_model(self, checkpoint: str, is_fuse_conv_bn: bool) -> nn.Module:
if is_fuse_conv_bn:
model = fuse_conv_bn(model)

model = model.cuda()

if self.distributed:
model = DistributedDataParallel(
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=False)
if is_musa_available():
model = model.musa()
if self.distributed:
model = DistributedDataParallel(
model,
device_ids=[torch.musa.current_device()],
broadcast_buffers=False,
find_unused_parameters=False)
else:
model = model.cuda()
if self.distributed:
model = DistributedDataParallel(
model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=False)

model.eval()
return model
Expand All @@ -209,37 +218,71 @@ def run_once(self) -> dict:
"""Executes the benchmark once."""
pure_inf_time = 0
fps = 0
if is_musa_available():
for i, data in enumerate(self.data_loader):

for i, data in enumerate(self.data_loader):
if (i + 1) % self.log_interval == 0:
print_log('==================================',
self.logger)

if (i + 1) % self.log_interval == 0:
print_log('==================================', self.logger)
torch.musa.synchronize()
start_time = time.perf_counter()

torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
self.model.test_step(data)

with torch.no_grad():
self.model.test_step(data)
torch.musa.synchronize()
elapsed = time.perf_counter() - start_time

torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= self.num_warmup:
pure_inf_time += elapsed
if (i + 1) % self.log_interval == 0:
fps = (i + 1 - self.num_warmup) / pure_inf_time
musa_memory = get_max_musa_memory()

if i >= self.num_warmup:
pure_inf_time += elapsed
if (i + 1) % self.log_interval == 0:
print_log(
f'Done image [{i + 1:<3}/{self.max_iter}], '
f'fps: {fps:.1f} img/s, '
f'times per image: {1000 / fps:.1f} ms/img, '
f'musa memory: {musa_memory} MB', self.logger)
print_process_memory(self._process, self.logger)

if (i + 1) == self.max_iter:
fps = (i + 1 - self.num_warmup) / pure_inf_time
cuda_memory = get_max_cuda_memory()
break

print_log(
f'Done image [{i + 1:<3}/{self.max_iter}], '
f'fps: {fps:.1f} img/s, '
f'times per image: {1000 / fps:.1f} ms/img, '
f'cuda memory: {cuda_memory} MB', self.logger)
print_process_memory(self._process, self.logger)
else:
for i, data in enumerate(self.data_loader):

if (i + 1) == self.max_iter:
fps = (i + 1 - self.num_warmup) / pure_inf_time
break
if (i + 1) % self.log_interval == 0:
print_log('==================================',
self.logger)

torch.cuda.synchronize()
start_time = time.perf_counter()

with torch.no_grad():
self.model.test_step(data)

torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time

if i >= self.num_warmup:
pure_inf_time += elapsed
if (i + 1) % self.log_interval == 0:
fps = (i + 1 - self.num_warmup) / pure_inf_time
cuda_memory = get_max_cuda_memory()

print_log(
f'Done image [{i + 1:<3}/{self.max_iter}], '
f'fps: {fps:.1f} img/s, '
f'times per image: {1000 / fps:.1f} ms/img, '
f'cuda memory: {cuda_memory} MB', self.logger)
print_process_memory(self._process, self.logger)

if (i + 1) == self.max_iter:
fps = (i + 1 - self.num_warmup) / pure_inf_time
break

return {'fps': fps}

Expand Down
Loading