Skip to content

Multiple inputs with different dtype error #102

@trajepl

Description

@trajepl

When using multiple inputs with different types, torchsummary generates random inputs with same type torch.FloatTensor

You can delete the assignment of dtype , then pass it as a parameter to get differnt random inputs with various types:

    ## del
    # if device == "cuda" and torch.cuda.is_available():
    #     dtype = torch.cuda.FloatTensor
    # else:
    #     dtype = torch.FloatTensor

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    # modified
    x = [torch.rand(*in_size).type(dtype) for in_size, dtype in input_size]
    # print(type(x[0]))

Then use to get the summary of your model:
summary(model, [((size1), dtype1), ((size2), dtype2)])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions