mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
support distributed training
This commit is contained in:
parent
4090322ae0
commit
c620cfc435
73
configs/softgroup_scannet.yaml
Normal file
73
configs/softgroup_scannet.yaml
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
model:
|
||||||
|
channels: 32
|
||||||
|
num_blocks: 7
|
||||||
|
semantic_classes: 20
|
||||||
|
instance_classes: 18
|
||||||
|
sem2ins_classes: []
|
||||||
|
semantic_only: False
|
||||||
|
ignore_label: -100
|
||||||
|
grouping_cfg:
|
||||||
|
score_thr: 0.2
|
||||||
|
radius: 0.04
|
||||||
|
mean_active: 300
|
||||||
|
class_numpoint_mean: [-1., -1., 3917., 12056., 2303.,
|
||||||
|
8331., 3948., 3166., 5629., 11719.,
|
||||||
|
1003., 3317., 4912., 10221., 3889.,
|
||||||
|
4136., 2120., 945., 3967., 2589.]
|
||||||
|
npoint_thr: 0.05 # absolute if class_numpoint == -1, relative if class_numpoint != -1
|
||||||
|
ignore_classes: [0, 1]
|
||||||
|
instance_voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: 20
|
||||||
|
train_cfg:
|
||||||
|
max_proposal_num: 200
|
||||||
|
pos_iou_thr: 0.5
|
||||||
|
test_cfg:
|
||||||
|
x4_split: False
|
||||||
|
cls_score_thr: 0.001
|
||||||
|
mask_score_thr: -0.5
|
||||||
|
min_npoint: 100
|
||||||
|
fixed_modules: ['input_conv', 'unet', 'output_layer', 'semantic_linear', 'offset_linear']
|
||||||
|
|
||||||
|
data:
|
||||||
|
train:
|
||||||
|
type: 'scannetv2'
|
||||||
|
data_root: 'dataset/scannetv2'
|
||||||
|
prefix: 'train'
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
training: True
|
||||||
|
repeat: 4
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
test:
|
||||||
|
type: 'scannetv2'
|
||||||
|
data_root: 'dataset/scannetv2'
|
||||||
|
prefix: 'val'
|
||||||
|
suffix: '_inst_nostuff.pth'
|
||||||
|
training: False
|
||||||
|
voxel_cfg:
|
||||||
|
scale: 50
|
||||||
|
spatial_shape: [128, 512]
|
||||||
|
max_npoint: 250000
|
||||||
|
min_npoint: 5000
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
train:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
test:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 1
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
type: 'Adam'
|
||||||
|
lr: 0.004
|
||||||
|
|
||||||
|
epochs: 128
|
||||||
|
step_epoch: 50
|
||||||
|
save_freq: 4
|
||||||
|
pretrain: 'work_dirs/softgroup_scannet_backbone_spconv2_dist/epoch_116.pth'
|
||||||
|
work_dir: 'work_dirs/softgroup_scannet_spconv2_dist'
|
||||||
@ -36,6 +36,7 @@ data:
|
|||||||
prefix: 'train'
|
prefix: 'train'
|
||||||
suffix: '_inst_nostuff.pth'
|
suffix: '_inst_nostuff.pth'
|
||||||
training: True
|
training: True
|
||||||
|
repeat: 4
|
||||||
voxel_cfg:
|
voxel_cfg:
|
||||||
scale: 50
|
scale: 50
|
||||||
spatial_shape: [128, 512]
|
spatial_shape: [128, 512]
|
||||||
@ -63,10 +64,10 @@ dataloader:
|
|||||||
|
|
||||||
optimizer:
|
optimizer:
|
||||||
type: 'Adam'
|
type: 'Adam'
|
||||||
lr: 0.001
|
lr: 0.004
|
||||||
|
|
||||||
epochs: 512
|
epochs: 128
|
||||||
step_epoch: 200
|
step_epoch: 50
|
||||||
save_freq: 16
|
save_freq: 4
|
||||||
pretrain: ''
|
pretrain: ''
|
||||||
work_dir: 'work_dirs/softgroup_scannet_backbone'
|
work_dir: 'work_dirs/softgroup_scannet_backbone'
|
||||||
|
|||||||
6
dist_train.sh
Executable file
6
dist_train.sh
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
CONFIG=$1
|
||||||
|
GPUS=$2
|
||||||
|
PORT=${PORT:-29500}
|
||||||
|
|
||||||
|
OMP_NUM_THREADS=1 torchrun --nproc_per_node=$GPUS --master_port=$PORT ./train.py --dist $CONFIG ${@:3}
|
||||||
@ -1,4 +1,5 @@
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from .s3dis import S3DISDataset
|
from .s3dis import S3DISDataset
|
||||||
from .scannetv2 import ScanNetDataset
|
from .scannetv2 import ScanNetDataset
|
||||||
@ -19,14 +20,19 @@ def build_dataset(data_cfg, logger):
|
|||||||
raise ValueError(f'Unknown {data_type}')
|
raise ValueError(f'Unknown {data_type}')
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(dataset, batch_size=1, num_workers=1, training=True):
|
def build_dataloader(dataset, batch_size=1, num_workers=1, training=True, dist=False):
|
||||||
|
shuffle = training
|
||||||
|
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist else None
|
||||||
|
if sampler is not None:
|
||||||
|
shuffle = False
|
||||||
if training:
|
if training:
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
shuffle=True,
|
shuffle=shuffle,
|
||||||
|
sampler=sampler,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from .dist import get_dist_info, init_dist
|
||||||
from .logger import get_root_logger
|
from .logger import get_root_logger
|
||||||
from .optim import build_optimizer
|
from .optim import build_optimizer
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
|||||||
@ -1,3 +1,7 @@
|
|||||||
|
import functools
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
|
||||||
@ -9,3 +13,21 @@ def get_dist_info():
|
|||||||
rank = 0
|
rank = 0
|
||||||
world_size = 1
|
world_size = 1
|
||||||
return rank, world_size
|
return rank, world_size
|
||||||
|
|
||||||
|
|
||||||
|
def init_dist(backend='nccl', **kwargs):
|
||||||
|
rank = int(os.environ['RANK'])
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(rank % num_gpus)
|
||||||
|
dist.init_process_group(backend=backend, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def master_only(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
rank, _ = get_dist_info()
|
||||||
|
if rank == 0:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from math import cos, pi
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .dist import master_only
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter(object):
|
class AverageMeter(object):
|
||||||
"""Computes and stores the average and current value."""
|
"""Computes and stores the average and current value."""
|
||||||
@ -59,6 +61,7 @@ def weights_to_cpu(state_dict):
|
|||||||
return state_dict_cpu
|
return state_dict_cpu
|
||||||
|
|
||||||
|
|
||||||
|
@master_only
|
||||||
def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16):
|
def checkpoint_save(epoch, model, optimizer, work_dir, save_freq=16):
|
||||||
f = os.path.join(work_dir, f'epoch_{epoch}.pth')
|
f = os.path.join(work_dir, f'epoch_{epoch}.pth')
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
|
|||||||
16
train.py
16
train.py
@ -12,15 +12,17 @@ from softgroup.data import build_dataloader, build_dataset
|
|||||||
from softgroup.evaluation import ScanNetEval, evaluate_semantic_acc, evaluate_semantic_miou
|
from softgroup.evaluation import ScanNetEval, evaluate_semantic_acc, evaluate_semantic_miou
|
||||||
from softgroup.model import SoftGroup
|
from softgroup.model import SoftGroup
|
||||||
from softgroup.util import (AverageMeter, build_optimizer, checkpoint_save, cosine_lr_after_step,
|
from softgroup.util import (AverageMeter, build_optimizer, checkpoint_save, cosine_lr_after_step,
|
||||||
get_max_memory, get_root_logger, is_multiple, is_power2,
|
get_max_memory, get_root_logger, init_dist, is_multiple, is_power2,
|
||||||
load_checkpoint)
|
load_checkpoint)
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser('SoftGroup')
|
parser = argparse.ArgumentParser('SoftGroup')
|
||||||
parser.add_argument('config', type=str, help='path to config file')
|
parser.add_argument('config', type=str, help='path to config file')
|
||||||
|
parser.add_argument('--dist', action='store_true', help='run with distributed parallel')
|
||||||
parser.add_argument('--resume', type=str, help='path to resume from')
|
parser.add_argument('--resume', type=str, help='path to resume from')
|
||||||
parser.add_argument('--work_dir', type=str, help='working directory')
|
parser.add_argument('--work_dir', type=str, help='working directory')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -32,6 +34,9 @@ if __name__ == '__main__':
|
|||||||
cfg_txt = open(args.config, 'r').read()
|
cfg_txt = open(args.config, 'r').read()
|
||||||
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
|
||||||
|
|
||||||
|
if args.dist:
|
||||||
|
init_dist()
|
||||||
|
|
||||||
# work_dir & logger
|
# work_dir & logger
|
||||||
if args.work_dir is not None:
|
if args.work_dir is not None:
|
||||||
cfg.work_dir = args.work_dir
|
cfg.work_dir = args.work_dir
|
||||||
@ -42,16 +47,20 @@ if __name__ == '__main__':
|
|||||||
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
||||||
logger = get_root_logger(log_file=log_file)
|
logger = get_root_logger(log_file=log_file)
|
||||||
logger.info(f'Config:\n{cfg_txt}')
|
logger.info(f'Config:\n{cfg_txt}')
|
||||||
|
logger.info(f'Distributed: {args.dist}')
|
||||||
shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
|
shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
|
||||||
writer = SummaryWriter(cfg.work_dir)
|
writer = SummaryWriter(cfg.work_dir)
|
||||||
|
|
||||||
# model
|
# model
|
||||||
model = SoftGroup(**cfg.model).cuda()
|
model = SoftGroup(**cfg.model).cuda()
|
||||||
|
if args.dist:
|
||||||
|
model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])
|
||||||
|
|
||||||
# data
|
# data
|
||||||
train_set = build_dataset(cfg.data.train, logger)
|
train_set = build_dataset(cfg.data.train, logger)
|
||||||
val_set = build_dataset(cfg.data.test, logger)
|
val_set = build_dataset(cfg.data.test, logger)
|
||||||
train_loader = build_dataloader(train_set, training=True, **cfg.dataloader.train)
|
train_loader = build_dataloader(
|
||||||
|
train_set, training=True, dist=args.dist, **cfg.dataloader.train)
|
||||||
val_loader = build_dataloader(val_set, training=False, **cfg.dataloader.test)
|
val_loader = build_dataloader(val_set, training=False, **cfg.dataloader.test)
|
||||||
|
|
||||||
# optim
|
# optim
|
||||||
@ -75,6 +84,9 @@ if __name__ == '__main__':
|
|||||||
meter_dict = {}
|
meter_dict = {}
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
|
if train_loader.sampler is not None and args.dist:
|
||||||
|
train_loader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
for i, batch in enumerate(train_loader, start=1):
|
for i, batch in enumerate(train_loader, start=1):
|
||||||
data_time.update(time.time() - end)
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user