Skip to content

Commit 3b1d43d

Browse files
committed
llama4 ckpy conversion
1 parent 3a83b61 commit 3b1d43d

File tree

2 files changed

+110
-3
lines changed

2 files changed

+110
-3
lines changed

MaxText/llama4_ckpt_unscanned.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _hf_to_maxtext_mapping(layer_idx: int = -1) -> dict:
136136
"""
137137
# pylint: disable=line-too-long
138138
return {
139+
## language model mappping
139140
"language_model.model.embed_tokens.weight": "tok_embeddings.weight",
140141
"language_model.model.norm.weight": "norm.weight",
141142
"language_model.lm_head.weight": "output.weight",
@@ -145,18 +146,43 @@ def _hf_to_maxtext_mapping(layer_idx: int = -1) -> dict:
145146
f"language_model.model.layers.{layer_idx}.self_attn.k_proj.weight": f"layers.{layer_idx}.attention.wk.weight",
146147
f"language_model.model.layers.{layer_idx}.self_attn.v_proj.weight": f"layers.{layer_idx}.attention.wv.weight",
147148
f"language_model.model.layers.{layer_idx}.self_attn.o_proj.weight": f"layers.{layer_idx}.attention.wo.weight",
148-
# MoE
149+
# MoE in language model
149150
f"language_model.model.layers.{layer_idx}.feed_forward.router.weight": f"layers.{layer_idx}.feed_forward.gate.weight",
150151
f"language_model.model.layers.{layer_idx}.feed_forward.experts.down_proj": f"layers.{layer_idx}.feed_forward.experts.down_proj",
151152
# NOTE: this contains up_proj and gate_proj concated together (we'll split/chunk them later)
152153
f"language_model.model.layers.{layer_idx}.feed_forward.experts.gate_up_proj": f"layers.{layer_idx}.feed_forward.experts.gate_up_proj",
153154
f"language_model.model.layers.{layer_idx}.feed_forward.shared_expert.gate_proj.weight": f"layers.{layer_idx}.feed_forward.shared_experts.gate_proj.weight",
154155
f"language_model.model.layers.{layer_idx}.feed_forward.shared_expert.down_proj.weight": f"layers.{layer_idx}.feed_forward.shared_experts.down_proj.weight",
155156
f"language_model.model.layers.{layer_idx}.feed_forward.shared_expert.up_proj.weight": f"layers.{layer_idx}.feed_forward.shared_experts.up_proj.weight",
156-
# FFN
157+
# FFN in language model
157158
f"language_model.model.layers.{layer_idx}.feed_forward.up_proj.weight": f"layers.{layer_idx}.feed_forward.w1.weight",
158159
f"language_model.model.layers.{layer_idx}.feed_forward.gate_proj.weight": f"layers.{layer_idx}.feed_forward.w3.weight",
159160
f"language_model.model.layers.{layer_idx}.feed_forward.down_proj.weight": f"layers.{layer_idx}.feed_forward.w2.weight",
161+
162+
# ## vision model mapping
163+
# "vision_model.class_embedding": "vision_encoder.",
164+
# "vision_model.positional_embedding_vlm": "",
165+
# "vision_model.patch_embedding.linear.weight": "",
166+
# "vision_model.layernorm_pre.weight": "",
167+
# "vision_model.layernorm_pre.bias": "",
168+
# "vision_model.layernorm_post.weight": "",
169+
# "vision_model.layernorm_post.bias": "",
170+
# "vision_model.model.layers.{layer_idx}.input_layernorm.weight": "",
171+
# "vision_model.model.layers.{layer_idx}.input_layernorm.bias": "",
172+
# "vision_model.model.layers.{layer_idx}.self_attn.q_proj.weight": "",
173+
# "vision_model.model.layers.{layer_idx}.self_attn.q_proj.bias": "",
174+
# "vision_model.model.layers.{layer_idx}.self_attn.k_proj.weight": "",
175+
# "vision_model.model.layers.{layer_idx}.self_attn.k_proj.bias": "",
176+
# "vision_model.model.layers.{layer_idx}.self_attn.v_proj.weight": "",
177+
# "vision_model.model.layers.{layer_idx}.self_attn.v_proj.bias": "",
178+
# "vision_model.model.layers.{layer_idx}.self_attn.o_proj.weight": "",
179+
# "vision_model.model.layers.{layer_idx}.self_attn.o_proj.bias": "",
180+
# "vision_model.model.layers.{layer_idx}.post_attention_layernorm.weight": "",
181+
# "vision_model.model.layers.{layer_idx}.post_attention_layernorm.bias": "",
182+
# "vision_model.model.layers.{layer_idx}.mlp.fc1.weight": "",
183+
# "vision_model.model.layers.{layer_idx}.mlp.fc1.bias": "",
184+
# "vision_model.model.layers.{layer_idx}.mlp.fc2.weight": "",
185+
# "vision_model.model.layers.{layer_idx}.mlp.fc2.bias": "",
160186
}
161187

162188

@@ -194,6 +220,10 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
194220
Returns:
195221
jax_weights (dict): Dictionary containing the converted weights.
196222
"""
223+
num_hidden_layers_for_vit = model_params.get("num_layers_vit", 0)
224+
num_attention_heads_for_vit = model_params.get("num_att_head_vit", 0)
225+
hidden_size_for_vit = model_params.get("hidden_size_vit", 0)
226+
head_dim_for_vit = hidden_size_for_vit // num_attention_heads_for_vit
197227
base_num_decoder_layers = model_params["num_layers"]
198228
base_num_query_heads = model_params["num_heads"]
199229
head_dim = model_params["dims_per_head"]
@@ -217,7 +247,8 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
217247
layer = int(parts[3]) if "layers" in key else 0
218248
# TODO: update when mutli-modality support is added
219249
if "vision" in key or "multi_modal_projector" in key:
220-
print("WARNING: skipping vision or multi-modal key: ", key)
250+
#print("WARNING: skipping vision or multi-modal key: ", key)
251+
chkpt_vars[key] = f.get_tensor(key)
221252
else:
222253
mapped_key = _hf_to_maxtext_mapping(layer)[key]
223254
chkpt_vars[mapped_key] = f.get_tensor(key)
@@ -230,8 +261,81 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
230261
"logits_dense": {"kernel": None},
231262
},
232263
"token_embedder": {"embedding": None},
264+
"vision_encoder": {
265+
"Llama4VisionModel_0": {
266+
"Llama4VisionEncoder_0": None,
267+
"class_embedding": None,
268+
"positional_embedding_vlm": None,
269+
"Llama4UnfoldConvolution_0": None,
270+
"layernorm_pre": None,
271+
"layernorm_post": None,
272+
},
273+
"Llama4MultiModalProjector_0": {"vit_multi_modal_projector": {"kernel": None}},
274+
},
233275
}
234276

277+
# vision encoder ###########################################
278+
max_logging.log("Processing vision model")
279+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["class_embedding"] = chkpt_vars["vision_model.class_embedding"].to(torch.float32).numpy().astype(CAST_DTYPE)
280+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["positional_embedding_vlm"] = chkpt_vars["vision_model.positional_embedding_vlm"].to(torch.float32).numpy().astype(CAST_DTYPE)
281+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["Llama4UnfoldConvolution_0"] = chkpt_vars["vision_model.patch_embedding.linear.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
282+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["layernorm_pre"].update({
283+
"scale": chkpt_vars["vision_model.layernorm_pre.weight"].to(torch.float32).numpy().astype(CAST_DTYPE),
284+
"bias": chkpt_vars["vision_model.layernorm_pre.bias"].to(torch.float32).numpy().astype(CAST_DTYPE),
285+
})
286+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["layernorm_post"].update({
287+
"scale": chkpt_vars["vision_model.layernorm_pre.weight"].to(torch.float32).numpy().astype(CAST_DTYPE),
288+
"bias": chkpt_vars["vision_model.layernorm_pre.bias"].to(torch.float32).numpy().astype(CAST_DTYPE),
289+
})
290+
for layer_idx in tqdm(range(num_hidden_layers_for_vit), desc="layers", leave=False):
291+
layer_name = f"layers_{layer_idx}"
292+
293+
wq = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.q_proj.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
294+
wk = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.k_proj.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
295+
wv = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.v_proj.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
296+
wo = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.o_proj.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
297+
bq = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.q_proj.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
298+
bk = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.k_proj.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
299+
bv = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.v_proj.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
300+
bo = chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.o_proj.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
301+
302+
wq = np.reshape(wq, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
303+
wk = np.reshape(wk, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
304+
wv = np.reshape(wv, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
305+
wo = np.reshape(wo, [num_attention_heads_for_vit, head_dim_for_vit, hidden_size_for_vit])
306+
307+
self_attention_vision = {
308+
"query": {"kernel": wq , "bias": bq},
309+
"key": {"kernel": wk , "bias": bk},
310+
"value": {"kernel": wv , "bias": bv},
311+
"out": {"kernel": wo , "bias": bo},
312+
}
313+
314+
fc1_w = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc1.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
315+
fc2_w = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc2.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
316+
fc1_b = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc1.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
317+
fc2_b = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc2.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
318+
vision_mlp = {
319+
"vit_encoder_layer_mlp_fc1": {"kernal": fc1_w, "bias": fc1_b},
320+
"vit_encoder_layer_mlp_fc2": {"kernal": fc2_w, "bias": fc2_b},
321+
}
322+
323+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["Llama4VisionEncoder_0"].update(
324+
{
325+
layer_name: {
326+
"self_attention_vision": self_attention_vision,
327+
"Llama4VisionMLP_0": vision_mlp,
328+
}
329+
}
330+
)
331+
332+
max_logging.log("Processing multimodal projector")
333+
jax_weights["Llama4MultiModalProjector_0"]["vit_multi_modal_projector"]["kernel"] = chkpt_vars["multi_modal_projector.linear_1.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
334+
335+
return jax_weights
336+
337+
# language model ###########################################
338+
max_logging.log("Processing language model")
235339
# decoder norm scale ###########################################
236340
max_logging.log("Processing decoder norm scale")
237341
decoder_norm_scale = chkpt_vars["norm.weight"].to(torch.float32).numpy().astype(CAST_DTYPE)

MaxText/llama_or_mistral_ckpt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@
141141
"rope_type": "llama3.1",
142142
"scale_query": False,
143143
"interleave_moe_layer_step": 1,
144+
"num_layers_vit": 1,
145+
"num_att_head_vit": 16,
146+
"hidden_size_vit": 1408,
144147
},
145148
"llama4-17b-128e": {
146149
"num_layers": 48,

0 commit comments

Comments
 (0)