diff --git a/softgroup/model/softgroup.py b/softgroup/model/softgroup.py index 5665e21..adbed6e 100644 --- a/softgroup/model/softgroup.py +++ b/softgroup/model/softgroup.py @@ -2,6 +2,7 @@ import functools import spconv.pytorch as spconv import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -123,18 +124,14 @@ class SoftGroup(nn.Module): proposals_offset, instance_labels, instance_pointnum, instance_cls, instance_batch_idxs) losses.update(instance_loss) - - # parse loss - loss = sum(v[0] for v in losses.values()) - losses['loss'] = (loss, batch_idxs.size(0)) - return loss, losses + return self.parse_losses(losses) def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance_labels, pt_offset_labels): losses = {} semantic_loss = F.cross_entropy( semantic_scores, semantic_labels, ignore_index=self.ignore_label) - losses['semantic_loss'] = (semantic_loss, semantic_scores.size(0)) + losses['semantic_loss'] = semantic_loss pos_inds = instance_labels != self.ignore_label if pos_inds.sum() == 0: @@ -142,7 +139,7 @@ class SoftGroup(nn.Module): else: offset_loss = F.l1_loss( pt_offsets[pos_inds], pt_offset_labels[pos_inds], reduction='sum') / pos_inds.sum() - losses['offset_loss'] = (offset_loss, pos_inds.sum()) + losses['offset_loss'] = offset_loss return losses @force_fp32(apply_to=('cls_scores', 'mask_scores', 'iou_scores')) @@ -170,7 +167,7 @@ class SoftGroup(nn.Module): labels = fg_instance_cls.new_full((fg_ious_on_cluster.size(0), ), self.instance_classes) labels[pos_inds] = fg_instance_cls[pos_gt_inds] cls_loss = F.cross_entropy(cls_scores, labels) - losses['cls_loss'] = (cls_loss, labels.size(0)) + losses['cls_loss'] = cls_loss # compute mask loss mask_cls_label = labels[instance_batch_idxs.long()] @@ -184,7 +181,7 @@ class SoftGroup(nn.Module): mask_loss = F.binary_cross_entropy( mask_scores_sigmoid_slice, mask_label, weight=mask_label_weight, reduction='sum') mask_loss /= (mask_label_weight.sum() + 1) - losses['mask_loss'] = (mask_loss, mask_label_weight.sum()) + losses['mask_loss'] = mask_loss # compute iou score loss ious = get_mask_iou_on_pred(proposals_idx, proposals_offset, instance_labels, @@ -196,9 +193,19 @@ class SoftGroup(nn.Module): iou_score_slice = iou_scores[slice_inds, labels] iou_score_loss = F.mse_loss(iou_score_slice, gt_ious, reduction='none') iou_score_loss = (iou_score_loss * iou_score_weight).sum() / (iou_score_weight.sum() + 1) - losses['iou_score_loss'] = (iou_score_loss, iou_score_weight.sum()) + losses['iou_score_loss'] = iou_score_loss return losses + def parse_losses(self, losses): + loss = sum(v for v in losses.values()) + losses['loss'] = loss + for loss_name, loss_value in losses.items(): + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + losses[loss_name] = loss_value.item() + return loss, losses + @cuda_cast def forward_test(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats, semantic_labels, instance_labels, pt_offset_labels, spatial_shape, batch_size, diff --git a/softgroup/util/utils.py b/softgroup/util/utils.py index 81374b5..9554818 100644 --- a/softgroup/util/utils.py +++ b/softgroup/util/utils.py @@ -5,14 +5,16 @@ from collections import OrderedDict from math import cos, pi import torch +from torch import distributed as dist -from .dist import master_only +from .dist import get_dist_info, master_only class AverageMeter(object): """Computes and stores the average and current value.""" - def __init__(self): + def __init__(self, apply_dist_reduce=False): + self.apply_dist_reduce = apply_dist_reduce self.reset() def reset(self): @@ -21,6 +23,27 @@ class AverageMeter(object): self.sum = 0 self.count = 0 + def dist_reduce(self, val): + rank, world_size = get_dist_info() + if world_size == 1: + return val + if not isinstance(val, torch.Tensor): + val = torch.tensor(val, device='cuda') + dist.all_reduce(val) + return val.item() / world_size + + def get_val(self): + if self.apply_dist_reduce: + return self.dist_reduce(self.val) + else: + return self.val + + def get_avg(self): + if self.apply_dist_reduce: + return self.dist_reduce(self.avg) + else: + return self.avg + def update(self, val, n=1): self.val = val self.sum += val * n @@ -124,7 +147,10 @@ def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False): def get_max_memory(): mem = torch.cuda.max_memory_allocated() - mem_mb = torch.tensor([int(mem) // (1024 * 1024)], dtype=torch.int) + mem_mb = torch.tensor([int(mem) // (1024 * 1024)], dtype=torch.int, device='cuda') + _, world_size = get_dist_info() + if world_size > 1: + dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) return mem_mb.item() diff --git a/train.py b/train.py index 21708a9..827fa2e 100644 --- a/train.py +++ b/train.py @@ -82,8 +82,8 @@ if __name__ == '__main__': logger.info('Training') for epoch in range(start_epoch, cfg.epochs + 1): model.train() - iter_time = AverageMeter() - data_time = AverageMeter() + iter_time = AverageMeter(True) + data_time = AverageMeter(True) meter_dict = {} end = time.time() @@ -102,7 +102,7 @@ if __name__ == '__main__': for k, v in log_vars.items(): if k not in meter_dict.keys(): meter_dict[k] = AverageMeter() - meter_dict[k].update(v[0], v[1]) + meter_dict[k].update(v) # backward optimizer.zero_grad() @@ -111,9 +111,7 @@ if __name__ == '__main__': scaler.update() # time and print - current_iter = (epoch - 1) * len(train_loader) + i - max_iter = cfg.epochs * len(train_loader) - remain_iter = max_iter - current_iter + remain_iter = len(train_loader) * (cfg.epochs - epoch + 1) - i iter_time.update(time.time() - end) end = time.time() remain_time = remain_iter * iter_time.avg