Skip to content

transformer_asr.py: incorrect source_maxlen #1389

Open
@MicahDoo

Description

@MicahDoo

https://github.com/keras-team/keras-io/blob/master/examples/audio/transformer_asr.py

In the code at the above link, I found that source_maxlen is defaulted to 100 in the transformer.
The problem, though, is that the inputs are actually padded to length 2754, where it's then downsampled with CNN by a factor of 8. The result is a sequence of length 345, which is far greater than 2754.
Correct me if I am wrong, but I reckon that is a bug?

Problem code:

In the transformer definition, source_maxlen is defaulted to 100:

class Transformer(keras.Model):
    def __init__(
        self,
        num_hid=64,
        num_head=2,
        num_feed_forward=128,
        source_maxlen=100,
        target_maxlen=100,
        num_layers_enc=4,
        num_layers_dec=1,
        num_classes=10,
    ):

... which isn't explicitly set at instantiation:

model = Transformer(
    num_hid=200,
    num_head=2,
    num_feed_forward=400,
    target_maxlen=max_target_len,
    num_layers_enc=4,
    num_layers_dec=1,
    num_classes=34,
)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions