-
Notifications
You must be signed in to change notification settings - Fork 415
Closed
Description
When using multiple inputs with different types, torchsummary generates random inputs with same type torch.FloatTensor
You can delete the assignment of dtype
, then pass it as a parameter to get differnt random inputs with various types:
## del
# if device == "cuda" and torch.cuda.is_available():
# dtype = torch.cuda.FloatTensor
# else:
# dtype = torch.FloatTensor
# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]
# batch_size of 2 for batchnorm
# modified
x = [torch.rand(*in_size).type(dtype) for in_size, dtype in input_size]
# print(type(x[0]))
Then use to get the summary of your model:
summary(model, [((size1), dtype1), ((size2), dtype2)])
QiaoJim and mtt1998
Metadata
Metadata
Assignees
Labels
No labels