mirror of
https://github.com/botastic/SoftGroup.git
synced 2025-10-16 11:45:42 +00:00
30 lines
973 B
Python
30 lines
973 B
Python
# Convert spconv1 checkpoint to spconv2 checkpoint
|
|
import argparse
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('checkpoint', type=str, help='spconv1 checkpoint')
|
|
args = parser.parse_args()
|
|
|
|
checkpoint = torch.load(args.checkpoint)
|
|
model = checkpoint['net']
|
|
new_model = OrderedDict()
|
|
|
|
for k, v in model.items():
|
|
new_k, new_v = k, v
|
|
if 'weight' in k and len(v.size()) == 5:
|
|
# KKKIO to OKKKI (0, 1, 2, 3, 4) -> (4, 0, 1, 2, 3)
|
|
new_v = v.permute(4, 0, 1, 2, 3)
|
|
if 'intra_ins_unet' in k:
|
|
new_k = k.replace('intra_ins_unet', 'tiny_unet')
|
|
elif 'score_linear' in new_k:
|
|
new_k = k.replace('score_linear', 'iou_score_linear')
|
|
elif 'intra_ins_outputlayer' in k:
|
|
new_k = k.replace('intra_ins_outputlayer', 'tiny_unet_outputlayer')
|
|
new_model[new_k] = new_v
|
|
|
|
checkpoint['net'] = new_model
|
|
torch.save(checkpoint, args.checkpoint.replace('.pth', '_spconv2.pth'))
|