This commit is contained in:
Thang Vu 2022-04-11 02:03:29 +00:00
parent c68be64a35
commit 08eb82816b
4 changed files with 8 additions and 7 deletions

View File

@ -66,8 +66,9 @@ optimizer:
type: 'Adam' type: 'Adam'
lr: 0.004 lr: 0.004
fp16: False
epochs: 128 epochs: 128
step_epoch: 50 step_epoch: 50
save_freq: 4 save_freq: 4
pretrain: 'work_dirs/softgroup_scannet_backbone_spconv2_dist/epoch_116.pth' pretrain: 'work_dirs/softgroup_scannet_backbone/epoch_120.pth'
work_dir: 'work_dirs/softgroup_scannet_spconv2_dist' work_dir: ''

View File

@ -71,4 +71,4 @@ epochs: 128
step_epoch: 50 step_epoch: 50
save_freq: 4 save_freq: 4
pretrain: '' pretrain: ''
work_dir: 'work_dirs/softgroup_scannet_backbone' work_dir: ''

View File

@ -1,9 +1,10 @@
cmake>=3.13.2 munch
pandas pandas
plyfile plyfile
pyyaml==5.4.1 pyyaml==5.4.1
scikit-learn scikit-learn
scipy scipy
six six
tensorboard
tensorboardX tensorboardX
torch==1.1 tqdm

View File

@ -33,10 +33,9 @@ if __name__ == '__main__':
cfg = Munch.fromDict(yaml.safe_load(cfg_txt)) cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
logger = get_root_logger() logger = get_root_logger()
model = SoftGroup(**cfg.model) model = SoftGroup(**cfg.model).cuda()
logger.info(f'Load state dict from {args.checkpoint}') logger.info(f'Load state dict from {args.checkpoint}')
load_checkpoint(args.checkpoint, logger, model) load_checkpoint(args.checkpoint, logger, model)
model.cuda()
dataset = build_dataset(cfg.data.test, logger) dataset = build_dataset(cfg.data.test, logger)
dataloader = build_dataloader(dataset, training=False, **cfg.dataloader.test) dataloader = build_dataloader(dataset, training=False, **cfg.dataloader.test)