From c0c104ef14ced6563b90867bdc9a97a426836758 Mon Sep 17 00:00:00 2001 From: argo <121647050+timbektu@users.noreply.github.com> Date: Fri, 14 Jul 2023 17:33:24 -0700 Subject: [PATCH] [bug-fix] for batch_size=1, squeeze() operation also squeezes the first (batch_size) dimension, and the code breaks when using batch_size =1. patched the bug in this commit. --- criteria/moco_loss.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/criteria/moco_loss.py b/criteria/moco_loss.py index 8fb13fb..515b056 100644 --- a/criteria/moco_loss.py +++ b/criteria/moco_loss.py @@ -43,7 +43,9 @@ def extract_feats(self, x): x = F.interpolate(x, size=224) x_feats = self.model(x) x_feats = nn.functional.normalize(x_feats, dim=1) - x_feats = x_feats.squeeze() + # x_feats = x_feats.squeeze() + bs, feat_dim, _, _ = x_feats + x_feats = x_feats.reshape((bs, feat_dum)) return x_feats def forward(self, y_hat, y, x):