mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
support distributed training
This commit is contained in:
parent
4090322ae0
commit
c620cfc435
73
configs/softgroup_scannet.yaml
Normal file
73
configs/softgroup_scannet.yaml
Normal file
@ -0,0 +1,73 @@
|
||||
model:
|
||||
channels: 32
|
||||
num_blocks: 7
|
||||
semantic_classes: 20
|
||||
instance_classes: 18
|
||||
sem2ins_classes: []
|
||||
semantic_only: False
|
||||
ignore_label: -100
|
||||
grouping_cfg:
|
||||
score_thr: 0.2
|
||||
radius: 0.04
|
||||
mean_active: 300
|
||||
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||
8331., 3948., 3166., 5629., 11719.,
|
||||
1003., 3317., 4912., 10221., 3889.,
|
||||
4136., 2120., 945., 3967., 2589.]
|
||||
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||
ignore_classes: [0, 1]
|
||||
instance_voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: 20
|
||||
train_cfg:
|
||||
max_proposal_num: 200
|
||||
pos_iou_thr: 0.5
|
||||
test_cfg:
|
||||
x4_split: False
|
||||
cls_score_thr: 0.001
|
||||
mask_score_thr: -0.5
|
||||
min_npoint: 100
|
||||
fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
|
||||
|
||||
data:
|
||||
train:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'train'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: True
|
||||
repeat: 4
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
test:
|
||||
type: 'scannetv2'
|
||||
data_root: 'dataset/scannetv2'
|
||||
prefix: 'val'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: False
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
max_npoint: 250000
|
||||
min_npoint: 5000
|
||||
|
||||
dataloader:
|
||||
train:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
test:
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.004
|
||||
|
||||
epochs: 128
|
||||
step_epoch: 50
|
||||
save_freq: 4
|
||||
pretrain: 'work_dirs/softgroup_scannet_backbone_spconv2_dist/epoch_116.pth'
|
||||
work_dir: 'work_dirs/softgroup_scannet_spconv2_dist'
|
||||
@ -36,6 +36,7 @@ data:
|
||||
prefix: 'train'
|
||||
suffix: '_inst_nostuff.pth'
|
||||
training: True
|
||||
repeat: 4
|
||||
voxel_cfg:
|
||||
scale: 50
|
||||
spatial_shape: [128, 512]
|
||||
@ -63,10 +64,10 @@ dataloader:
|
||||
|
||||
optimizer:
|
||||
type: 'Adam'
|
||||
lr: 0.001
|
||||
lr: 0.004
|
||||
|
||||
epochs: 512
|
||||
step_epoch: 200
|
||||
save_freq: 16
|
||||
epochs: 128
|
||||
step_epoch: 50
|
||||
save_freq: 4
|
||||
pretrain: ''
|
||||
work_dir: 'work_dirs/softgroup_scannet_backbone'
|
||||
|
||||
6
dist_train.sh
Executable file
6
dist_train.sh
Executable file
@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env bash
|
||||
CONFIG=$1
|
||||
GPUS=$2
|
||||
PORT=${PORT:-29500}
|
||||
|
||||
OMP_NUM_THREADS=1 torchrun --nproc_per_node=$GPUS --master_port=$PORT ./train.py --dist $CONFIG ${@:3}
|
||||
@ -1,4 +1,5 @@
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from .s3dis import S3DISDataset
|
||||
from .scannetv2 import ScanNetDataset
|
||||
@ -19,14 +20,19 @@ def build_dataset(data_cfg, logger):
|
||||
raise ValueError(f'Unknown {data_type}')
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch_size=1, num_workers=1, training=True):
|
||||
def build_dataloader(dataset, batch_size=1, num_workers=1, training=True, dist=False):
|
||||
shuffle = training
|
||||
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist else None
|
||||
if sampler is not None:
|
||||
shuffle = False
|
||||
if training:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=True,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
else:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .dist import get_dist_info, init_dist
|
||||
from .logger import get_root_logger
|
||||
from .optim import build_optimizer
|
||||
from .utils import *
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
import functools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
|
||||
@ -9,3 +13,21 @@ def get_dist_info():
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def init_dist(backend='nccl', **kwargs):
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def master_only(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -5,6 +5,8 @@ from math import cos, pi
|
||||
|
||||
import torch
|
||||
|
||||
from .dist import master_only
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value."""
|
||||
@ -59,6 +61,7 @@ def weights_to_cpu(state_dict):
|
||||
return state_dict_cpu
|
||||
|
||||
|
||||
@master_only
|
||||
def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16):
|
||||
f = os.path.join(work_dir, f'epoch_{epoch}.pth')
|
||||
checkpoint = {
|
||||
|
||||
16
train.py
16
train.py
@ -12,15 +12,17 @@ from softgroup.data import build_dataloader, build_dataset
|
||||
from softgroup.evaluation import ScanNetEval, evaluate_semantic_acc, evaluate_semantic_miou
|
||||
from softgroup.model import SoftGroup
|
||||
from softgroup.util import (AverageMeter, build_optimizer, checkpoint_save, cosine_lr_after_step,
|
||||
get_max_memory, get_root_logger, is_multiple, is_power2,
|
||||
get_max_memory, get_root_logger, init_dist, is_multiple, is_power2,
|
||||
load_checkpoint)
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('SoftGroup')
|
||||
parser.add_argument('config', type=str, help='path to config file')
|
||||
parser.add_argument('--dist', action='store_true', help='run with distributed parallel')
|
||||
parser.add_argument('--resume', type=str, help='path to resume from')
|
||||
parser.add_argument('--work_dir', type=str, help='working directory')
|
||||
args = parser.parse_args()
|
||||
@ -32,6 +34,9 @@ if __name__ == '__main__':
|
||||
cfg_txt = open(args.config, 'r').read()
|
||||
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
||||
|
||||
if args.dist:
|
||||
init_dist()
|
||||
|
||||
# work_dir & logger
|
||||
if args.work_dir is not None:
|
||||
cfg.work_dir = args.work_dir
|
||||
@ -42,16 +47,20 @@ if __name__ == '__main__':
|
||||
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
||||
logger = get_root_logger(log_file=log_file)
|
||||
logger.info(f'Config:\n{cfg_txt}')
|
||||
logger.info(f'Distributed: {args.dist}')
|
||||
shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
|
||||
writer = SummaryWriter(cfg.work_dir)
|
||||
|
||||
# model
|
||||
model = SoftGroup(**cfg.model).cuda()
|
||||
if args.dist:
|
||||
model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])
|
||||
|
||||
# data
|
||||
train_set = build_dataset(cfg.data.train, logger)
|
||||
val_set = build_dataset(cfg.data.test, logger)
|
||||
train_loader = build_dataloader(train_set, training=True, **cfg.dataloader.train)
|
||||
train_loader = build_dataloader(
|
||||
train_set, training=True, dist=args.dist, **cfg.dataloader.train)
|
||||
val_loader = build_dataloader(val_set, training=False, **cfg.dataloader.test)
|
||||
|
||||
# optim
|
||||
@ -75,6 +84,9 @@ if __name__ == '__main__':
|
||||
meter_dict = {}
|
||||
end = time.time()
|
||||
|
||||
if train_loader.sampler is not None and args.dist:
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
for i, batch in enumerate(train_loader, start=1):
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user