From 3475ab88b955caaa6ea79d9b1ede4959645d0335 Mon Sep 17 00:00:00 2001 From: Thang Vu Date: Sun, 10 Apr 2022 02:57:11 +0000 Subject: [PATCH] support mix precision training --- configs/softgroup_scannet_backbone.yaml | 1 + configs/softgroup_scannet_backbone_fp16.yaml | 74 ++++++++++++++++++++ configs/softgroup_scannet_fp16.yaml | 74 ++++++++++++++++++++ softgroup/data/custom.py | 66 ++++++++--------- softgroup/data/s3dis.py | 32 ++++----- softgroup/data/scannetv2.py | 4 +- softgroup/model/softgroup.py | 62 +++++++--------- softgroup/util/__init__.py | 1 + softgroup/util/fp16.py | 66 +++++++++++++++++ train.py | 15 ++-- 10 files changed, 301 insertions(+), 94 deletions(-) create mode 100644 configs/softgroup_scannet_backbone_fp16.yaml create mode 100644 configs/softgroup_scannet_fp16.yaml create mode 100644 softgroup/util/fp16.py diff --git a/configs/softgroup_scannet_backbone.yaml b/configs/softgroup_scannet_backbone.yaml index e489273..fb354a6 100644 --- a/configs/softgroup_scannet_backbone.yaml +++ b/configs/softgroup_scannet_backbone.yaml @@ -66,6 +66,7 @@ optimizer: type: 'Adam' lr: 0.004 +fp16: False epochs: 128 step_epoch: 50 save_freq: 4 diff --git a/configs/softgroup_scannet_backbone_fp16.yaml b/configs/softgroup_scannet_backbone_fp16.yaml new file mode 100644 index 0000000..81699dd --- /dev/null +++ b/configs/softgroup_scannet_backbone_fp16.yaml @@ -0,0 +1,74 @@ +model: + channels: 32 + num_blocks: 7 + semantic_classes: 20 + instance_classes: 18 + sem2ins_classes: [] + semantic_only: True + ignore_label: -100 + grouping_cfg: + score_thr: 0.2 + radius: 0.04 + mean_active: 300 + class_numpoint_mean: [-1., -1., 3917., 12056., 2303., + 8331., 3948., 3166., 5629., 11719., + 1003., 3317., 4912., 10221., 3889., + 4136., 2120., 945., 3967., 2589.] + npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1 + ignore_classes: [0, 1] + instance_voxel_cfg: + scale: 50 + spatial_shape: 20 + train_cfg: + max_proposal_num: 200 + pos_iou_thr: 0.5 + test_cfg: + x4_split: False + cls_score_thr: 0.001 + mask_score_thr: -0.5 + min_npoint: 100 + fixed_modules: [] + +data: + train: + type: 'scannetv2' + data_root: 'dataset/scannetv2' + prefix: 'train' + suffix: '_inst_nostuff.pth' + training: True + repeat: 4 + voxel_cfg: + scale: 50 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + test: + type: 'scannetv2' + data_root: 'dataset/scannetv2' + prefix: 'val' + suffix: '_inst_nostuff.pth' + training: False + voxel_cfg: + scale: 50 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + +dataloader: + train: + batch_size: 4 + num_workers: 4 + test: + batch_size: 1 + num_workers: 1 + +optimizer: + type: 'Adam' + lr: 0.004 + +fp16: True +epochs: 128 +step_epoch: 50 +save_freq: 4 +pretrain: '' +work_dir: '' diff --git a/configs/softgroup_scannet_fp16.yaml b/configs/softgroup_scannet_fp16.yaml new file mode 100644 index 0000000..d21d1bc --- /dev/null +++ b/configs/softgroup_scannet_fp16.yaml @@ -0,0 +1,74 @@ +model: + channels: 32 + num_blocks: 7 + semantic_classes: 20 + instance_classes: 18 + sem2ins_classes: [] + semantic_only: False + ignore_label: -100 + grouping_cfg: + score_thr: 0.2 + radius: 0.04 + mean_active: 300 + class_numpoint_mean: [-1., -1., 3917., 12056., 2303., + 8331., 3948., 3166., 5629., 11719., + 1003., 3317., 4912., 10221., 3889., + 4136., 2120., 945., 3967., 2589.] + npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1 + ignore_classes: [0, 1] + instance_voxel_cfg: + scale: 50 + spatial_shape: 20 + train_cfg: + max_proposal_num: 200 + pos_iou_thr: 0.5 + test_cfg: + x4_split: False + cls_score_thr: 0.001 + mask_score_thr: -0.5 + min_npoint: 100 + fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear'] + +data: + train: + type: 'scannetv2' + data_root: 'dataset/scannetv2' + prefix: 'train' + suffix: '_inst_nostuff.pth' + training: True + repeat: 4 + voxel_cfg: + scale: 50 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + test: + type: 'scannetv2' + data_root: 'dataset/scannetv2' + prefix: 'val' + suffix: '_inst_nostuff.pth' + training: False + voxel_cfg: + scale: 50 + spatial_shape: [128, 512] + max_npoint: 250000 + min_npoint: 5000 + +dataloader: + train: + batch_size: 4 + num_workers: 4 + test: + batch_size: 1 + num_workers: 1 + +optimizer: + type: 'Adam' + lr: 0.004 + +fp16: True +epochs: 128 +step_epoch: 50 +save_freq: 4 +pretrain: 'work_dirs/softgroup_scannet_backbone_spconv2_dist/epoch_116.pth' +work_dir: '' diff --git a/softgroup/data/custom.py b/softgroup/data/custom.py index 9cf8ab6..d62ce70 100644 --- a/softgroup/data/custom.py +++ b/softgroup/data/custom.py @@ -70,7 +70,7 @@ class CustomDataset(Dataset): return x + g(x) * mag - def getInstanceInfo(self, xyz, instance_label, label): + def getInstanceInfo(self, xyz, instance_label, semantic_label): pt_mean = np.ones((xyz.shape[0], 3), dtype=np.float32) * -100.0 instance_pointnum = [] instance_cls = [] @@ -80,8 +80,8 @@ class CustomDataset(Dataset): xyz_i = xyz[inst_idx_i] pt_mean[inst_idx_i] = xyz_i.mean(0) instance_pointnum.append(inst_idx_i[0].size) - cls_loc = inst_idx_i[0][0] - instance_cls.append(label[cls_loc]) + cls_idx = inst_idx_i[0][0] + instance_cls.append(semantic_label[cls_idx]) pt_offset_label = pt_mean - xyz return instance_num, instance_pointnum, instance_cls, pt_offset_label @@ -122,7 +122,7 @@ class CustomDataset(Dataset): j += 1 return instance_label - def transform_train(self, xyz, rgb, label, instance_label): + def transform_train(self, xyz, rgb, semantic_label, instance_label): xyz_middle = self.dataAugment(xyz, True, True, True) xyz = xyz_middle * self.voxel_cfg.scale xyz = self.elastic(xyz, 6 * self.voxel_cfg.scale // 50, 40 * self.voxel_cfg.scale / 50) @@ -140,17 +140,17 @@ class CustomDataset(Dataset): xyz = xyz[valid_idxs] xyz_middle = xyz_middle[valid_idxs] rgb = rgb[valid_idxs] - label = label[valid_idxs] + semantic_label = semantic_label[valid_idxs] instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) - return xyz, xyz_middle, rgb, label, instance_label + return xyz, xyz_middle, rgb, semantic_label, instance_label - def transform_test(self, xyz, rgb, label, instance_label): + def transform_test(self, xyz, rgb, semantic_label, instance_label): xyz_middle = self.dataAugment(xyz, False, True, True) xyz = xyz_middle * self.voxel_cfg.scale xyz -= xyz.min(0) valid_idxs = np.ones(xyz.shape[0], dtype=bool) instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this - return xyz, xyz_middle, rgb, label, instance_label + return xyz, xyz_middle, rgb, semantic_label, instance_label def __getitem__(self, index): filename = self.filenames[index] @@ -159,26 +159,26 @@ class CustomDataset(Dataset): data = self.transform_train(*data) if self.training else self.transform_test(*data) if data is None: return None - xyz, xyz_middle, rgb, label, instance_label = data - info = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32), label) + xyz, xyz_middle, rgb, semantic_label, instance_label = data + info = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32), semantic_label) inst_num, inst_pointnum, inst_cls, pt_offset_label = info - loc = torch.from_numpy(xyz).long() - loc_float = torch.from_numpy(xyz_middle) + coord = torch.from_numpy(xyz).long() + coord_float = torch.from_numpy(xyz_middle) feat = torch.from_numpy(rgb).float() if self.training: feat += torch.randn(3) * 0.1 - label = torch.from_numpy(label) + semantic_label = torch.from_numpy(semantic_label) instance_label = torch.from_numpy(instance_label) pt_offset_label = torch.from_numpy(pt_offset_label) - return (scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum, - inst_cls, pt_offset_label) + return (scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num, + inst_pointnum, inst_cls, pt_offset_label) def collate_fn(self, batch): scan_ids = [] - locs = [] - locs_float = [] + coords = [] + coords_float = [] feats = [] - labels = [] + semantic_labels = [] instance_labels = [] instance_pointnum = [] # (total_nInst), int @@ -190,15 +190,15 @@ class CustomDataset(Dataset): for data in batch: if data is None: continue - (scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum, - inst_cls, pt_offset_label) = data + (scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num, + inst_pointnum, inst_cls, pt_offset_label) = data instance_label[np.where(instance_label != -100)] += total_inst_num total_inst_num += inst_num scan_ids.append(scan_id) - locs.append(torch.cat([loc.new_full((loc.size(0), 1), batch_id), loc], 1)) - locs_float.append(loc_float) + coords.append(torch.cat([coord.new_full((coord.size(0), 1), batch_id), coord], 1)) + coords_float.append(coord_float) feats.append(feat) - labels.append(label) + semantic_labels.append(semantic_label) instance_labels.append(instance_label) instance_pointnum.extend(inst_pointnum) instance_cls.extend(inst_cls) @@ -209,29 +209,29 @@ class CustomDataset(Dataset): self.logger.info(f'batch is truncated from size {len(batch)} to {batch_id}') # merge all the scenes in the batch - locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0] - batch_idxs = locs[:, 0].int() - locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3) + coords = torch.cat(coords, 0) # long (N, 1 + 3), the batch item idx is put in coords[:, 0] + batch_idxs = coords[:, 0].int() + coords_float = torch.cat(coords_float, 0).to(torch.float32) # float (N, 3) feats = torch.cat(feats, 0) # float (N, C) - labels = torch.cat(labels, 0).long() # long (N) + semantic_labels = torch.cat(semantic_labels, 0).long() # long (N) instance_labels = torch.cat(instance_labels, 0).long() # long (N) instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst) instance_cls = torch.tensor(instance_cls, dtype=torch.long) # long (total_nInst) pt_offset_labels = torch.cat(pt_offset_labels).float() spatial_shape = np.clip( - locs.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None) - voxel_locs, v2p_map, p2v_map = voxelization_idx(locs, 1) + coords.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None) + voxel_coords, v2p_map, p2v_map = voxelization_idx(coords, 1) return { 'scan_ids': scan_ids, - 'locs': locs, + 'coords': coords, 'batch_idxs': batch_idxs, - 'voxel_locs': voxel_locs, + 'voxel_coords': voxel_coords, 'p2v_map': p2v_map, 'v2p_map': v2p_map, - 'locs_float': locs_float, + 'coords_float': coords_float, 'feats': feats, - 'labels': labels, + 'semantic_labels': semantic_labels, 'instance_labels': instance_labels, 'instance_pointnum': instance_pointnum, 'instance_cls': instance_cls, diff --git a/softgroup/data/s3dis.py b/softgroup/data/s3dis.py index 14e615b..0bf95e6 100644 --- a/softgroup/data/s3dis.py +++ b/softgroup/data/s3dis.py @@ -26,21 +26,21 @@ class S3DISDataset(CustomDataset): def load(self, filename): # TODO make file load results consistent - xyz, rgb, label, instance_label, _, _ = torch.load(filename) + xyz, rgb, semantic_label, instance_label, _, _ = torch.load(filename) # subsample data if self.training: N = xyz.shape[0] inds = np.random.choice(N, int(N * 0.25), replace=False) xyz = xyz[inds] rgb = rgb[inds] - label = label[inds] + semantic_label = semantic_label[inds] instance_label = self.getCroppedInstLabel(instance_label, inds) - return xyz, rgb, label, instance_label + return xyz, rgb, semantic_label, instance_label def crop(self, xyz, step=64): return super().crop(xyz, step=step) - def transform_test(self, xyz, rgb, label, instance_label): + def transform_test(self, xyz, rgb, semantic_label, instance_label): # devide into 4 piecies inds = np.arange(xyz.shape[0]) piece_1 = inds[::4] @@ -64,37 +64,37 @@ class S3DISDataset(CustomDataset): rgb = np.concatenate(rgb_list, 0) valid_idxs = np.ones(xyz.shape[0], dtype=bool) instance_label = self.getCroppedInstLabel(instance_label, valid_idxs) # TODO remove this - return xyz, xyz_middle, rgb, label, instance_label + return xyz, xyz_middle, rgb, semantic_label, instance_label def collate_fn(self, batch): if self.training: return super().collate_fn(batch) # assume 1 scan only - (scan_id, loc, loc_float, feat, label, instance_label, inst_num, inst_pointnum, inst_cls, - pt_offset_label) = batch[0] + (scan_id, coord, coord_float, feat, semantic_label, instance_label, inst_num, inst_pointnum, + inst_cls, pt_offset_label) = batch[0] scan_ids = [scan_id] - locs = loc.long() - batch_idxs = torch.zeros_like(loc[:, 0].int()) - locs_float = loc_float.float() + coords = coord.long() + batch_idxs = torch.zeros_like(coord[:, 0].int()) + coords_float = coord_float.float() feats = feat.float() - labels = label.long() + semantic_labels = semantic_label.long() instance_labels = instance_label.long() instance_pointnum = torch.tensor([inst_pointnum], dtype=torch.int) instance_cls = torch.tensor([inst_cls], dtype=torch.long) pt_offset_labels = pt_offset_label.float() - spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0], + spatial_shape = np.clip((coords.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0], None) - voxel_locs, v2p_map, p2v_map = voxelization_idx(locs, 4) + voxel_coords, v2p_map, p2v_map = voxelization_idx(coords, 4) return { 'scan_ids': scan_ids, 'batch_idxs': batch_idxs, - 'voxel_locs': voxel_locs, + 'voxel_coords': voxel_coords, 'p2v_map': p2v_map, 'v2p_map': v2p_map, - 'locs_float': locs_float, + 'coords_float': coords_float, 'feats': feats, - 'labels': labels, + 'semantic_labels': semantic_labels, 'instance_labels': instance_labels, 'instance_pointnum': instance_pointnum, 'instance_cls': instance_cls, diff --git a/softgroup/data/scannetv2.py b/softgroup/data/scannetv2.py index 6acabb1..c6affa0 100644 --- a/softgroup/data/scannetv2.py +++ b/softgroup/data/scannetv2.py @@ -7,8 +7,8 @@ class ScanNetDataset(CustomDataset): 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture') - def getInstanceInfo(self, xyz, instance_label, label): - ret = super().getInstanceInfo(xyz, instance_label, label) + def getInstanceInfo(self, xyz, instance_label, semantic_label): + ret = super().getInstanceInfo(xyz, instance_label, semantic_label) instance_num, instance_pointnum, instance_cls, pt_offset_label = ret instance_cls = [x - 2 if x != -100 else x for x in instance_cls] return instance_num, instance_pointnum, instance_cls, pt_offset_label diff --git a/softgroup/model/softgroup.py b/softgroup/model/softgroup.py index 165ba99..20aa593 100644 --- a/softgroup/model/softgroup.py +++ b/softgroup/model/softgroup.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from ..lib.softgroup_ops import (ballquery_batch_p, bfs_cluster, get_mask_iou_on_cluster, get_mask_iou_on_pred, get_mask_label, global_avg_pool, sec_max, sec_min, voxelization, voxelization_idx) +from ..util import force_fp32 from .blocks import MLP, ResidualBlock, UBlock @@ -80,25 +81,13 @@ class SoftGroup(nn.Module): def forward(self, batch, return_loss=False): if return_loss: - return self.forward_train(batch) + return self.forward_train(**batch) else: - return self.forward_test(batch) - - def forward_train(self, batch): - batch_idxs = batch['batch_idxs'].cuda() - voxel_coords = batch['voxel_locs'].cuda() - p2v_map = batch['p2v_map'].cuda() - v2p_map = batch['v2p_map'].cuda() - coords_float = batch['locs_float'].cuda() - feats = batch['feats'].cuda() - semantic_labels = batch['labels'].cuda() - instance_labels = batch['instance_labels'].cuda() - instance_pointnum = batch['instance_pointnum'].cuda() - instance_cls = batch['instance_cls'].cuda() - pt_offset_labels = batch['pt_offset_labels'].cuda() - spatial_shape = batch['spatial_shape'] - batch_size = batch['batch_size'] + return self.forward_test(**batch) + def forward_train(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats, + semantic_labels, instance_labels, instance_pointnum, instance_cls, + pt_offset_labels, spatial_shape, batch_size, **kwargs): losses = {} feats = torch.cat((feats, coords_float), 1) voxel_feats = voxelization(feats, p2v_map) @@ -155,6 +144,7 @@ class SoftGroup(nn.Module): losses['offset_loss'] = (offset_loss, pos_inds.sum()) return losses + @force_fp32(apply_to=('cls_scores', 'mask_scores', 'iou_scores')) def instance_loss(self, cls_scores, mask_scores, iou_scores, proposals_idx, proposals_offset, instance_labels, instance_pointnum, instance_cls, instance_batch_idxs): losses = {} @@ -208,18 +198,9 @@ class SoftGroup(nn.Module): losses['iou_score_loss'] = (iou_score_loss, iou_score_weight.sum()) return losses - def forward_test(self, batch): - batch_idxs = batch['batch_idxs'].cuda() - voxel_coords = batch['voxel_locs'].cuda() - p2v_map = batch['p2v_map'].cuda() - v2p_map = batch['v2p_map'].cuda() - coords_float = batch['locs_float'].cuda() - feats = batch['feats'].cuda() - labels = batch['labels'].cuda() - instance_labels = batch['instance_labels'].cuda() - spatial_shape = batch['spatial_shape'] - batch_size = batch['batch_size'] - + def forward_test(self, batch_idxs, voxel_coords, p2v_map, v2p_map, coords_float, feats, + semantic_labels, instance_labels, spatial_shape, batch_size, scan_ids, + **kwargs): feats = torch.cat((feats, coords_float), 1) voxel_feats = voxelization(feats, p2v_map) input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size) @@ -227,7 +208,8 @@ class SoftGroup(nn.Module): input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split) semantic_preds = semantic_scores.max(1)[1] ret = dict( - semantic_preds=semantic_preds.cpu().numpy(), semantic_labels=labels.cpu().numpy()) + semantic_preds=semantic_preds.cpu().numpy(), + semantic_labels=semantic_labels.cpu().numpy()) if not self.semantic_only: proposals_idx, proposals_offset = self.forward_grouping(semantic_scores, pt_offsets, batch_idxs, coords_float, @@ -236,10 +218,9 @@ class SoftGroup(nn.Module): output_feats, coords_float, **self.instance_voxel_cfg) _, cls_scores, iou_scores, mask_scores = self.forward_instance(inst_feats, inst_map) - pred_instances = self.get_instances(batch['scan_ids'][0], proposals_idx, - semantic_scores, cls_scores, iou_scores, - mask_scores) - gt_instances = self.get_gt_instances(labels, instance_labels) + pred_instances = self.get_instances(scan_ids[0], proposals_idx, semantic_scores, + cls_scores, iou_scores, mask_scores) + gt_instances = self.get_gt_instances(semantic_labels, instance_labels) ret.update(dict(pred_instances=pred_instances, gt_instances=gt_instances)) return ret @@ -289,6 +270,7 @@ class SoftGroup(nn.Module): x_new[p] = x_split[i] return x_new + @force_fp32(apply_to=('semantic_scores, pt_offsets')) def forward_grouping(self, semantic_scores, pt_offsets, @@ -350,6 +332,7 @@ class SoftGroup(nn.Module): return instance_batch_idxs, cls_scores, iou_scores, mask_scores + @force_fp32(apply_to=('semantic_scores', 'cls_scores', 'iou_scores', 'mask_scores')) def get_instances(self, scan_id, proposals_idx, semantic_scores, cls_scores, iou_scores, mask_scores): num_instances = cls_scores.size(0) @@ -402,19 +385,21 @@ class SoftGroup(nn.Module): instances.append(pred) return instances - def get_gt_instances(self, labels, instance_labels): + def get_gt_instances(self, semantic_labels, instance_labels): """Get gt instances for evaluation.""" # convert to evaluation format 0: ignore, 1->N: valid label_shift = self.semantic_classes - self.instance_classes - labels = labels - label_shift + 1 - labels[labels < 0] = 0 + semantic_labels = semantic_labels - label_shift + 1 + semantic_labels[semantic_labels < 0] = 0 instance_labels += 1 ignore_inds = instance_labels < 0 - gt_ins = labels * 1000 + instance_labels + # scannet encoding rule + gt_ins = semantic_labels * 1000 + instance_labels gt_ins[ignore_inds] = 0 gt_ins = gt_ins.cpu().numpy() return gt_ins + @force_fp32(apply_to='feats') def clusters_voxelization(self, clusters_idx, clusters_offset, @@ -466,6 +451,7 @@ class SoftGroup(nn.Module): assert batch_offsets[-1] == batch_idxs.shape[0] return batch_offsets + @force_fp32(apply_to=('x')) def global_pool(self, x, expand=False): indices = x.indices[:, 0] batch_counts = torch.bincount(indices) diff --git a/softgroup/util/__init__.py b/softgroup/util/__init__.py index 951c05b..f0c8d4d 100644 --- a/softgroup/util/__init__.py +++ b/softgroup/util/__init__.py @@ -1,4 +1,5 @@ from .dist import get_dist_info, init_dist +from .fp16 import force_fp32 from .logger import get_root_logger from .optim import build_optimizer from .utils import * diff --git a/softgroup/util/fp16.py b/softgroup/util/fp16.py new file mode 100644 index 0000000..692f218 --- /dev/null +++ b/softgroup/util/fp16.py @@ -0,0 +1,66 @@ +# Simplfied from mmcv. +# Directly use torch.cuda.amp.autocast for mix-precision and support sparse tensor +import functools +from collections import abc +from inspect import getfullargspec + +import spconv.pytorch as spconv +import torch + + +def cast_tensor_type(inputs, src_type, dst_type): + if isinstance(inputs, torch.Tensor): + return inputs.to(dst_type) if inputs.dtype == src_type else inputs + elif isinstance(inputs, spconv.SparseConvTensor): + if inputs.features.dtype == src_type: + features = inputs.features.to(dst_type) + inputs = inputs.replace_feature(features) + return inputs + elif isinstance(inputs, abc.Mapping): + return type(inputs)({k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items()}) + elif isinstance(inputs, abc.Iterable): + return type(inputs)(cast_tensor_type(item, src_type, dst_type) for item in inputs) + else: + return inputs + + +def force_fp32(apply_to=None, out_fp16=False): + + def force_fp32_wrapper(old_func): + + @functools.wraps(old_func) + def new_func(*args, **kwargs): + if not isinstance(args[0], torch.nn.Module): + raise TypeError('@force_fp32 can only be used to decorate the ' + 'method of nn.Module') + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + if args: + arg_names = args_info.args[:len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append(cast_tensor_type(args[i], torch.half, torch.float)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = dict() + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type(arg_value, torch.half, torch.float) + else: + new_kwargs[arg_name] = arg_value + with torch.cuda.amp.autocast(enabled=False): + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp16: + output = cast_tensor_type(output, torch.float, torch.half) + return output + + return new_func + + return force_fp32_wrapper diff --git a/train.py b/train.py index 94762a1..a29e6a2 100644 --- a/train.py +++ b/train.py @@ -38,9 +38,9 @@ if __name__ == '__main__': init_dist() # work_dir & logger - if args.work_dir is not None: + if args.work_dir: cfg.work_dir = args.work_dir - elif cfg.get('work_dir', None) is None: + else: cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) os.makedirs(osp.abspath(cfg.work_dir), exist_ok=True) timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) @@ -48,6 +48,7 @@ if __name__ == '__main__': logger = get_root_logger(log_file=log_file) logger.info(f'Config:\n{cfg_txt}') logger.info(f'Distributed: {args.dist}') + logger.info(f'Mix precision training: {cfg.fp16}') shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config))) writer = SummaryWriter(cfg.work_dir) @@ -55,6 +56,7 @@ if __name__ == '__main__': model = SoftGroup(**cfg.model).cuda() if args.dist: model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) + scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16) # data train_set = build_dataset(cfg.data.train, logger) @@ -91,7 +93,9 @@ if __name__ == '__main__': data_time.update(time.time() - end) cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs) - loss, log_vars = model(batch, return_loss=True) + + with torch.cuda.amp.autocast(enabled=cfg.fp16): + loss, log_vars = model(batch, return_loss=True) # meter_dict for k, v in log_vars.items(): @@ -101,8 +105,9 @@ if __name__ == '__main__': # backward optimizer.zero_grad() - loss.backward() - optimizer.step() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() # time and print current_iter = (epoch - 1) * len(train_loader) + i