fix model.eval() return None in non-distributed mode

This commit is contained in:
Thang Vu 2022-05-08 14:24:35 +00:00
parent e0c3b4e1ab
commit d8665970f9
2 changed files with 2 additions and 2 deletions

View File

@ -95,7 +95,7 @@ def main():
_, world_size = get_dist_info()
progress_bar = tqdm(total=len(dataloader) * world_size, disable=not is_main_process())
with torch.no_grad():
model = model.eval()
model.eval()
for i, batch in enumerate(dataloader):
result = model(batch)
results.append(result)

View File

@ -89,7 +89,7 @@ def validate(epoch, model, val_loader, cfg, logger, writer):
progress_bar = tqdm(total=len(val_loader) * world_size, disable=not is_main_process())
val_set = val_loader.dataset
with torch.no_grad():
model = model.eval()
model.eval()
for i, batch in enumerate(val_loader):
result = model(batch)
results.append(result)