diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..43c252f 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import funcy from torch.autograd import Variable from collections import OrderedDict @@ -79,9 +80,9 @@ def hook(module, input, output): line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") print(line_new) print("================================================================") - total_params = 0 + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_output = 0 - trainable_params = 0 for layer in summary: # input_shape, output_shape, trainable, nb_params line_new = "{:>20} {:>25} {:>15}".format( @@ -89,17 +90,13 @@ def hook(module, input, output): str(summary[layer]["output_shape"]), "{0:,}".format(summary[layer]["nb_params"]), ) - total_params += summary[layer]["nb_params"] - total_output += np.prod(summary[layer]["output_shape"]) - if "trainable" in summary[layer]: - if summary[layer]["trainable"] == True: - trainable_params += summary[layer]["nb_params"] + total_output += np.prod(list(funcy.flatten(summary[layer]["output_shape"]))) print(line_new) # assume 4 bytes/number (float on cuda). - total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.)) + total_input_size = abs(sum([np.prod(input_item) for input_item in input_size]) * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients - total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.)) + total_params_size = abs(total_params * 4. / (1024 ** 2.)) total_size = total_params_size + total_output_size + total_input_size print("================================================================") @@ -113,3 +110,4 @@ def hook(module, input, output): print("Estimated Total Size (MB): %0.2f" % total_size) print("----------------------------------------------------------------") # return summary +