support distributed training

This commit is contained in:
Thang Vu 2022-04-09 14:42:07 +00:00
parent 4090322ae0
commit c620cfc435
8 changed files with 132 additions and 8 deletions

View File

@ -0,0 +1,73 @@
model:
channels: 32
num_blocks: 7
semantic_classes: 20
instance_classes: 18
sem2ins_classes: []
semantic_only: False
ignore_label: -100
grouping_cfg:
score_thr: 0.2
radius: 0.04
mean_active: 300
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
8331., 3948., 3166., 5629., 11719.,
1003., 3317., 4912., 10221., 3889.,
4136., 2120., 945., 3967., 2589.]
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
ignore_classes: [0, 1]
instance_voxel_cfg:
scale: 50
spatial_shape: 20
train_cfg:
max_proposal_num: 200
pos_iou_thr: 0.5
test_cfg:
x4_split: False
cls_score_thr: 0.001
mask_score_thr: -0.5
min_npoint: 100
fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
data:
train:
type: 'scannetv2'
data_root: 'dataset/scannetv2'
prefix: 'train'
suffix: '_inst_nostuff.pth'
training: True
repeat: 4
voxel_cfg:
scale: 50
spatial_shape: [128, 512]
max_npoint: 250000
min_npoint: 5000
test:
type: 'scannetv2'
data_root: 'dataset/scannetv2'
prefix: 'val'
suffix: '_inst_nostuff.pth'
training: False
voxel_cfg:
scale: 50
spatial_shape: [128, 512]
max_npoint: 250000
min_npoint: 5000
dataloader:
train:
batch_size: 4
num_workers: 4
test:
batch_size: 1
num_workers: 1
optimizer:
type: 'Adam'
lr: 0.004
epochs: 128
step_epoch: 50
save_freq: 4
pretrain: 'work_dirs/softgroup_scannet_backbone_spconv2_dist/epoch_116.pth'
work_dir: 'work_dirs/softgroup_scannet_spconv2_dist'

View File

@ -36,6 +36,7 @@ data:
prefix: 'train'
suffix: '_inst_nostuff.pth'
training: True
repeat: 4
voxel_cfg:
scale: 50
spatial_shape: [128, 512]
@ -63,10 +64,10 @@ dataloader:
optimizer:
type: 'Adam'
lr: 0.001
lr: 0.004
epochs: 512
step_epoch: 200
save_freq: 16
epochs: 128
step_epoch: 50
save_freq: 4
pretrain: ''
work_dir: 'work_dirs/softgroup_scannet_backbone'

6
dist_train.sh Executable file
View File

@ -0,0 +1,6 @@
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
PORT=${PORT:-29500}
OMP_NUM_THREADS=1 torchrun --nproc_per_node=$GPUS --master_port=$PORT ./train.py --dist $CONFIG ${@:3}

View File

@ -1,4 +1,5 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from .s3dis import S3DISDataset
from .scannetv2 import ScanNetDataset
@ -19,14 +20,19 @@ def build_dataset(data_cfg, logger):
raise ValueError(f'Unknown {data_type}')
def build_dataloader(dataset, batch_size=1, num_workers=1, training=True):
def build_dataloader(dataset, batch_size=1, num_workers=1, training=True, dist=False):
shuffle = training
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist else None
if sampler is not None:
shuffle = False
if training:
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=dataset.collate_fn,
shuffle=True,
shuffle=shuffle,
sampler=sampler,
drop_last=True,
pin_memory=True)
else:

View File

@ -1,3 +1,4 @@
from .dist import get_dist_info, init_dist
from .logger import get_root_logger
from .optim import build_optimizer
from .utils import *

View File

@ -1,3 +1,7 @@
import functools
import os
import torch
from torch import distributed as dist
@ -9,3 +13,21 @@ def get_dist_info():
rank = 0
world_size = 1
return rank, world_size
def init_dist(backend='nccl', **kwargs):
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper

View File

@ -5,6 +5,8 @@ from math import cos, pi
import torch
from .dist import master_only
class AverageMeter(object):
"""Computes and stores the average and current value."""
@ -59,6 +61,7 @@ def weights_to_cpu(state_dict):
return state_dict_cpu
@master_only
def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16):
f = os.path.join(work_dir, f'epoch_{epoch}.pth')
checkpoint = {

View File

@ -12,15 +12,17 @@ from softgroup.data import build_dataloader, build_dataset
from softgroup.evaluation import ScanNetEval, evaluate_semantic_acc, evaluate_semantic_miou
from softgroup.model import SoftGroup
from softgroup.util import (AverageMeter, build_optimizer, checkpoint_save, cosine_lr_after_step,
get_max_memory, get_root_logger, is_multiple, is_power2,
get_max_memory, get_root_logger, init_dist, is_multiple, is_power2,
load_checkpoint)
from tensorboardX import SummaryWriter
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser('SoftGroup')
parser.add_argument('config', type=str, help='path to config file')
parser.add_argument('--dist', action='store_true', help='run with distributed parallel')
parser.add_argument('--resume', type=str, help='path to resume from')
parser.add_argument('--work_dir', type=str, help='working directory')
args = parser.parse_args()
@ -32,6 +34,9 @@ if __name__ == '__main__':
cfg_txt = open(args.config, 'r').read()
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
if args.dist:
init_dist()
# work_dir & logger
if args.work_dir is not None:
cfg.work_dir = args.work_dir
@ -42,16 +47,20 @@ if __name__ == '__main__':
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file)
logger.info(f'Config:\n{cfg_txt}')
logger.info(f'Distributed: {args.dist}')
shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
writer = SummaryWriter(cfg.work_dir)
# model
model = SoftGroup(**cfg.model).cuda()
if args.dist:
model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])
# data
train_set = build_dataset(cfg.data.train, logger)
val_set = build_dataset(cfg.data.test, logger)
train_loader = build_dataloader(train_set, training=True, **cfg.dataloader.train)
train_loader = build_dataloader(
train_set, training=True, dist=args.dist, **cfg.dataloader.train)
val_loader = build_dataloader(val_set, training=False, **cfg.dataloader.test)
# optim
@ -75,6 +84,9 @@ if __name__ == '__main__':
meter_dict = {}
end = time.time()
if train_loader.sampler is not None and args.dist:
train_loader.sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader, start=1):
data_time.update(time.time() - end)