mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
log with distributed parallel
This commit is contained in:
parent
f0802b75bb
commit
70c86093db
@ -2,6 +2,7 @@ import functools
|
||||
|
||||
import spconv.pytorch as spconv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -123,18 +124,14 @@ class SoftGroup(nn.Module):
|
||||
proposals_offset, instance_labels, instance_pointnum,
|
||||
instance_cls, instance_batch_idxs)
|
||||
losses.update(instance_loss)
|
||||
|
||||
# parse loss
|
||||
loss = sum(v[0] for v in losses.values())
|
||||
losses['loss'] = (loss, batch_idxs.size(0))
|
||||
return loss, losses
|
||||
return self.parse_losses(losses)
|
||||
|
||||
def point_wise_loss(self, semantic_scores, pt_offsets, semantic_labels, instance_labels,
|
||||
pt_offset_labels):
|
||||
losses = {}
|
||||
semantic_loss = F.cross_entropy(
|
||||
semantic_scores, semantic_labels, ignore_index=self.ignore_label)
|
||||
losses['semantic_loss'] = (semantic_loss, semantic_scores.size(0))
|
||||
losses['semantic_loss'] = semantic_loss
|
||||
|
||||
pos_inds = instance_labels != self.ignore_label
|
||||
if pos_inds.sum() == 0:
|
||||
@ -142,7 +139,7 @@ class SoftGroup(nn.Module):
|
||||
else:
|
||||
offset_loss = F.l1_loss(
|
||||
pt_offsets[pos_inds], pt_offset_labels[pos_inds], reduction='sum') / pos_inds.sum()
|
||||
losses['offset_loss'] = (offset_loss, pos_inds.sum())
|
||||
losses['offset_loss'] = offset_loss
|
||||
return losses
|
||||
|
||||
@force_fp32(apply_to=('cls_scores', 'mask_scores', 'iou_scores'))
|
||||
@ -170,7 +167,7 @@ class SoftGroup(nn.Module):
|
||||
labels = fg_instance_cls.new_full((fg_ious_on_cluster.size(0), ), self.instance_classes)
|
||||
labels[pos_inds] = fg_instance_cls[pos_gt_inds]
|
||||
cls_loss = F.cross_entropy(cls_scores, labels)
|
||||
losses['cls_loss'] = (cls_loss, labels.size(0))
|
||||
losses['cls_loss'] = cls_loss
|
||||
|
||||
# compute mask loss
|
||||
mask_cls_label = labels[instance_batch_idxs.long()]
|
||||
@ -184,7 +181,7 @@ class SoftGroup(nn.Module):
|
||||
mask_loss = F.binary_cross_entropy(
|
||||
mask_scores_sigmoid_slice, mask_label, weight=mask_label_weight, reduction='sum')
|
||||
mask_loss /= (mask_label_weight.sum() + 1)
|
||||
losses['mask_loss'] = (mask_loss, mask_label_weight.sum())
|
||||
losses['mask_loss'] = mask_loss
|
||||
|
||||
# compute iou score loss
|
||||
ious = get_mask_iou_on_pred(proposals_idx, proposals_offset, instance_labels,
|
||||
@ -196,9 +193,19 @@ class SoftGroup(nn.Module):
|
||||
iou_score_slice = iou_scores[slice_inds, labels]
|
||||
iou_score_loss = F.mse_loss(iou_score_slice, gt_ious, reduction='none')
|
||||
iou_score_loss = (iou_score_loss * iou_score_weight).sum() / (iou_score_weight.sum() + 1)
|
||||
losses['iou_score_loss'] = (iou_score_loss, iou_score_weight.sum())
|
||||
losses['iou_score_loss'] = iou_score_loss
|
||||
return losses
|
||||
|
||||
def parse_losses(self, losses):
|
||||
loss = sum(v for v in losses.values())
|
||||
losses['loss'] = loss
|
||||
for loss_name, loss_value in losses.items():
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
loss_value = loss_value.data.clone()
|
||||
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
||||
losses[loss_name] = loss_value.item()
|
||||
return loss, losses
|
||||
|
||||
@cuda_cast
|
||||
def forward_test(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats,
|
||||
semantic_labels, instance_labels, pt_offset_labels, spatial_shape, batch_size,
|
||||
|
||||
@ -5,14 +5,16 @@ from collections import OrderedDict
|
||||
from math import cos, pi
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
from .dist import master_only
|
||||
from .dist import get_dist_info, master_only
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, apply_dist_reduce=False):
|
||||
self.apply_dist_reduce = apply_dist_reduce
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
@ -21,6 +23,27 @@ class AverageMeter(object):
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def dist_reduce(self, val):
|
||||
rank, world_size = get_dist_info()
|
||||
if world_size == 1:
|
||||
return val
|
||||
if not isinstance(val, torch.Tensor):
|
||||
val = torch.tensor(val, device='cuda')
|
||||
dist.all_reduce(val)
|
||||
return val.item() / world_size
|
||||
|
||||
def get_val(self):
|
||||
if self.apply_dist_reduce:
|
||||
return self.dist_reduce(self.val)
|
||||
else:
|
||||
return self.val
|
||||
|
||||
def get_avg(self):
|
||||
if self.apply_dist_reduce:
|
||||
return self.dist_reduce(self.avg)
|
||||
else:
|
||||
return self.avg
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
@ -124,7 +147,10 @@ def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False):
|
||||
|
||||
def get_max_memory():
|
||||
mem = torch.cuda.max_memory_allocated()
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)], dtype=torch.int)
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)], dtype=torch.int, device='cuda')
|
||||
_, world_size = get_dist_info()
|
||||
if world_size > 1:
|
||||
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
|
||||
return mem_mb.item()
|
||||
|
||||
|
||||
|
||||
10
train.py
10
train.py
@ -82,8 +82,8 @@ if __name__ == '__main__':
|
||||
logger.info('Training')
|
||||
for epoch in range(start_epoch, cfg.epochs + 1):
|
||||
model.train()
|
||||
iter_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
iter_time = AverageMeter(True)
|
||||
data_time = AverageMeter(True)
|
||||
meter_dict = {}
|
||||
end = time.time()
|
||||
|
||||
@ -102,7 +102,7 @@ if __name__ == '__main__':
|
||||
for k, v in log_vars.items():
|
||||
if k not in meter_dict.keys():
|
||||
meter_dict[k] = AverageMeter()
|
||||
meter_dict[k].update(v[0], v[1])
|
||||
meter_dict[k].update(v)
|
||||
|
||||
# backward
|
||||
optimizer.zero_grad()
|
||||
@ -111,9 +111,7 @@ if __name__ == '__main__':
|
||||
scaler.update()
|
||||
|
||||
# time and print
|
||||
current_iter = (epoch - 1) * len(train_loader) + i
|
||||
max_iter = cfg.epochs * len(train_loader)
|
||||
remain_iter = max_iter - current_iter
|
||||
remain_iter = len(train_loader) * (cfg.epochs - epoch + 1) - i
|
||||
iter_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
remain_time = remain_iter * iter_time.avg
|
||||
|
||||
Loading…
Reference in New Issue
Block a user