@@ -194,6 +194,10 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
194
194
Returns:
195
195
jax_weights (dict): Dictionary containing the converted weights.
196
196
"""
197
+ num_hidden_layers_for_vit = model_params .get ("num_layers_vit" , 0 )
198
+ num_attention_heads_for_vit = model_params .get ("num_att_head_vit" , 0 )
199
+ hidden_size_for_vit = model_params .get ("hidden_size_vit" , 0 )
200
+ head_dim_for_vit = hidden_size_for_vit // num_attention_heads_for_vit
197
201
base_num_decoder_layers = model_params ["num_layers" ]
198
202
base_num_query_heads = model_params ["num_heads" ]
199
203
head_dim = model_params ["dims_per_head" ]
@@ -215,9 +219,8 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
215
219
for key in f .keys ():
216
220
parts = key .split ("." )
217
221
layer = int (parts [3 ]) if "layers" in key else 0
218
- # TODO: update when mutli-modality support is added
219
222
if "vision" in key or "multi_modal_projector" in key :
220
- print ( "WARNING: skipping vision or multi-modal key: " , key )
223
+ chkpt_vars [ key ] = f . get_tensor ( key )
221
224
else :
222
225
mapped_key = _hf_to_maxtext_mapping (layer )[key ]
223
226
chkpt_vars [mapped_key ] = f .get_tensor (key )
@@ -230,8 +233,190 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
230
233
"logits_dense" : {"kernel" : None },
231
234
},
232
235
"token_embedder" : {"embedding" : None },
236
+ "vision_encoder" : {
237
+ "Llama4VisionModel_0" : {
238
+ "Llama4VisionEncoder_0" : {},
239
+ "class_embedding" : None ,
240
+ "positional_embedding_vlm" : None ,
241
+ "Llama4UnfoldConvolution_0" : {"vit_unfold_linear" : {"kernel" : None }},
242
+ "layernorm_pre" : {},
243
+ "layernorm_post" : {},
244
+ "Llama4VisionPixelShuffleMLP_0" : {},
245
+ },
246
+ "Llama4MultiModalProjector_0" : {"vit_multi_modal_projector" : {"kernel" : None }},
247
+ },
233
248
}
234
249
250
+ # vision model ###########################################
251
+ max_logging .log ("Processing vision model" )
252
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["class_embedding" ] = (
253
+ chkpt_vars ["vision_model.class_embedding" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
254
+ )
255
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["positional_embedding_vlm" ] = (
256
+ chkpt_vars ["vision_model.positional_embedding_vlm" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
257
+ )
258
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["Llama4UnfoldConvolution_0" ]["vit_unfold_linear" ]["kernel" ] = (
259
+ chkpt_vars ["vision_model.patch_embedding.linear.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
260
+ )
261
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["layernorm_pre" ].update (
262
+ {
263
+ "scale" : chkpt_vars ["vision_model.layernorm_pre.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ),
264
+ "bias" : chkpt_vars ["vision_model.layernorm_pre.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ),
265
+ }
266
+ )
267
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["layernorm_post" ].update (
268
+ {
269
+ "scale" : chkpt_vars ["vision_model.layernorm_post.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ),
270
+ "bias" : chkpt_vars ["vision_model.layernorm_post.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ),
271
+ }
272
+ )
273
+
274
+ # vision encoder ###########################################
275
+ max_logging .log ("Processing vision encoder" )
276
+ for layer_idx in tqdm (range (num_hidden_layers_for_vit ), desc = "layers" , leave = False ):
277
+ layer_name = f"layers_{ layer_idx } "
278
+ wq = (
279
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.q_proj.weight" ]
280
+ .to (torch .float32 )
281
+ .numpy ()
282
+ .astype (CAST_DTYPE )
283
+ .transpose ()
284
+ )
285
+ wk = (
286
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.k_proj.weight" ]
287
+ .to (torch .float32 )
288
+ .numpy ()
289
+ .astype (CAST_DTYPE )
290
+ .transpose ()
291
+ )
292
+ wv = (
293
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.v_proj.weight" ]
294
+ .to (torch .float32 )
295
+ .numpy ()
296
+ .astype (CAST_DTYPE )
297
+ .transpose ()
298
+ )
299
+ wo = (
300
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.o_proj.weight" ]
301
+ .to (torch .float32 )
302
+ .numpy ()
303
+ .astype (CAST_DTYPE )
304
+ .transpose ()
305
+ )
306
+ bq = (
307
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.q_proj.bias" ]
308
+ .to (torch .float32 )
309
+ .numpy ()
310
+ .astype (CAST_DTYPE )
311
+ )
312
+ bk = (
313
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.k_proj.bias" ]
314
+ .to (torch .float32 )
315
+ .numpy ()
316
+ .astype (CAST_DTYPE )
317
+ )
318
+ bv = (
319
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.v_proj.bias" ]
320
+ .to (torch .float32 )
321
+ .numpy ()
322
+ .astype (CAST_DTYPE )
323
+ )
324
+ bo = (
325
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.o_proj.bias" ]
326
+ .to (torch .float32 )
327
+ .numpy ()
328
+ .astype (CAST_DTYPE )
329
+ )
330
+
331
+ wq = np .reshape (wq , [hidden_size_for_vit , num_attention_heads_for_vit , head_dim_for_vit ])
332
+ wk = np .reshape (wk , [hidden_size_for_vit , num_attention_heads_for_vit , head_dim_for_vit ])
333
+ wv = np .reshape (wv , [hidden_size_for_vit , num_attention_heads_for_vit , head_dim_for_vit ])
334
+ wo = np .reshape (wo , [num_attention_heads_for_vit , head_dim_for_vit , hidden_size_for_vit ])
335
+ bq = np .reshape (bq , [num_attention_heads_for_vit , head_dim_for_vit ])
336
+ bk = np .reshape (bk , [num_attention_heads_for_vit , head_dim_for_vit ])
337
+ bv = np .reshape (bv , [num_attention_heads_for_vit , head_dim_for_vit ])
338
+
339
+ self_attention_vision = {
340
+ "query" : {"kernel" : wq , "bias" : bq },
341
+ "key" : {"kernel" : wk , "bias" : bk },
342
+ "value" : {"kernel" : wv , "bias" : bv },
343
+ "out" : {"kernel" : wo , "bias" : bo },
344
+ }
345
+
346
+ fc1_w = (
347
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .mlp.fc1.weight" ]
348
+ .to (torch .float32 )
349
+ .numpy ()
350
+ .astype (CAST_DTYPE )
351
+ .transpose ()
352
+ )
353
+ fc2_w = (
354
+ chkpt_vars [f"vision_model.model.layers.{ layer_idx } .mlp.fc2.weight" ]
355
+ .to (torch .float32 )
356
+ .numpy ()
357
+ .astype (CAST_DTYPE )
358
+ .transpose ()
359
+ )
360
+ fc1_b = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .mlp.fc1.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
361
+ fc2_b = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .mlp.fc2.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
362
+ vision_mlp = {
363
+ "vit_encoder_layer_mlp_fc1" : {"kernel" : fc1_w , "bias" : fc1_b },
364
+ "vit_encoder_layer_mlp_fc2" : {"kernel" : fc2_w , "bias" : fc2_b },
365
+ }
366
+
367
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["Llama4VisionEncoder_0" ].update (
368
+ {
369
+ layer_name : {
370
+ "self_attention_vision" : self_attention_vision ,
371
+ "Llama4VisionMLP_0" : vision_mlp ,
372
+ "input_layer_norm" : {
373
+ "scale" : chkpt_vars [f"vision_model.model.layers.{ layer_idx } .input_layernorm.weight" ]
374
+ .to (torch .float32 )
375
+ .numpy ()
376
+ .astype (CAST_DTYPE ),
377
+ "bias" : chkpt_vars [f"vision_model.model.layers.{ layer_idx } .input_layernorm.bias" ]
378
+ .to (torch .float32 )
379
+ .numpy ()
380
+ .astype (CAST_DTYPE ),
381
+ },
382
+ "post_attention_layer_norm" : {
383
+ "scale" : chkpt_vars [f"vision_model.model.layers.{ layer_idx } .post_attention_layernorm.weight" ]
384
+ .to (torch .float32 )
385
+ .numpy ()
386
+ .astype (CAST_DTYPE ),
387
+ "bias" : chkpt_vars [f"vision_model.model.layers.{ layer_idx } .post_attention_layernorm.bias" ]
388
+ .to (torch .float32 )
389
+ .numpy ()
390
+ .astype (CAST_DTYPE ),
391
+ },
392
+ }
393
+ }
394
+ )
395
+
396
+ # pixel shuffle mlp ###########################################
397
+ max_logging .log ("Processing pixel shuffle mlp" )
398
+ adaptor_fc1 = (
399
+ chkpt_vars ["vision_model.vision_adapter.mlp.fc1.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
400
+ )
401
+ adaptor_fc2 = (
402
+ chkpt_vars ["vision_model.vision_adapter.mlp.fc2.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
403
+ )
404
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["Llama4VisionPixelShuffleMLP_0" ].update (
405
+ {
406
+ "pixel_shuffle_mlp" : {
407
+ "vit_pixel_shuffle_mlp_fc1" : {"kernel" : adaptor_fc1 },
408
+ "vit_pixel_shuffle_mlp_fc2" : {"kernel" : adaptor_fc2 },
409
+ },
410
+ }
411
+ )
412
+ # multimodal projector ###########################################
413
+ max_logging .log ("Processing multimodal projector" )
414
+ jax_weights ["vision_encoder" ]["Llama4MultiModalProjector_0" ]["vit_multi_modal_projector" ]["kernel" ] = (
415
+ chkpt_vars ["multi_modal_projector.linear_1.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
416
+ )
417
+
418
+ # language model ###########################################
419
+ max_logging .log ("Processing language model" )
235
420
# decoder norm scale ###########################################
236
421
max_logging .log ("Processing decoder norm scale" )
237
422
decoder_norm_scale = chkpt_vars ["norm.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
0 commit comments