From 4852f2efb1ccfa9cfda1a49519a633d21a170fe7 Mon Sep 17 00:00:00 2001 From: Ahmed Shahin Date: Thu, 11 Mar 2021 13:17:16 +0200 Subject: [PATCH] Fix handling input_size with multi-input --- torchsummary/torchsummary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..b77113c 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -98,7 +98,9 @@ def hook(module, input, output): summary_str += line_new + "\n" # assume 4 bytes/number (float on cuda). - total_input_size = abs(np.prod(sum(input_size, ())) + # to handle the case of multi-input: prod(input1) + prod(input2) + ... + n_input_size = np.array([np.prod(i) for i in input_size]).sum() if isinstance(input_size, list) else np.prod(input_size) + total_input_size = abs(n_input_size * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients