reorganize code

This commit is contained in:
Thang Vu 2022-04-08 11:01:48 +00:00
parent 7a20b6d6ae
commit 192f4bdd89
57 changed files with 220 additions and 326 deletions

View File

@ -17,9 +17,8 @@ from scipy.spatial import cKDTree
# raise ImportError("must install `pointgroup_ops` from lib")
def random_sample(coords: np.ndarray, colors: np.ndarray,
semantic_labels: np.ndarray, instance_labels: np.ndarray,
ratio: float):
def random_sample(coords: np.ndarray, colors: np.ndarray, semantic_labels: np.ndarray,
instance_labels: np.ndarray, ratio: float):
num_points = coords.shape[0]
num_sample = int(num_points * ratio)
sample_ids = sample(range(num_points), num_sample)
@ -33,31 +32,25 @@ def random_sample(coords: np.ndarray, colors: np.ndarray,
return coords, colors, semantic_labels, instance_labels
def voxelize(coords: np.ndarray, colors: np.ndarray,
semantic_labels: np.ndarray, instance_labels: np.ndarray,
voxel_size: float):
def voxelize(coords: np.ndarray, colors: np.ndarray, semantic_labels: np.ndarray,
instance_labels: np.ndarray, voxel_size: float):
# move to positive area
coords_offset = coords.min(0)
coords -= coords_offset
origin_coords = coords.copy()
# begin voxelize
num_points = coords.shape[0]
voxelize_coords = torch.from_numpy(coords /
voxel_size).long() # [num_point, 3]
voxelize_coords = torch.cat(
[torch.zeros(num_points).view(-1, 1).long(), voxelize_coords],
1) # [num_point, 1 + 3]
voxelize_coords = torch.from_numpy(coords / voxel_size).long() # [num_point, 3]
voxelize_coords = torch.cat([torch.zeros(num_points).view(-1, 1).long(), voxelize_coords],
1) # [num_point, 1 + 3]
# mode=4 is mean pooling
voxelize_coords, p2v_map, v2p_map = pointgroup_ops.voxelization_idx(
voxelize_coords, 1, 4)
voxelize_coords, p2v_map, v2p_map = pointgroup_ops.voxelization_idx(voxelize_coords, 1, 4)
v2p_map = v2p_map.cuda()
coords = torch.from_numpy(coords).float().cuda()
coords = pointgroup_ops.voxelization(coords, v2p_map,
4).cpu().numpy() # [num_voxel, 3]
coords = pointgroup_ops.voxelization(coords, v2p_map, 4).cpu().numpy() # [num_voxel, 3]
coords += coords_offset
colors = torch.from_numpy(colors).float().cuda()
colors = pointgroup_ops.voxelization(colors, v2p_map,
4).cpu().numpy() # [num_voxel, 3]
colors = pointgroup_ops.voxelization(colors, v2p_map, 4).cpu().numpy() # [num_voxel, 3]
# processing labels individually (nearest search)
voxelize_coords = voxelize_coords[:, 1:].cpu().numpy() * voxel_size
@ -71,24 +64,16 @@ def voxelize(coords: np.ndarray, colors: np.ndarray,
def get_parser():
parser = argparse.ArgumentParser(
description="downsample s3dis by voxelization")
parser.add_argument("--data-dir",
type=str,
default="./preprocess",
help="directory save processed data")
parser.add_argument("--ratio",
type=float,
default=0.25,
help="random downsample ratio")
parser = argparse.ArgumentParser(description="downsample s3dis by voxelization")
parser.add_argument(
"--data-dir", type=str, default="./preprocess", help="directory save processed data")
parser.add_argument("--ratio", type=float, default=0.25, help="random downsample ratio")
parser.add_argument(
"--voxel-size",
type=float,
default=None,
help="voxelization size (priority is higher than voxel-size)")
parser.add_argument("--verbose",
action="store_true",
help="show partition information or not")
parser.add_argument("--verbose", action="store_true", help="show partition information or not")
args_cfg = parser.parse_args()
@ -129,5 +114,4 @@ if __name__ == "__main__":
coords, colors, semantic_labels, instance_labels = \
random_sample(coords, colors, semantic_labels, instance_labels, args.ratio)
torch.save((coords, colors, semantic_labels, instance_labels,
room_label, scene), save_path)
torch.save((coords, colors, semantic_labels, instance_labels, room_label, scene), save_path)

View File

@ -103,24 +103,14 @@ def read_s3dis_format(area_id: str,
def get_parser():
parser = argparse.ArgumentParser(description="s3dis data prepare")
parser.add_argument("--data-root",
type=str,
default="./Stanford3dDataset_v1.2",
help="root dir save data")
parser.add_argument("--save-dir",
type=str,
default="./preprocess",
help="directory save processed data")
parser.add_argument(
"--patch",
action="store_true",
help="patch data or not (just patch at first time running)")
parser.add_argument("--align",
action="store_true",
help="processing aligned dataset or not")
parser.add_argument("--verbose",
action="store_true",
help="show processing room name or not")
"--data-root", type=str, default="./Stanford3dDataset_v1.2", help="root dir save data")
parser.add_argument(
"--save-dir", type=str, default="./preprocess", help="directory save processed data")
parser.add_argument(
"--patch", action="store_true", help="patch data or not (just patch at first time running)")
parser.add_argument("--align", action="store_true", help="processing aligned dataset or not")
parser.add_argument("--verbose", action="store_true", help="show processing room name or not")
args_cfg = parser.parse_args()
@ -141,14 +131,10 @@ if __name__ == "__main__":
f"patch -ruN -p0 -d {data_root} < {osp.join(osp.dirname(__file__), 's3dis_align.patch')}"
)
# rename to avoid room_name conflict
if osp.exists(
osp.join(data_root, "Area_6", "copyRoom_1",
"copy_Room_1.txt")):
if osp.exists(osp.join(data_root, "Area_6", "copyRoom_1", "copy_Room_1.txt")):
os.rename(
osp.join(data_root, "Area_6", "copyRoom_1",
"copy_Room_1.txt"),
osp.join(data_root, "Area_6", "copyRoom_1",
"copyRoom_1.txt"))
osp.join(data_root, "Area_6", "copyRoom_1", "copy_Room_1.txt"),
osp.join(data_root, "Area_6", "copyRoom_1", "copyRoom_1.txt"))
else:
os.system(
f"patch -ruN -p0 -d {data_root} < {osp.join(osp.dirname(__file__), 's3dis.patch')}"
@ -182,5 +168,4 @@ if __name__ == "__main__":
(xyz, rgb, semantic_labels, instance_labels,
room_label) = read_s3dis_format(area_id, room_name, data_root)
rgb = (rgb / 127.5) - 1
torch.save((xyz, rgb, semantic_labels, instance_labels, room_label,
scene), save_path)
torch.save((xyz, rgb, semantic_labels, instance_labels, room_label, scene), save_path)

View File

@ -3,7 +3,6 @@ Generate instance groundtruth .txt files (for evaluation)
modified by thang: fix label id
"""
import argparse
import numpy as np
import glob
@ -32,14 +31,10 @@ semantic_label_idxs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
def get_parser():
parser = argparse.ArgumentParser(description="s3dis data prepare")
parser.add_argument("--data-dir",
type=str,
default="./preprocess",
help="directory save processed data")
parser.add_argument("--save-dir",
type=str,
default="./val_gt",
help="directory save ground truth")
parser.add_argument(
"--data-dir", type=str, default="./preprocess", help="directory save processed data")
parser.add_argument(
"--save-dir", type=str, default="./val_gt", help="directory save ground truth")
args_cfg = parser.parse_args()
@ -58,8 +53,7 @@ if __name__ == "__main__":
for i, f in enumerate(files):
(xyz, rgb, semantic_labels, instance_labels, room_label,
scene) = torch.load(
f) # semantic label 0-12 instance_labels 0~instance_num-1 -100
scene) = torch.load(f) # semantic label 0-12 instance_labels 0~instance_num-1 -100
print(f"{i + 1}/{len(files)} {scene}")
instance_labels_new = np.zeros(
@ -75,9 +69,6 @@ if __name__ == "__main__":
sem_id = int(semantic_labels[instance_mask[0]])
if (sem_id == -100): sem_id = 0
semantic_label = semantic_label_idxs[sem_id]
instance_labels_new[
instance_mask] = semantic_label * 1000 + inst_id
instance_labels_new[instance_mask] = semantic_label * 1000 + inst_id
np.savetxt(os.path.join(save_dir, scene + ".txt"),
instance_labels_new,
fmt="%d")
np.savetxt(os.path.join(save_dir, scene + ".txt"), instance_labels_new, fmt="%d")

View File

@ -26,6 +26,7 @@ if opt.data_split != 'test':
assert len(files) == len(files3)
assert len(files) == len(files4), "{} {}".format(len(files), len(files4))
def f_test(fn):
print(fn)
@ -66,14 +67,17 @@ def f(fn):
with open(fn4) as jsondata:
d = json.load(jsondata)
for x in d['segGroups']:
if scannet_util.g_raw2scannetv2[x['label']] != 'wall' and scannet_util.g_raw2scannetv2[x['label']] != 'floor':
if scannet_util.g_raw2scannetv2[x['label']] != 'wall' and scannet_util.g_raw2scannetv2[
x['label']] != 'floor':
instance_segids.append(x['segments'])
labels.append(x['label'])
assert(x['label'] in scannet_util.g_raw2scannetv2.keys())
if(fn == 'val/scene0217_00_vh_clean_2.ply' and instance_segids[0] == instance_segids[int(len(instance_segids) / 2)]):
instance_segids = instance_segids[: int(len(instance_segids) / 2)]
assert (x['label'] in scannet_util.g_raw2scannetv2.keys())
if (fn == 'val/scene0217_00_vh_clean_2.ply'
and instance_segids[0] == instance_segids[int(len(instance_segids) / 2)]):
instance_segids = instance_segids[:int(len(instance_segids) / 2)]
check = []
for i in range(len(instance_segids)): check += instance_segids[i]
for i in range(len(instance_segids)):
check += instance_segids[i]
assert len(np.unique(check)) == len(check)
instance_labels = np.ones(sem_labels.shape[0]) * -100
@ -83,10 +87,11 @@ def f(fn):
for segid in segids:
pointids += segid_to_pointid[segid]
instance_labels[pointids] = i
assert(len(np.unique(sem_labels[pointids])) == 1)
assert (len(np.unique(sem_labels[pointids])) == 1)
torch.save((coords, colors, sem_labels, instance_labels), fn[:-15] + '_inst_nostuff.pth')
print('Saving to ' + fn[:-15] + '_inst_nostuff.pth')
torch.save((coords, colors, sem_labels, instance_labels), fn[:-15]+'_inst_nostuff.pth')
print('Saving to ' + fn[:-15]+'_inst_nostuff.pth')
# for fn in files:
# f(fn)

View File

@ -8,8 +8,11 @@ import torch
import os
semantic_label_idxs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]
semantic_label_names = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture']
semantic_label_names = [
'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
'bathtub', 'otherfurniture'
]
if __name__ == '__main__':
split = 'val'
@ -20,22 +23,21 @@ if __name__ == '__main__':
os.mkdir(split + '_gt')
for i in range(len(rooms)):
xyz, rgb, label, instance_label = rooms[i] # label 0~19 -100; instance_label 0~instance_num-1 -100
xyz, rgb, label, instance_label = rooms[
i] # label 0~19 -100; instance_label 0~instance_num-1 -100
scene_name = files[i].split('/')[-1][:12]
print('{}/{} {}'.format(i + 1, len(rooms), scene_name))
instance_label_new = np.zeros(instance_label.shape, dtype=np.int32) # 0 for unannotated, xx00y: x for semantic_label, y for inst_id (1~instance_num)
instance_label_new = np.zeros(
instance_label.shape, dtype=np.int32
) # 0 for unannotated, xx00y: x for semantic_label, y for inst_id (1~instance_num)
instance_num = int(instance_label.max()) + 1
for inst_id in range(instance_num):
instance_mask = np.where(instance_label == inst_id)[0]
sem_id = int(label[instance_mask[0]])
if(sem_id == -100): sem_id = 0
if (sem_id == -100): sem_id = 0
semantic_label = semantic_label_idxs[sem_id]
instance_label_new[instance_mask] = semantic_label * 1000 + inst_id + 1
np.savetxt(os.path.join(split + '_gt', scene_name + '.txt'), instance_label_new, fmt='%d')

View File

@ -1,4 +1,9 @@
g_label_names = ['unannotated', 'wall', 'floor', 'chair', 'table', 'desk', 'bed', 'bookshelf', 'sofa', 'sink', 'bathtub', 'toilet', 'curtain', 'counter', 'door', 'window', 'shower curtain', 'refridgerator', 'picture', 'cabinet', 'otherfurniture']
g_label_names = [
'unannotated', 'wall', 'floor', 'chair', 'table', 'desk', 'bed', 'bookshelf', 'sofa', 'sink',
'bathtub', 'toilet', 'curtain', 'counter', 'door', 'window', 'shower curtain', 'refridgerator',
'picture', 'cabinet', 'otherfurniture'
]
def get_raw2scannetv2_label_map():
lines = [line.rstrip() for line in open('scannetv2-labels.combined.tsv')]
@ -20,4 +25,5 @@ def get_raw2scannetv2_label_map():
raw2scannet[raw_name] = nyu40_name
return raw2scannet
g_raw2scannetv2 = get_raw2scannetv2_label_map()
g_raw2scannetv2 = get_raw2scannetv2_label_map()

View File

@ -1,14 +0,0 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='SOFTGROUP_OP',
ext_modules=[
CUDAExtension('SOFTGROUP_OP', [
'src/softgroup_api.cpp',
'src/softgroup_ops.cpp',
'src/cuda.cu'
], extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']})
],
cmdclass={'build_ext': BuildExtension}
)

@ -1 +0,0 @@
Subproject commit 740a5b717fc576b222abc169ae6047ff1e95363f

View File

@ -3,14 +3,11 @@ import numpy as np
import os.path as osp
import scipy.interpolate
import scipy.ndimage
import sys
import torch
from glob import glob
from torch.utils.data import Dataset
sys.path.append('../')
from lib.softgroup_ops.functions import softgroup_ops # noqa
from ..lib.softgroup_ops import voxelization_idx
class CustomDataset(Dataset):
@ -223,7 +220,7 @@ class CustomDataset(Dataset):
spatial_shape = np.clip(
locs.max(0)[0][1:].numpy() + 1, self.voxel_cfg.spatial_shape[0], None)
voxel_locs, v2p_map, p2v_map = softgroup_ops.voxelization_idx(locs, 1)
voxel_locs, v2p_map, p2v_map = voxelization_idx(locs, 1)
return {
'scan_ids': scan_ids,
'locs': locs,

View File

@ -3,11 +3,8 @@ import torch
import numpy as np
from glob import glob
import os.path as osp
import sys
sys.path.append('../')
from lib.softgroup_ops.functions import softgroup_ops # noqa
from ..lib.softgroup_ops import voxelization_idx
class S3DISDataset(CustomDataset):
@ -87,7 +84,7 @@ class S3DISDataset(CustomDataset):
pt_offset_labels = pt_offset_label.float()
spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0],
None)
voxel_locs, v2p_map, p2v_map = softgroup_ops.voxelization_idx(locs, 4)
voxel_locs, v2p_map, p2v_map = voxelization_idx(locs, 4)
return {
'scan_ids': scan_ids,
'batch_idxs': batch_idxs,

View File

@ -39,14 +39,10 @@ class ScanNetEval(object):
dist_confs = [self.distance_confs[0]]
# results: class x iou
ap = np.zeros(
(len(dist_threshes), len(self.eval_class_labels), len(ious)),
np.float)
rc = np.zeros(
(len(dist_threshes), len(self.eval_class_labels), len(ious)),
np.float)
for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
zip(min_region_sizes, dist_threshes, dist_confs)):
ap = np.zeros((len(dist_threshes), len(self.eval_class_labels), len(ious)), np.float)
rc = np.zeros((len(dist_threshes), len(self.eval_class_labels), len(ious)), np.float)
for di, (min_region_size, distance_thresh,
distance_conf) in enumerate(zip(min_region_sizes, dist_threshes, dist_confs)):
for oi, iou_th in enumerate(ious):
pred_visited = {}
for m in matches:
@ -67,10 +63,8 @@ class ScanNetEval(object):
# filter groups in ground truth
gt_instances = [
gt for gt in gt_instances
if gt['instance_id'] >= 1000
and gt['vert_count'] >= min_region_size
and gt['med_dist'] <= distance_thresh
and gt['dist_conf'] >= distance_conf
if gt['instance_id'] >= 1000 and gt['vert_count'] >= min_region_size and
gt['med_dist'] <= distance_thresh and gt['dist_conf'] >= distance_conf
]
if gt_instances:
has_gt = True
@ -78,8 +72,7 @@ class ScanNetEval(object):
has_pred = True
cur_true = np.ones(len(gt_instances))
cur_score = np.ones(
len(gt_instances)) * (-float('inf'))
cur_score = np.ones(len(gt_instances)) * (-float('inf'))
cur_match = np.zeros(len(gt_instances), dtype=np.bool)
# collect matches
for (gti, gt) in enumerate(gt_instances):
@ -96,15 +89,12 @@ class ScanNetEval(object):
# the prediction with the lower score is
# automatically a FP
if cur_match[gti]:
max_score = max(
cur_score[gti], confidence)
min_score = min(
cur_score[gti], confidence)
max_score = max(cur_score[gti], confidence)
min_score = min(cur_score[gti], confidence)
cur_score[gti] = max_score
# append false positive
cur_true = np.append(cur_true, 0)
cur_score = np.append(
cur_score, min_score)
cur_score = np.append(cur_score, min_score)
cur_match = np.append(cur_match, True)
# otherwise set score
else:
@ -135,17 +125,14 @@ class ScanNetEval(object):
# small ground truth instances
if (gt['vert_count'] < min_region_size
or gt['med_dist'] > distance_thresh
or
gt['dist_conf'] < distance_conf):
or gt['dist_conf'] < distance_conf):
num_ignore += gt['intersection']
proportion_ignore = float(
num_ignore) / pred['vert_count']
proportion_ignore = float(num_ignore) / pred['vert_count']
# if not ignored append false positive
if proportion_ignore <= iou_th:
cur_true = np.append(cur_true, 0)
confidence = pred['confidence']
cur_score = np.append(
cur_score, confidence)
cur_score = np.append(cur_score, confidence)
# append to overall results
y_true = np.append(y_true, cur_true)
@ -162,8 +149,7 @@ class ScanNetEval(object):
y_true_sorted_cumsum = np.cumsum(y_true_sorted)
# unique thresholds
(thresholds, unique_indices) = np.unique(
y_score_sorted, return_index=True)
(thresholds, unique_indices) = np.unique(y_score_sorted, return_index=True)
num_prec_recall = len(unique_indices) + 1
# prepare precision recall
@ -173,8 +159,7 @@ class ScanNetEval(object):
recall = np.zeros(num_prec_recall)
# deal with the first point
y_true_sorted_cumsum = np.append(
y_true_sorted_cumsum, 0)
y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
# deal with remaining
for idx_res, idx_scores in enumerate(unique_indices):
cumsum = y_true_sorted_cumsum[idx_scores - 1]
@ -195,12 +180,10 @@ class ScanNetEval(object):
# compute average of precision-recall curve
recall_for_conv = np.copy(recall)
recall_for_conv = np.append(recall_for_conv[0],
recall_for_conv)
recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
recall_for_conv = np.append(recall_for_conv, 0.)
stepWidths = np.convolve(recall_for_conv,
[-0.5, 0, 0.5], 'valid')
stepWidths = np.convolve(recall_for_conv, [-0.5, 0, 0.5], 'valid')
# integrate is now simply a dot product
ap_current = np.dot(precision, stepWidths)
@ -230,25 +213,19 @@ class ScanNetEval(object):
avg_dict['classes'] = {}
for (li, label_name) in enumerate(self.eval_class_labels):
avg_dict['classes'][label_name] = {}
avg_dict['classes'][label_name]['ap'] = np.average(aps[d_inf, li,
oAllBut25])
avg_dict['classes'][label_name]['ap50%'] = np.average(aps[d_inf,
li, o50])
avg_dict['classes'][label_name]['ap25%'] = np.average(aps[d_inf,
li, o25])
avg_dict['classes'][label_name]['rc'] = np.average(rcs[d_inf, li,
oAllBut25])
avg_dict['classes'][label_name]['rc50%'] = np.average(rcs[d_inf,
li, o50])
avg_dict['classes'][label_name]['rc25%'] = np.average(rcs[d_inf,
li, o25])
avg_dict['classes'][label_name]['ap'] = np.average(aps[d_inf, li, oAllBut25])
avg_dict['classes'][label_name]['ap50%'] = np.average(aps[d_inf, li, o50])
avg_dict['classes'][label_name]['ap25%'] = np.average(aps[d_inf, li, o25])
avg_dict['classes'][label_name]['rc'] = np.average(rcs[d_inf, li, oAllBut25])
avg_dict['classes'][label_name]['rc50%'] = np.average(rcs[d_inf, li, o50])
avg_dict['classes'][label_name]['rc25%'] = np.average(rcs[d_inf, li, o25])
return avg_dict
def assign_instances_for_scan(self, preds, gts):
"""get gt instances, only consider the valid class labels even in class
agnostic setting."""
gt_instances = get_instances(gts, self.valid_class_ids,
self.valid_class_labels, self.id2label)
gt_instances = get_instances(gts, self.valid_class_ids, self.valid_class_labels,
self.id2label)
# associate
if self.use_label:
gt2pred = deepcopy(gt_instances)
@ -292,8 +269,7 @@ class ScanNetEval(object):
continue # skip if empty
pred_instance = {}
pred_instance['filename'] = '{}_{}'.format(
pred['scan_id'], num_pred_instances) # dummy
pred_instance['filename'] = '{}_{}'.format(pred['scan_id'], num_pred_instances) # dummy
pred_instance['pred_id'] = num_pred_instances
pred_instance['label_id'] = label_id if self.use_label else None
pred_instance['vert_count'] = num
@ -314,13 +290,11 @@ class ScanNetEval(object):
pred_copy['intersection'] = intersection
iou = (
float(intersection) /
(gt_copy['vert_count'] + pred_copy['vert_count'] -
intersection))
(gt_copy['vert_count'] + pred_copy['vert_count'] - intersection))
gt_copy['iou'] = iou
pred_copy['iou'] = iou
matched_gt.append(gt_copy)
gt2pred[label_name][gt_num]['matched_pred'].append(
pred_copy)
gt2pred[label_name][gt_num]['matched_pred'].append(pred_copy)
pred_instance['matched_gt'] = matched_gt
num_pred_instances += 1
pred2gt[label_name].append(pred_instance)
@ -384,16 +358,12 @@ class ScanNetEval(object):
def write_result_file(self, avgs, filename):
_SPLITTER = ','
with open(filename, 'w') as f:
f.write(
_SPLITTER.join(['class', 'class id', 'ap', 'ap50', 'ap25']) +
'\n')
f.write(_SPLITTER.join(['class', 'class id', 'ap', 'ap50', 'ap25']) + '\n')
for class_name in self.eval_class_labels:
ap = avgs['classes'][class_name]['ap']
ap50 = avgs['classes'][class_name]['ap50%']
ap25 = avgs['classes'][class_name]['ap25%']
f.write(
_SPLITTER.join(
[str(x) for x in [class_name, ap, ap50, ap25]]) + '\n')
f.write(_SPLITTER.join([str(x) for x in [class_name, ap, ap50, ap25]]) + '\n')
def evaluate(self, pred_list, gt_list):
"""

View File

@ -46,8 +46,7 @@ def export_instance_ids_for_eval(filename, label_ids, instance_ids):
assert label_ids.shape[0] == instance_ids.shape[0]
output_mask_path_relative = 'pred_mask'
name = os.path.splitext(os.path.basename(filename))[0]
output_mask_path = os.path.join(
os.path.dirname(filename), output_mask_path_relative)
output_mask_path = os.path.join(os.path.dirname(filename), output_mask_path_relative)
if not os.path.isdir(output_mask_path):
os.mkdir(output_mask_path)
insts = np.unique(instance_ids)
@ -82,8 +81,7 @@ class Instance(object):
return
self.instance_id = int(instance_id)
self.label_id = int(self.get_label_id(instance_id))
self.vert_count = int(
self.get_instance_verts(mesh_vert_instances, instance_id))
self.vert_count = int(self.get_instance_verts(mesh_vert_instances, instance_id))
def get_label_id(self, instance_id):
return int(instance_id // 1000)
@ -92,8 +90,7 @@ class Instance(object):
return (mesh_vert_instances == instance_id).sum()
def to_json(self):
return json.dumps(
self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
def to_dict(self):
dict = {}
@ -134,8 +131,7 @@ def read_instance_prediction_file(filename, pred_path):
# check that mask_file lives inside prediction path
if os.path.commonprefix([mask_file, abs_pred_path]) != abs_pred_path:
print(('predicted mask {} in prediction text file {}' +
'points outside of prediction path.').format(
mask_file, filename))
'points outside of prediction path.').format(mask_file, filename))
info = {}
info['label_id'] = int(float(parts[1]))

View File

@ -0,0 +1 @@
from .functions.softgroup_ops import *

View File

@ -1,83 +1,11 @@
import torch
from torch.autograd import Function
import SOFTGROUP_OP
class HierarchicalAggregation(Function):
@staticmethod
def forward(ctx, cluster_numpoint_mean, semantic_label, coord_shift, ball_query_idxs, start_len, batch_idxs, training_mode, using_set_aggr, class_id):
'''
:param ctx:
:param semantic_label: (N_fg), int
:param coord_shift: (N_fg, 3), float
:param ball_query_idxs: (nActive), int
:param start_len: (N_fg, 2), int
:param batch_idxs: (N_fg), int
:return: cluster_idxs: int (sumNPoint, 2), [:, 0] for cluster_id, [:, 1] for corresponding point idxs in N
:return: cluster_offsets: int (nCluster + 1)
'''
N = start_len.size(0)
assert semantic_label.is_contiguous()
assert coord_shift.is_contiguous()
assert ball_query_idxs.is_contiguous()
assert start_len.is_contiguous()
fragment_idxs = semantic_label.new()
fragment_offsets = semantic_label.new()
fragment_centers = coord_shift.new() # float
cluster_idxs_kept = semantic_label.new()
cluster_offsets_kept = semantic_label.new()
cluster_centers_kept = coord_shift.new() # float
primary_idxs = semantic_label.new()
primary_offsets = semantic_label.new()
primary_centers = coord_shift.new() # float
primary_idxs_post = semantic_label.new()
primary_offsets_post = semantic_label.new()
training_mode_ = 1 if training_mode == 'train' else 0
using_set_aggr_ = int(using_set_aggr)
SOFTGROUP_OP.hierarchical_aggregation(cluster_numpoint_mean, semantic_label, coord_shift, batch_idxs, ball_query_idxs, start_len,
fragment_idxs, fragment_offsets, fragment_centers,
cluster_idxs_kept, cluster_offsets_kept, cluster_centers_kept,
primary_idxs, primary_offsets, primary_centers,
primary_idxs_post, primary_offsets_post,
N, training_mode_, using_set_aggr_, class_id)
if using_set_aggr_ == 0: # not set aggr
pass
else:
# cut off tails
primary_idxs_post = primary_idxs_post[:primary_offsets_post[-1]]
primary_idxs = primary_idxs_post
primary_offsets = primary_offsets_post
cluster_idxs = cluster_idxs_kept
cluster_offsets = cluster_offsets_kept
if primary_idxs.shape[0] != 0:
#add primary
primary_idxs[:, 0] += (cluster_offsets.size(0) - 1)
primary_offsets += cluster_offsets[-1]
cluster_idxs = torch.cat((cluster_idxs, primary_idxs), dim=0).cpu()
cluster_offsets = torch.cat((cluster_offsets, primary_offsets[1:])).cpu()
return cluster_idxs, cluster_offsets
@staticmethod
def backward(ctx, a=None):
return None
hierarchical_aggregation = HierarchicalAggregation.apply
from .. import SOFTGROUP_OP
class GetMaskIoUOnCluster(Function):
@staticmethod
def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_pointnum):
'''
@ -102,7 +30,8 @@ class GetMaskIoUOnCluster(Function):
assert instance_labels.is_contiguous() and instance_labels.is_cuda
assert instance_pointnum.is_contiguous() and instance_pointnum.is_cuda
SOFTGROUP_OP.get_mask_iou_on_cluster(proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou, nInstance, nProposal)
SOFTGROUP_OP.get_mask_iou_on_cluster(proposals_idx, proposals_offset, instance_labels,
instance_pointnum, proposals_iou, nInstance, nProposal)
return proposals_iou
@ -110,12 +39,15 @@ class GetMaskIoUOnCluster(Function):
def backward(ctx, a=None):
return None, None, None, None
get_mask_iou_on_cluster = GetMaskIoUOnCluster.apply
class GetMaskIoUOnPred(Function):
@staticmethod
def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_pointnum, mask_scores_sigmoid):
def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_pointnum,
mask_scores_sigmoid):
'''
:param ctx:
:param proposals_idx: (sumNPoint), int
@ -139,7 +71,9 @@ class GetMaskIoUOnPred(Function):
assert instance_pointnum.is_contiguous() and instance_pointnum.is_cuda
assert mask_scores_sigmoid.is_contiguous() and mask_scores_sigmoid.is_cuda
SOFTGROUP_OP.get_mask_iou_on_pred(proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou, nInstance, nProposal, mask_scores_sigmoid)
SOFTGROUP_OP.get_mask_iou_on_pred(proposals_idx, proposals_offset, instance_labels,
instance_pointnum, proposals_iou, nInstance, nProposal,
mask_scores_sigmoid)
return proposals_iou
@ -147,11 +81,15 @@ class GetMaskIoUOnPred(Function):
def backward(ctx, a=None):
return None, None, None, None
get_mask_iou_on_pred = GetMaskIoUOnPred.apply
class GetMaskLabel(Function):
@staticmethod
def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_cls, instance_pointnum, proposals_iou, iou_thr):
def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_cls,
instance_pointnum, proposals_iou, iou_thr):
'''
:param ctx:
:param proposals_idx: (sumNPoint), int
@ -174,7 +112,8 @@ class GetMaskLabel(Function):
assert instance_labels.is_contiguous() and instance_labels.is_cuda
assert instance_cls.is_contiguous() and instance_cls.is_cuda
SOFTGROUP_OP.get_mask_label(proposals_idx, proposals_offset, instance_labels, instance_cls, proposals_iou, nInstance, nProposal, iou_thr, mask_label)
SOFTGROUP_OP.get_mask_label(proposals_idx, proposals_offset, instance_labels, instance_cls,
proposals_iou, nInstance, nProposal, iou_thr, mask_label)
return mask_label
@ -182,10 +121,12 @@ class GetMaskLabel(Function):
def backward(ctx, a=None):
return None, None, None, None
get_mask_label = GetMaskLabel.apply
class Voxelization_Idx(Function):
@staticmethod
def forward(ctx, coords, batchsize, mode=4):
'''
@ -212,10 +153,12 @@ class Voxelization_Idx(Function):
def backward(ctx, a=None, b=None, c=None):
return None
voxelization_idx = Voxelization_Idx.apply
class Voxelization(Function):
@staticmethod
def forward(ctx, feats, map_rule, mode=4):
'''
@ -237,7 +180,6 @@ class Voxelization(Function):
SOFTGROUP_OP.voxelize_fp(feats, output_feats, map_rule, mode, M, maxActive, C)
return output_feats
@staticmethod
def backward(ctx, d_output_feats):
map_rule, mode, maxActive, N = ctx.for_backwards
@ -245,13 +187,16 @@ class Voxelization(Function):
d_feats = torch.cuda.FloatTensor(N, C).zero_()
SOFTGROUP_OP.voxelize_bp(d_output_feats.contiguous(), d_feats, map_rule, mode, M, maxActive, C)
SOFTGROUP_OP.voxelize_bp(d_output_feats.contiguous(), d_feats, map_rule, mode, M, maxActive,
C)
return d_feats, None, None
voxelization = Voxelization.apply
class PointRecover(Function):
@staticmethod
def forward(ctx, feats, map_rule, nPoint):
'''
@ -281,14 +226,17 @@ class PointRecover(Function):
d_feats = torch.cuda.FloatTensor(M, C).zero_()
SOFTGROUP_OP.point_recover_bp(d_output_feats.contiguous(), d_feats, map_rule, M, maxActive, C)
SOFTGROUP_OP.point_recover_bp(d_output_feats.contiguous(), d_feats, map_rule, M, maxActive,
C)
return d_feats, None, None
point_recover = PointRecover.apply
class BallQueryBatchP(Function):
@staticmethod
def forward(ctx, coords, batch_idxs, batch_offsets, radius, meanActive):
'''
@ -311,7 +259,8 @@ class BallQueryBatchP(Function):
while True:
idx = torch.cuda.IntTensor(n * meanActive).zero_()
start_len = torch.cuda.IntTensor(n, 2).zero_()
nActive = SOFTGROUP_OP.ballquery_batch_p(coords, batch_idxs, batch_offsets, idx, start_len, n, meanActive, radius)
nActive = SOFTGROUP_OP.ballquery_batch_p(coords, batch_idxs, batch_offsets, idx,
start_len, n, meanActive, radius)
if nActive <= n * meanActive:
break
meanActive = int(nActive // n + 1)
@ -323,12 +272,14 @@ class BallQueryBatchP(Function):
def backward(ctx, a=None, b=None):
return None, None, None
ballquery_batch_p = BallQueryBatchP.apply
class BFSCluster(Function):
@staticmethod
def forward(ctx, cluster_numpoint_mean, ball_query_idxs, start_len, threshold, class_id):
def forward(ctx, cluster_numpoint_mean, ball_query_idxs, start_len, threshold, class_id):
'''
:param ctx:
:param ball_query_idxs: (nActive), int
@ -345,7 +296,8 @@ class BFSCluster(Function):
cluster_idxs = ball_query_idxs.new()
cluster_offsets = ball_query_idxs.new()
SOFTGROUP_OP.bfs_cluster(cluster_numpoint_mean, ball_query_idxs, start_len, cluster_idxs, cluster_offsets, N, threshold, class_id)
SOFTGROUP_OP.bfs_cluster(cluster_numpoint_mean, ball_query_idxs, start_len, cluster_idxs,
cluster_offsets, N, threshold, class_id)
return cluster_idxs, cluster_offsets
@ -353,10 +305,12 @@ class BFSCluster(Function):
def backward(ctx, a=None):
return None
bfs_cluster = BFSCluster.apply
class RoiPool(Function):
@staticmethod
def forward(ctx, feats, proposals_offset):
'''
@ -388,14 +342,17 @@ class RoiPool(Function):
d_feats = torch.cuda.FloatTensor(sumNPoint, C).zero_()
SOFTGROUP_OP.roipool_bp(d_feats, proposals_offset, output_maxidx, d_output_feats.contiguous(), nProposal, C)
SOFTGROUP_OP.roipool_bp(d_feats, proposals_offset, output_maxidx,
d_output_feats.contiguous(), nProposal, C)
return d_feats, None
roipool = RoiPool.apply
class GlobalAvgPool(Function):
@staticmethod
def forward(ctx, feats, proposals_offset):
'''
@ -426,14 +383,17 @@ class GlobalAvgPool(Function):
d_feats = torch.cuda.FloatTensor(sumNPoint, C).zero_()
SOFTGROUP_OP.global_avg_pool_bp(d_feats, proposals_offset, d_output_feats.contiguous(), nProposal, C)
SOFTGROUP_OP.global_avg_pool_bp(d_feats, proposals_offset, d_output_feats.contiguous(),
nProposal, C)
return d_feats, None
global_avg_pool = GlobalAvgPool.apply
class GetIoU(Function):
@staticmethod
def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_pointnum):
'''
@ -454,7 +414,8 @@ class GetIoU(Function):
proposals_iou = torch.cuda.FloatTensor(nProposal, nInstance).zero_()
SOFTGROUP_OP.get_iou(proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou, nInstance, nProposal)
SOFTGROUP_OP.get_iou(proposals_idx, proposals_offset, instance_labels, instance_pointnum,
proposals_iou, nInstance, nProposal)
return proposals_iou
@ -462,10 +423,12 @@ class GetIoU(Function):
def backward(ctx, a=None):
return None, None, None, None
get_iou = GetIoU.apply
class SecMean(Function):
@staticmethod
def forward(ctx, inp, offsets):
'''
@ -490,10 +453,12 @@ class SecMean(Function):
def backward(ctx, a=None):
return None, None
sec_mean = SecMean.apply
class SecMin(Function):
@staticmethod
def forward(ctx, inp, offsets):
'''
@ -518,10 +483,12 @@ class SecMin(Function):
def backward(ctx, a=None):
return None, None
sec_min = SecMin.apply
class SecMax(Function):
@staticmethod
def forward(ctx, inp, offsets):
'''
@ -546,4 +513,5 @@ class SecMax(Function):
def backward(ctx, a=None):
return None, None
sec_max = SecMax.apply

View File

@ -0,0 +1,14 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='SOFTGROUP_OP',
ext_modules=[
CUDAExtension(
'SOFTGROUP_OP', ['src/softgroup_api.cpp', 'src/softgroup_ops.cpp', 'src/cuda.cu'],
extra_compile_args={
'cxx': ['-g'],
'nvcc': ['-O2']
})
],
cmdclass={'build_ext': BuildExtension})

View File

@ -0,0 +1,3 @@
from .softgroup import SoftGroup
__all__ = ['SoftGroup']

View File

@ -1,16 +1,14 @@
import functools
import spconv
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..lib.softgroup_ops import (ballquery_batch_p, bfs_cluster, get_mask_iou_on_cluster,
get_mask_iou_on_pred, get_mask_label, global_avg_pool, sec_max,
sec_mean, sec_min, voxelization, voxelization_idx)
from .blocks import ResidualBlock, UBlock
sys.path.append('../../')
from lib.softgroup_ops.functions import softgroup_ops # noqa
class SoftGroup(nn.Module):
@ -119,7 +117,7 @@ class SoftGroup(nn.Module):
losses = {}
feats = torch.cat((feats, coords_float), 1)
voxel_feats = softgroup_ops.voxelization(feats, p2v_map)
voxel_feats = voxelization(feats, p2v_map)
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
input, v2p_map, coords_float)
@ -173,8 +171,8 @@ class SoftGroup(nn.Module):
proposals_offset = proposals_offset.cuda()
# cal iou of clustered instance
ious_on_cluster = softgroup_ops.get_mask_iou_on_cluster(proposals_idx, proposals_offset,
instance_labels, instance_pointnum)
ious_on_cluster = get_mask_iou_on_cluster(proposals_idx, proposals_offset, instance_labels,
instance_pointnum)
# filter out background instances
fg_inds = (instance_cls != self.ignore_label)
@ -197,9 +195,8 @@ class SoftGroup(nn.Module):
slice_inds = torch.arange(
0, mask_cls_label.size(0), dtype=torch.long, device=mask_cls_label.device)
mask_scores_sigmoid_slice = mask_scores.sigmoid()[slice_inds, mask_cls_label]
mask_label = softgroup_ops.get_mask_label(proposals_idx, proposals_offset, instance_labels,
instance_cls, instance_pointnum, ious_on_cluster,
self.train_cfg.pos_iou_thr)
mask_label = get_mask_label(proposals_idx, proposals_offset, instance_labels, instance_cls,
instance_pointnum, ious_on_cluster, self.train_cfg.pos_iou_thr)
mask_label_weight = (mask_label != -1).float()
mask_label[mask_label == -1.] = 0.5 # any value is ok
mask_loss = F.binary_cross_entropy(
@ -208,9 +205,8 @@ class SoftGroup(nn.Module):
losses['mask_loss'] = (mask_loss, mask_label_weight.sum())
# compute iou score loss
ious = softgroup_ops.get_mask_iou_on_pred(proposals_idx, proposals_offset, instance_labels,
instance_pointnum,
mask_scores_sigmoid_slice.detach())
ious = get_mask_iou_on_pred(proposals_idx, proposals_offset, instance_labels,
instance_pointnum, mask_scores_sigmoid_slice.detach())
fg_ious = ious[:, fg_inds]
gt_ious, _ = fg_ious.max(1)
slice_inds = torch.arange(0, labels.size(0), dtype=torch.long, device=labels.device)
@ -234,7 +230,7 @@ class SoftGroup(nn.Module):
batch_size = batch['batch_size']
feats = torch.cat((feats, coords_float), 1)
voxel_feats = softgroup_ops.voxelization(feats, p2v_map)
voxel_feats = voxelization(feats, p2v_map)
input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size)
semantic_scores, pt_offsets, output_feats, coords_float = self.forward_backbone(
input, v2p_map, coords_float, x4_split=self.test_cfg.x4_split)
@ -324,10 +320,10 @@ class SoftGroup(nn.Module):
batch_offsets_ = self.get_batch_offsets(batch_idxs_, batch_size)
coords_ = coords_float[object_idxs]
pt_offsets_ = pt_offsets[object_idxs]
idx, start_len = softgroup_ops.ballquery_batch_p(coords_ + pt_offsets_, batch_idxs_,
batch_offsets_, radius, mean_active)
proposals_idx, proposals_offset = softgroup_ops.bfs_cluster(
class_numpoint_mean, idx.cpu(), start_len.cpu(), npoint_thr, class_id)
idx, start_len = ballquery_batch_p(coords_ + pt_offsets_, batch_idxs_, batch_offsets_,
radius, mean_active)
proposals_idx, proposals_offset = bfs_cluster(class_numpoint_mean, idx.cpu(),
start_len.cpu(), npoint_thr, class_id)
proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
# merge proposals
@ -433,13 +429,13 @@ class SoftGroup(nn.Module):
clusters_feats = feats[c_idxs.long()]
clusters_coords = coords[c_idxs.long()]
clusters_coords_mean = softgroup_ops.sec_mean(clusters_coords, clusters_offset.cuda())
clusters_coords_mean = sec_mean(clusters_coords, clusters_offset.cuda())
clusters_coords_mean = torch.index_select(clusters_coords_mean, 0,
clusters_idx[:, 0].cuda().long())
clusters_coords -= clusters_coords_mean
clusters_coords_min = softgroup_ops.sec_min(clusters_coords, clusters_offset.cuda())
clusters_coords_max = softgroup_ops.sec_max(clusters_coords, clusters_offset.cuda())
clusters_coords_min = sec_min(clusters_coords, clusters_offset.cuda())
clusters_coords_max = sec_max(clusters_coords, clusters_offset.cuda())
clusters_scale = 1 / (
(clusters_coords_max - clusters_coords_min) / spatial_shape).max(1)[0] - 0.01
@ -465,9 +461,9 @@ class SoftGroup(nn.Module):
clusters_coords = torch.cat([clusters_idx[:, 0].view(-1, 1).long(),
clusters_coords.cpu()], 1)
out_coords, inp_map, out_map = softgroup_ops.voxelization_idx(clusters_coords,
int(clusters_idx[-1, 0]) + 1)
out_feats = softgroup_ops.voxelization(clusters_feats, out_map.cuda())
out_coords, inp_map, out_map = voxelization_idx(clusters_coords,
int(clusters_idx[-1, 0]) + 1)
out_feats = voxelization(clusters_feats, out_map.cuda())
spatial_shape = [spatial_shape] * 3
voxelization_feats = spconv.SparseConvTensor(out_feats,
out_coords.int().cuda(), spatial_shape,
@ -487,7 +483,7 @@ class SoftGroup(nn.Module):
batch_offset = torch.cumsum(batch_counts, dim=0)
pad = batch_offset.new_full((1, ), 0)
batch_offset = torch.cat([pad, batch_offset]).int()
x_pool = softgroup_ops.global_avg_pool(x.features, batch_offset)
x_pool = global_avg_pool(x.features, batch_offset)
if not expand:
return x_pool

View File

@ -1,3 +1,3 @@
from .logger import get_root_logger
from .optim import build_optimizer
from .utils import get_max_memory
from .utils import *

View File

@ -8,15 +8,13 @@ def get_root_logger(log_file=None, log_level=logging.INFO):
if logger.hasHandlers():
return logger
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)

14
test.py
View File

@ -3,17 +3,13 @@ import numpy as np
import random
import torch
import yaml
from softgroup.data import build_dataloader, build_dataset
from munch import Munch
from tqdm import tqdm
import util.utils as utils
from evaluation import ScanNetEval
from model.softgroup import SoftGroup
from data.scannetv2 import ScanNetDataset
from torch.utils.data import DataLoader
from util import get_root_logger
from data import build_dataset, build_dataloader
from softgroup.evaluation import ScanNetEval
from softgroup.model import SoftGroup
from softgroup.util import get_root_logger, load_checkpoint
def get_args():
@ -78,7 +74,7 @@ if __name__ == '__main__':
model = SoftGroup(**cfg.model)
logger.info(f'Load state dict from {args.checkpoint}')
utils.load_checkpoint(args.checkpoint, logger, model)
load_checkpoint(args.checkpoint, logger, model)
model.cuda()
dataset = build_dataset(cfg.data.test, logger)

View File

@ -9,12 +9,13 @@ import sys
import time
import torch
import yaml
from data import build_dataloader, build_dataset
from model.softgroup import SoftGroup
from munch import Munch
from tensorboardX import SummaryWriter
from data import build_dataloader, build_dataset
from model.softgroup import SoftGroup
from util import build_optimizer, get_max_memory, get_root_logger, utils
from softgroup.util import (AverageMeter, build_optimizer, checkpoint_save, cosine_lr_after_step,
get_max_memory, get_root_logger, load_checkpoint)
def eval_epoch(val_loader, model, model_fn, epoch):
@ -32,7 +33,7 @@ def eval_epoch(val_loader, model, model_fn, epoch):
for k, v in meter_dict.items():
if k not in am_dict.keys():
am_dict[k] = utils.AverageMeter()
am_dict[k] = AverageMeter()
am_dict[k].update(v[0], v[1])
sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(
i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
@ -97,31 +98,30 @@ if __name__ == '__main__':
start_epoch = 1
if args.resume:
logger.info(f'Resume from {args.resume}')
start_epoch = utils.load_checkpoint(args.resume, logger, model, optimizer=optimizer)
start_epoch = load_checkpoint(args.resume, logger, model, optimizer=optimizer)
elif cfg.pretrain:
logger.info(f'Load pretrain from {cfg.pretrain}')
utils.load_checkpoint(cfg.pretrain, logger, model)
load_checkpoint(cfg.pretrain, logger, model)
# train and val
logger.info('Training')
for epoch in range(start_epoch, cfg.epochs + 1):
model.train()
iter_time = utils.AverageMeter()
data_time = utils.AverageMeter()
iter_time = AverageMeter()
data_time = AverageMeter()
meter_dict = {}
end = time.time()
for i, batch in enumerate(train_loader, start=1):
data_time.update(time.time() - end)
utils.cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch,
cfg.epochs)
cosine_lr_after_step(optimizer, cfg.optimizer.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
loss, log_vars = model(batch, return_loss=True)
# meter_dict
for k, v in log_vars.items():
if k not in meter_dict.keys():
meter_dict[k] = utils.AverageMeter()
meter_dict[k] = AverageMeter()
meter_dict[k].update(v[0], v[1])
# backward
@ -148,4 +148,4 @@ if __name__ == '__main__':
for k, v in meter_dict.items():
log_str += f', {k}: {v.val:.4f}'
logger.info(log_str)
utils.checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)
checkpoint_save(epoch, model, optimizer, cfg.work_dir, cfg.save_freq)