From d8665970f91aaf6ef1a0361c70103bc9aa67084d Mon Sep 17 00:00:00 2001 From: Thang Vu Date: Sun, 8 May 2022 14:24:35 +0000 Subject: [PATCH] fix model.eval() return None in non-distributed mode --- tools/test.py | 2 +- tools/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/test.py b/tools/test.py index 6bdea9a..49b5d6d 100644 --- a/tools/test.py +++ b/tools/test.py @@ -95,7 +95,7 @@ def main(): _, world_size = get_dist_info() progress_bar = tqdm(total=len(dataloader) * world_size, disable=not is_main_process()) with torch.no_grad(): - model = model.eval() + model.eval() for i, batch in enumerate(dataloader): result = model(batch) results.append(result) diff --git a/tools/train.py b/tools/train.py index 5ddf9d2..4df59dc 100644 --- a/tools/train.py +++ b/tools/train.py @@ -89,7 +89,7 @@ def validate(epoch, model, val_loader, cfg, logger, writer): progress_bar = tqdm(total=len(val_loader) * world_size, disable=not is_main_process()) val_set = val_loader.dataset with torch.no_grad(): - model = model.eval() + model.eval() for i, batch in enumerate(val_loader): result = model(batch) results.append(result)