Skip to content

Conversation

jlamypoirier
Copy link
Collaborator

✨ Description

Please provide a brief summary of the changes, relevant motivation, and context.
Include any related issue numbers or links to discussions, and explain why this change is necessary.

Closes #

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@jlamypoirier jlamypoirier changed the base branch from main to debug_mamba July 24, 2025 19:21
@jlamypoirier jlamypoirier changed the base branch from debug_mamba to concatenated_dim July 28, 2025 22:11
@jlamypoirier jlamypoirier marked this pull request as ready for review July 29, 2025 22:04
@@ -284,10 +284,15 @@ def test_load_pretrained(
@pytest.mark.model_testing_group(ModelTestingGroup.convert)
def test_huggingface_model(model_testing_config, get_convert_path):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test isn't working for SSM because the HF wrapper can't find the external model. I would like to make it work so we have at least one correctness test. Any idea how to make it work? (@bigximik @oleksost @tscholak ?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my Vision+Hybrid-ssm PR, I updated the SSM conversion to copy the modeling files to the export directory https://github.com/ServiceNow/Fast-LLM/pull/332/files#diff-58be369d99e6722a68e734002686ae4afcfd423261e4d3d3b9d6aa552a6f2a14R729-R784
But this PR is far from being merged ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to add them here too, but the external models don't seem to be working

FAILED tests/models/test_checkpoint.py::test_huggingface_model[hybrid_mamba2]@dependency_group_2 - AttributeError: 'DynamicCache' object has no attribute 'has_previous_state'
FAILED tests/models/test_checkpoint.py::test_huggingface_model[hybrid_discrete_mamba2]@dependency_group_3 - AttributeError: 'NoneType' object has no attribute 'ssm_states'

@jlamypoirier jlamypoirier requested review from RaymondLi0, oleksost, nitsanluke, tscholak and bigximik and removed request for RaymondLi0 July 31, 2025 19:10
@@ -27,23 +27,7 @@
except (ImportError, RuntimeError):
_causal_conv1d_available = False


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not keeping this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From, what I'm seeing there is absolutely no benefit over calling repeat_interleave directly. I tried to figure out why it's there, got two hypotheses:

  • expand is preferable over repeat because of the absence of copy. That's pointless because the copy is still done in reshape below. And there is a second copy on each usage (contiguous), so the function actually makes things slower...
  • repeat_interleave may involve cuda synchronization because it supports tensor inputs. But that's not supposed to happen, and the explicit output_size ensures it.

else:
head_dim = state

tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These var names are somewhat confusing. Wouldn't this be clearer?

  • num_head_groups --> num_heads_per_group -- this is the number of heads in each group (e.g. div(self.d_xb, self.state_size), where head dim. is self.state_size)
  • TensorDim head_groups --> heads_per_group
  • group_heads --> head_groups --- this is number groups (like in GKV), so number of groups with num_head_groups heads in each group

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names are supposed to be correct, i.e. head_groups and num_ head_groups refers to the different head groups and the number of such groups, and group_heads refers to heads inside a group. I might have inverted them by mistake though, I'll double-check. (They should be right here.)

I'm not a huge fan of the term group_heads though, so I'm open to renaming. (heads_per_group? heads_in_group?) What do you think?

Copy link
Contributor

@oleksost oleksost Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense, so the current names are correct. I guess for group_heads it would make sense to call it heads_per_group

bias=config.add_bias_linear,
weight_init_method=init_kaiming_(self._config.d_inner),
sequence_parallel=self._sequence_parallel,
# TODO: lr_scale?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lr_scale=lr_scale?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making sure it's not intentionally absent. It's ok to add?

Same about the normalization layer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeh, should be added everywhere, I think we missed it before.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

)
tensor_space.add_tensor_dim(
head_groups_and_state := CompositeTensorDim(
SSMDimNames.composite_head_groups_and_state, (head_groups, state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SSMDimNames.composite_head_groups_and_state, (head_groups, head_dim) instead of SSMDimNames.composite_head_groups_and_state, (head_groups, state)? state and head_dim must not be the same (currently they are)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In mamba and mamba 2 head_dim and state are the same, so the two are equivalent. The distinction is for discrete mamba.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ended up being too confusing though, so in the next PR the dimensions are moved directly in the model with model-specific names https://github.com/ServiceNow/Fast-LLM/pull/339/files#diff-646273ef4d1b740f2ea28a29b34ca91517f5ea91ee1530ede13acfc44d30f1b1R55

Copy link
Contributor

@oleksost oleksost Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the dimentions into the model would be much less confusing!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to move those in the model already in this pr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to limit the amount of extra work since #339 is already way ahead of this PR, but I'll check if I can do it easily.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, wasn't too bad


# inner_projection : (batch/local_sequence, local_sequence/batch, hidden)
# -> (batch/sequence, sequence/batch, inner_projection)
inner_projection = self.in_proj(input_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is the best place to discuss this, but I would like to better understand how TP is working here.
My understanding of the shapes is case of TP=2:

  • input_ -- bs x seq/2 x d
  • inner_projection -- seq x bs x d -- so here the output is already gathered across the ranks?
  • all the way down to out_proj we operate on full seq. length
  • the out_proj splits the seq. length in 2 again

Is that correct? So we only apply TP to in_proj and out_proj (+ also dt_proj). So does it mean that inner_projection is identical across the ranks and all the ranks are performing redundant work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inner_projection is tensor-parallel, ex. with size seq x bs x d/2. I'll change to (batch/sequence, sequence/batch, local_inner_projection) to clear things up.

Copy link
Contributor

@oleksost oleksost Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant in case of sequence_tensor_parallel=True, where we split over sequence length. Sequence_tensor_parallel should work, right? (my experiments show that loss and grad-norms are very similar in STP and non-TP, so I assume it works).

So in TP=2 case, each rank processes half of the inner_projection, that works fine.

In TP=2 and sequence_tensor_parallel=True each rank processes half of the sequence dim.? I just don't quite understand how it works in this case, since input_ is bs x seq/2 x d, but inner_projection is already seq x bs x d --- so does it mean everything afterin_proj operates on full sequence length (which makes sense since selective_scan_fn does not support STP)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if that's what you mean, but in sequence-tensor-parallel we split the activations that aren't already tensor-parallel over the sequence (mainly hidden states), so that all (or almost all) of them are split in the TP direction in one way or another. For mamba (and attn) that means the input is sequence-parallel, then we (gather the full tensor)[https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/linear.py#L114] and run inner_projection with standard TP which leaves a TP output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mamba (and attn) that means the input is sequence-parallel, then we (gather the full tensor)[https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/linear.py#L114] and run inner_projection with standard TP which leaves a TP output.
So does it mean if the input to mamba is sequence-parallel (bs x seq./tp x d), then in_proj gathers it into bs x seq. x d, then doing the projection with standard TP means it should output bs x seq. x d/tp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants