MultiHeadAttentionWrapper should instantiate CausalSelfAttention with d_out = d_out // num_heads? #609
Unanswered
henrythe9th
asked this question in
Q&A
Replies: 1 comment 1 reply
-
I believe the confusion lies in how we are interpreting In your impl, It's true that it's clearer in the sense that the |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Since the
MultiHeadAttentionWrapper
class callstorch.cat([head(x) for head in self.heads], dim=-1)
shouldn't we be instantiating
CausalSelfAttention
with d_out = d_out // num_heads so that the finalMultiHeadAttentionWrapper
output has the same shape and d_out as was specified in the input?In other words, is this a clearer implementation?
Beta Was this translation helpful? Give feedback.
All reactions