Skip to content

Commit 4ff7e82

Browse files
committed
llama4 ckpt conversion
1 parent 3a83b61 commit 4ff7e82

File tree

2 files changed

+190
-2
lines changed

2 files changed

+190
-2
lines changed

MaxText/llama4_ckpt_unscanned.py

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
194194
Returns:
195195
jax_weights (dict): Dictionary containing the converted weights.
196196
"""
197+
num_hidden_layers_for_vit = model_params.get("num_layers_vit", 0)
198+
num_attention_heads_for_vit = model_params.get("num_att_head_vit", 0)
199+
hidden_size_for_vit = model_params.get("hidden_size_vit", 0)
200+
head_dim_for_vit = hidden_size_for_vit // num_attention_heads_for_vit
197201
base_num_decoder_layers = model_params["num_layers"]
198202
base_num_query_heads = model_params["num_heads"]
199203
head_dim = model_params["dims_per_head"]
@@ -215,9 +219,8 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
215219
for key in f.keys():
216220
parts = key.split(".")
217221
layer = int(parts[3]) if "layers" in key else 0
218-
# TODO: update when mutli-modality support is added
219222
if "vision" in key or "multi_modal_projector" in key:
220-
print("WARNING: skipping vision or multi-modal key: ", key)
223+
chkpt_vars[key] = f.get_tensor(key)
221224
else:
222225
mapped_key = _hf_to_maxtext_mapping(layer)[key]
223226
chkpt_vars[mapped_key] = f.get_tensor(key)
@@ -230,8 +233,190 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
230233
"logits_dense": {"kernel": None},
231234
},
232235
"token_embedder": {"embedding": None},
236+
"vision_encoder": {
237+
"Llama4VisionModel_0": {
238+
"Llama4VisionEncoder_0": {},
239+
"class_embedding": None,
240+
"positional_embedding_vlm": None,
241+
"Llama4UnfoldConvolution_0": {"vit_unfold_linear": {"kernel": None}},
242+
"layernorm_pre": {},
243+
"layernorm_post": {},
244+
"Llama4VisionPixelShuffleMLP_0": {},
245+
},
246+
"Llama4MultiModalProjector_0": {"vit_multi_modal_projector": {"kernel": None}},
247+
},
233248
}
234249

250+
# vision model ###########################################
251+
max_logging.log("Processing vision model")
252+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["class_embedding"] = (
253+
chkpt_vars["vision_model.class_embedding"].to(torch.float32).numpy().astype(CAST_DTYPE)
254+
)
255+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["positional_embedding_vlm"] = (
256+
chkpt_vars["vision_model.positional_embedding_vlm"].to(torch.float32).numpy().astype(CAST_DTYPE)
257+
)
258+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["Llama4UnfoldConvolution_0"]["vit_unfold_linear"]["kernel"] = (
259+
chkpt_vars["vision_model.patch_embedding.linear.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
260+
)
261+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["layernorm_pre"].update(
262+
{
263+
"scale": chkpt_vars["vision_model.layernorm_pre.weight"].to(torch.float32).numpy().astype(CAST_DTYPE),
264+
"bias": chkpt_vars["vision_model.layernorm_pre.bias"].to(torch.float32).numpy().astype(CAST_DTYPE),
265+
}
266+
)
267+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["layernorm_post"].update(
268+
{
269+
"scale": chkpt_vars["vision_model.layernorm_post.weight"].to(torch.float32).numpy().astype(CAST_DTYPE),
270+
"bias": chkpt_vars["vision_model.layernorm_post.bias"].to(torch.float32).numpy().astype(CAST_DTYPE),
271+
}
272+
)
273+
274+
# vision encoder ###########################################
275+
max_logging.log("Processing vision encoder")
276+
for layer_idx in tqdm(range(num_hidden_layers_for_vit), desc="layers", leave=False):
277+
layer_name = f"layers_{layer_idx}"
278+
wq = (
279+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.q_proj.weight"]
280+
.to(torch.float32)
281+
.numpy()
282+
.astype(CAST_DTYPE)
283+
.transpose()
284+
)
285+
wk = (
286+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.k_proj.weight"]
287+
.to(torch.float32)
288+
.numpy()
289+
.astype(CAST_DTYPE)
290+
.transpose()
291+
)
292+
wv = (
293+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.v_proj.weight"]
294+
.to(torch.float32)
295+
.numpy()
296+
.astype(CAST_DTYPE)
297+
.transpose()
298+
)
299+
wo = (
300+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.o_proj.weight"]
301+
.to(torch.float32)
302+
.numpy()
303+
.astype(CAST_DTYPE)
304+
.transpose()
305+
)
306+
bq = (
307+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.q_proj.bias"]
308+
.to(torch.float32)
309+
.numpy()
310+
.astype(CAST_DTYPE)
311+
)
312+
bk = (
313+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.k_proj.bias"]
314+
.to(torch.float32)
315+
.numpy()
316+
.astype(CAST_DTYPE)
317+
)
318+
bv = (
319+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.v_proj.bias"]
320+
.to(torch.float32)
321+
.numpy()
322+
.astype(CAST_DTYPE)
323+
)
324+
bo = (
325+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.self_attn.o_proj.bias"]
326+
.to(torch.float32)
327+
.numpy()
328+
.astype(CAST_DTYPE)
329+
)
330+
331+
wq = np.reshape(wq, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
332+
wk = np.reshape(wk, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
333+
wv = np.reshape(wv, [hidden_size_for_vit, num_attention_heads_for_vit, head_dim_for_vit])
334+
wo = np.reshape(wo, [num_attention_heads_for_vit, head_dim_for_vit, hidden_size_for_vit])
335+
bq = np.reshape(bq, [num_attention_heads_for_vit, head_dim_for_vit])
336+
bk = np.reshape(bk, [num_attention_heads_for_vit, head_dim_for_vit])
337+
bv = np.reshape(bv, [num_attention_heads_for_vit, head_dim_for_vit])
338+
339+
self_attention_vision = {
340+
"query": {"kernel": wq, "bias": bq},
341+
"key": {"kernel": wk, "bias": bk},
342+
"value": {"kernel": wv, "bias": bv},
343+
"out": {"kernel": wo, "bias": bo},
344+
}
345+
346+
fc1_w = (
347+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc1.weight"]
348+
.to(torch.float32)
349+
.numpy()
350+
.astype(CAST_DTYPE)
351+
.transpose()
352+
)
353+
fc2_w = (
354+
chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc2.weight"]
355+
.to(torch.float32)
356+
.numpy()
357+
.astype(CAST_DTYPE)
358+
.transpose()
359+
)
360+
fc1_b = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc1.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
361+
fc2_b = chkpt_vars[f"vision_model.model.layers.{layer_idx}.mlp.fc2.bias"].to(torch.float32).numpy().astype(CAST_DTYPE)
362+
vision_mlp = {
363+
"vit_encoder_layer_mlp_fc1": {"kernel": fc1_w, "bias": fc1_b},
364+
"vit_encoder_layer_mlp_fc2": {"kernel": fc2_w, "bias": fc2_b},
365+
}
366+
367+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["Llama4VisionEncoder_0"].update(
368+
{
369+
layer_name: {
370+
"self_attention_vision": self_attention_vision,
371+
"Llama4VisionMLP_0": vision_mlp,
372+
"input_layer_norm": {
373+
"scale": chkpt_vars[f"vision_model.model.layers.{layer_idx}.input_layernorm.weight"]
374+
.to(torch.float32)
375+
.numpy()
376+
.astype(CAST_DTYPE),
377+
"bias": chkpt_vars[f"vision_model.model.layers.{layer_idx}.input_layernorm.bias"]
378+
.to(torch.float32)
379+
.numpy()
380+
.astype(CAST_DTYPE),
381+
},
382+
"post_attention_layer_norm": {
383+
"scale": chkpt_vars[f"vision_model.model.layers.{layer_idx}.post_attention_layernorm.weight"]
384+
.to(torch.float32)
385+
.numpy()
386+
.astype(CAST_DTYPE),
387+
"bias": chkpt_vars[f"vision_model.model.layers.{layer_idx}.post_attention_layernorm.bias"]
388+
.to(torch.float32)
389+
.numpy()
390+
.astype(CAST_DTYPE),
391+
},
392+
}
393+
}
394+
)
395+
396+
# pixel shuffle mlp ###########################################
397+
max_logging.log("Processing pixel shuffle mlp")
398+
adaptor_fc1 = (
399+
chkpt_vars["vision_model.vision_adapter.mlp.fc1.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
400+
)
401+
adaptor_fc2 = (
402+
chkpt_vars["vision_model.vision_adapter.mlp.fc2.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
403+
)
404+
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["Llama4VisionPixelShuffleMLP_0"].update(
405+
{
406+
"pixel_shuffle_mlp": {
407+
"vit_pixel_shuffle_mlp_fc1": {"kernel": adaptor_fc1},
408+
"vit_pixel_shuffle_mlp_fc2": {"kernel": adaptor_fc2},
409+
},
410+
}
411+
)
412+
# multimodal projector ###########################################
413+
max_logging.log("Processing multimodal projector")
414+
jax_weights["vision_encoder"]["Llama4MultiModalProjector_0"]["vit_multi_modal_projector"]["kernel"] = (
415+
chkpt_vars["multi_modal_projector.linear_1.weight"].to(torch.float32).numpy().astype(CAST_DTYPE).transpose()
416+
)
417+
418+
# language model ###########################################
419+
max_logging.log("Processing language model")
235420
# decoder norm scale ###########################################
236421
max_logging.log("Processing decoder norm scale")
237422
decoder_norm_scale = chkpt_vars["norm.weight"].to(torch.float32).numpy().astype(CAST_DTYPE)

MaxText/llama_or_mistral_ckpt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@
141141
"rope_type": "llama3.1",
142142
"scale_query": False,
143143
"interleave_moe_layer_step": 1,
144+
"num_layers_vit": 34,
145+
"num_att_head_vit": 16,
146+
"hidden_size_vit": 1408,
144147
},
145148
"llama4-17b-128e": {
146149
"num_layers": 48,

0 commit comments

Comments
 (0)