Skip to content

Add LLaDA 8b Diffusion model #14771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Conversation

am17an
Copy link
Collaborator

@am17an am17an commented Jul 19, 2025

Continuing on #14644, this PR adds another diffusion model https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct, which has different semantics compared to the dream-7b model, and overall seems to have better performance

There are very few similarities between how they seem to generate tokens, so for now I've just created two different examples llama-diffusion-dream-cli (for the earlier version) and llama-diffusion-llada-cli, for running the new LLaDA model. Added a README as well

I've uploaded a GGUF.

Example command
./build/bin/llama-diffusion-llada-cli -m llada-8b.gguf -p "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" --diffusion_steps 128 -ngl 99 --temp 0 -ub 128 --diffusion-visual

Also I would like this to the server, but I'm not sure what API would be acceptable so I'm hoping to have a discussion on that as well

@github-actions github-actions bot added examples python python script changes labels Jul 19, 2025
@am17an am17an requested a review from ggerganov July 19, 2025 10:06
@am17an am17an requested a review from CISC July 19, 2025 11:05
@am17an am17an force-pushed the add_llada_8b branch 3 times, most recently from e4b7346 to 5644f2f Compare July 19, 2025 14:59
@ggerganov
Copy link
Member

I would like to avoid adding a second diffusion example - we are increasing the maintenance efforts for not significant benefit. The diffusion architecture is not yet well established.

We can think about extending the llama_sampler functionality to support these use cases and since it is already modular it would make more sense to implement the sampling logic there. Ideally the diffusion CLI example would be just one for all diffusion models, with different samplers attached.

@am17an
Copy link
Collaborator Author

am17an commented Jul 21, 2025

I would like to avoid adding a second diffusion example - we are increasing the maintenance efforts for not significant benefit. The diffusion architecture is not yet well established.

We can think about extending the llama_sampler functionality to support these use cases and since it is already modular it would make more sense to implement the sampling logic there. Ideally the diffusion CLI example would be just one for all diffusion models, with different samplers attached.

Yeah agree, I initially wrote them to be one example. However, passing arguments via CLI for two separate sets of sampling parameters/algorithms was quite confusing to me and would be even more so for the end-user, so for the sake of clarity I wrote them separately.
diffusion_generate_dream and diffusion_generate_llada are two different functions with the same outline, decode => sample => unmask, so there is an abstraction to be made, the only thing is to clarify is how we pass separate sets of parameters to the example without overloading the same thing (e.g. --diffusion-algorithm being supported in dream but not llada and vice versa), llama_sampler be used also, but I don't see how it would solve this particular problem

@am17an
Copy link
Collaborator Author

am17an commented Jul 23, 2025

@ggerganov would having them in the same example and having extra CLI args for models be acceptable?

@ggerganov
Copy link
Member

Yes, merging the examples into a single example would be better.

@am17an
Copy link
Collaborator Author

am17an commented Jul 26, 2025

Yes, merging the examples into a single example would be better.

Made everything into a single example, please have another look when you have the time

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the example can be improved by not branching between "llada" and "dream" and instead have a common logic for any diffusion logic. This would make it much easier to scale with more diffusion models in the future. Otherwise, the way you've implemented it now, you have to add new structs, sampling types, generation functions, etc. for each new architecture and this seems a bit unnecessary.

Comment on lines +761 to +768
// For LLaDA models, forcefully add BOS token at the beginning. TODO: check why
if (arch == "llada") {
llama_token bos_token = llama_vocab_bos(vocab);
if (bos_token != LLAMA_TOKEN_NULL && (input_tokens.empty() || input_tokens[0] != bos_token)) {
input_tokens.insert(input_tokens.begin(), bos_token);
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled by the meta data in the GGUF model. There is a boolean field for when BOS is needed or not.

@am17an
Copy link
Collaborator Author

am17an commented Jul 28, 2025

@ggerganov you're right, we can combine the sampling methods. I was under the assumption that the only sampling methods that would work are their respective paper implementations, but I tried various sampling methods on both models and they seem to have coherent outputs, but I did not do any deep correctness checks.

Refactored to have a concept called schedule which is either timestep based (like dream) or block-based (like LLaDA). Both work for both models. Also refactored the sampling methods to be the same across the models.

The issues that do remain however,

  1. Shifted logits - logits in Dream are shifted by -1 after a pp path, which is not the case in LLaDA. Ideally this should be a part of the GGUF, but I'm not sure.
  2. The BOS token in LLaDA - add_bos_token is false in tokenizer_config.json, I think because the chat_template contains the bos_token.
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",

However, this code removes this BOS

llama.cpp/common/chat.cpp

Lines 746 to 755 in c35f9ea

minja::chat_template_options tmpl_opts;
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());

I'm not familiar with chat-template code and I was not able to work around this without adding a bos token

@am17an am17an force-pushed the add_llada_8b branch 2 times, most recently from cb015b4 to cf10ebf Compare July 28, 2025 07:27
@CISC
Copy link
Collaborator

CISC commented Jul 28, 2025

2. The BOS token in LLaDA - `add_bos_token` is false in `tokenizer_config.json`, I think because the chat_template contains the `bos_token`.

No, add_bos_token only applies to untemplated generation, it seems like a mistake. It was removed in LLaDA 1.5 chat template BTW.

Edit: Nvm, I'm blind, it's still there.

However, this code removes this BOS

llama.cpp/common/chat.cpp

Lines 746 to 755 in c35f9ea

minja::chat_template_options tmpl_opts;
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());

This probably needs to be improved.

I'm not familiar with chat-template code and I was not able to work around this without adding a bos token

Setting add_bos_token to True on conversion should fix that, but only applies to pre-1.5 models.

@am17an
Copy link
Collaborator Author

am17an commented Jul 28, 2025

Setting add_bos_token to True on conversion should fix that, but only applies to pre-1.5 models.

Yep, this fixes it for regenerated gguf. Though it might be a problem downstream if people use the HF repo to create quants (unless they patch this in the HF repo)

Comment on lines +2936 to +2937
# Add LLaDA-specific parameters
mask_token_id = self.hparams.get("mask_token_id")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Add LLaDA-specific parameters
mask_token_id = self.hparams.get("mask_token_id")
# Add LLaDA-specific parameters
self.gguf_writer.add_add_bos_token(True)
mask_token_id = self.hparams.get("mask_token_id")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think we can add diffusion_shift_logits to the GGUF for this, which by default can be true (for backwards combability with Dream models), and set it to false here. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, AFAICT it isn't configurable in the model, so perhaps just do it based on arch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're saying the same thing, but maybe just for clarity - we add another option to gguf_writer called diffusion_shift_logits, which is true by default, and for the LLaDA arch we set it to false, and then read it like how we're reading the mask token in this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking along the lines of just doing the shift based on arch, and not storing that in the GGUF at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, if I understand that would still keep the arch-specific code in the example. Or are you suggesting something else?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, you do that already, should be fine IMO.

Could be moved out to load_hparams as a static value, but I don't think that's warranted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants