-
Notifications
You must be signed in to change notification settings - Fork 35
Tensor-parallel SSM #333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: concatenated_dim
Are you sure you want to change the base?
Tensor-parallel SSM #333
Conversation
@@ -284,10 +284,15 @@ def test_load_pretrained( | |||
@pytest.mark.model_testing_group(ModelTestingGroup.convert) | |||
def test_huggingface_model(model_testing_config, get_convert_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my Vision+Hybrid-ssm PR, I updated the SSM conversion to copy the modeling files to the export directory https://github.com/ServiceNow/Fast-LLM/pull/332/files#diff-58be369d99e6722a68e734002686ae4afcfd423261e4d3d3b9d6aa552a6f2a14R729-R784
But this PR is far from being merged ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I managed to add them here too, but the external models don't seem to be working
FAILED tests/models/test_checkpoint.py::test_huggingface_model[hybrid_mamba2]@dependency_group_2 - AttributeError: 'DynamicCache' object has no attribute 'has_previous_state'
FAILED tests/models/test_checkpoint.py::test_huggingface_model[hybrid_discrete_mamba2]@dependency_group_3 - AttributeError: 'NoneType' object has no attribute 'ssm_states'
@@ -27,23 +27,7 @@ | |||
except (ImportError, RuntimeError): | |||
_causal_conv1d_available = False | |||
|
|||
|
|||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not keeping this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From, what I'm seeing there is absolutely no benefit over calling repeat_interleave
directly. I tried to figure out why it's there, got two hypotheses:
expand
is preferable overrepeat
because of the absence of copy. That's pointless because the copy is still done inreshape
below. And there is a second copy on each usage (contiguous
), so the function actually makes things slower...repeat_interleave
may involve cuda synchronization because it supports tensor inputs. But that's not supposed to happen, and the explicitoutput_size
ensures it.
fast_llm/layers/ssm/config.py
Outdated
else: | ||
head_dim = state | ||
|
||
tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These var names are somewhat confusing. Wouldn't this be clearer?
num_head_groups
-->num_heads_per_group
-- this is the number of heads in each group (e.g. div(self.d_xb, self.state_size), where head dim. is self.state_size)- TensorDim
head_groups
-->heads_per_group
group_heads
-->head_groups
--- this is number groups (like in GKV), so number of groups withnum_head_groups
heads in each group
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The names are supposed to be correct, i.e. head_groups
and num_ head_groups
refers to the different head groups and the number of such groups, and group_heads
refers to heads inside a group. I might have inverted them by mistake though, I'll double-check. (They should be right here.)
I'm not a huge fan of the term group_heads
though, so I'm open to renaming. (heads_per_group
? heads_in_group
?) What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, makes sense, so the current names are correct. I guess for group_heads
it would make sense to call it heads_per_group
fast_llm/layers/ssm/mamba2.py
Outdated
bias=config.add_bias_linear, | ||
weight_init_method=init_kaiming_(self._config.d_inner), | ||
sequence_parallel=self._sequence_parallel, | ||
# TODO: lr_scale? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lr_scale=lr_scale?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just making sure it's not intentionally absent. It's ok to add?
Same about the normalization layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeh, should be added everywhere, I think we missed it before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
fast_llm/layers/ssm/config.py
Outdated
) | ||
tensor_space.add_tensor_dim( | ||
head_groups_and_state := CompositeTensorDim( | ||
SSMDimNames.composite_head_groups_and_state, (head_groups, state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SSMDimNames.composite_head_groups_and_state, (head_groups, head_dim)
instead of SSMDimNames.composite_head_groups_and_state, (head_groups, state)
? state and head_dim must not be the same (currently they are)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In mamba and mamba 2 head_dim and state are the same, so the two are equivalent. The distinction is for discrete mamba.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ended up being too confusing though, so in the next PR the dimensions are moved directly in the model with model-specific names https://github.com/ServiceNow/Fast-LLM/pull/339/files#diff-646273ef4d1b740f2ea28a29b34ca91517f5ea91ee1530ede13acfc44d30f1b1R55
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving the dimentions into the model would be much less confusing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to move those in the model already in this pr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to limit the amount of extra work since #339 is already way ahead of this PR, but I'll check if I can do it easily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, wasn't too bad
|
||
# inner_projection : (batch/local_sequence, local_sequence/batch, hidden) | ||
# -> (batch/sequence, sequence/batch, inner_projection) | ||
inner_projection = self.in_proj(input_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is the best place to discuss this, but I would like to better understand how TP is working here.
My understanding of the shapes is case of TP=2:
input_
-- bs x seq/2 x dinner_projection
-- seq x bs x d -- so here the output is already gathered across the ranks?- all the way down to out_proj we operate on full seq. length
- the
out_proj
splits the seq. length in 2 again
Is that correct? So we only apply TP to in_proj
and out_proj
(+ also dt_proj
). So does it mean that inner_projection
is identical across the ranks and all the ranks are performing redundant work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inner_projection
is tensor-parallel, ex. with size seq x bs x d/2
. I'll change to (batch/sequence, sequence/batch, local_inner_projection)
to clear things up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I meant in case of sequence_tensor_parallel=True
, where we split over sequence length. Sequence_tensor_parallel should work, right? (my experiments show that loss and grad-norms are very similar in STP and non-TP, so I assume it works).
So in TP=2 case, each rank processes half of the inner_projection
, that works fine.
In TP=2
and sequence_tensor_parallel=True
each rank processes half of the sequence dim.? I just don't quite understand how it works in this case, since input_
is bs x seq/2 x d, but inner_projection
is already seq x bs x d --- so does it mean everything afterin_proj
operates on full sequence length (which makes sense since selective_scan_fn
does not support STP)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if that's what you mean, but in sequence-tensor-parallel we split the activations that aren't already tensor-parallel over the sequence (mainly hidden states), so that all (or almost all) of them are split in the TP direction in one way or another. For mamba (and attn) that means the input is sequence-parallel, then we (gather the full tensor)[https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/linear.py#L114] and run inner_projection
with standard TP which leaves a TP output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For mamba (and attn) that means the input is sequence-parallel, then we (gather the full tensor)[https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/linear.py#L114] and run inner_projection with standard TP which leaves a TP output.
So does it mean if the input to mamba is sequence-parallel (bs x seq./tp x d
), then in_proj
gathers it into bs x seq. x d
, then doing the projection with standard TP means it should output bs x seq. x d/tp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
✨ Description
Please provide a brief summary of the changes, relevant motivation, and context.
Include any related issue numbers or links to discussions, and explain why this change is necessary.
Closes #
🔍 Type of change
Select all that apply: