@@ -136,6 +136,7 @@ def _hf_to_maxtext_mapping(layer_idx: int = -1) -> dict:
136
136
"""
137
137
# pylint: disable=line-too-long
138
138
return {
139
+ ## language model mappping
139
140
"language_model.model.embed_tokens.weight" : "tok_embeddings.weight" ,
140
141
"language_model.model.norm.weight" : "norm.weight" ,
141
142
"language_model.lm_head.weight" : "output.weight" ,
@@ -145,18 +146,43 @@ def _hf_to_maxtext_mapping(layer_idx: int = -1) -> dict:
145
146
f"language_model.model.layers.{ layer_idx } .self_attn.k_proj.weight" : f"layers.{ layer_idx } .attention.wk.weight" ,
146
147
f"language_model.model.layers.{ layer_idx } .self_attn.v_proj.weight" : f"layers.{ layer_idx } .attention.wv.weight" ,
147
148
f"language_model.model.layers.{ layer_idx } .self_attn.o_proj.weight" : f"layers.{ layer_idx } .attention.wo.weight" ,
148
- # MoE
149
+ # MoE in language model
149
150
f"language_model.model.layers.{ layer_idx } .feed_forward.router.weight" : f"layers.{ layer_idx } .feed_forward.gate.weight" ,
150
151
f"language_model.model.layers.{ layer_idx } .feed_forward.experts.down_proj" : f"layers.{ layer_idx } .feed_forward.experts.down_proj" ,
151
152
# NOTE: this contains up_proj and gate_proj concated together (we'll split/chunk them later)
152
153
f"language_model.model.layers.{ layer_idx } .feed_forward.experts.gate_up_proj" : f"layers.{ layer_idx } .feed_forward.experts.gate_up_proj" ,
153
154
f"language_model.model.layers.{ layer_idx } .feed_forward.shared_expert.gate_proj.weight" : f"layers.{ layer_idx } .feed_forward.shared_experts.gate_proj.weight" ,
154
155
f"language_model.model.layers.{ layer_idx } .feed_forward.shared_expert.down_proj.weight" : f"layers.{ layer_idx } .feed_forward.shared_experts.down_proj.weight" ,
155
156
f"language_model.model.layers.{ layer_idx } .feed_forward.shared_expert.up_proj.weight" : f"layers.{ layer_idx } .feed_forward.shared_experts.up_proj.weight" ,
156
- # FFN
157
+ # FFN in language model
157
158
f"language_model.model.layers.{ layer_idx } .feed_forward.up_proj.weight" : f"layers.{ layer_idx } .feed_forward.w1.weight" ,
158
159
f"language_model.model.layers.{ layer_idx } .feed_forward.gate_proj.weight" : f"layers.{ layer_idx } .feed_forward.w3.weight" ,
159
160
f"language_model.model.layers.{ layer_idx } .feed_forward.down_proj.weight" : f"layers.{ layer_idx } .feed_forward.w2.weight" ,
161
+
162
+ # ## vision model mapping
163
+ # "vision_model.class_embedding": "vision_encoder.",
164
+ # "vision_model.positional_embedding_vlm": "",
165
+ # "vision_model.patch_embedding.linear.weight": "",
166
+ # "vision_model.layernorm_pre.weight": "",
167
+ # "vision_model.layernorm_pre.bias": "",
168
+ # "vision_model.layernorm_post.weight": "",
169
+ # "vision_model.layernorm_post.bias": "",
170
+ # "vision_model.model.layers.{layer_idx}.input_layernorm.weight": "",
171
+ # "vision_model.model.layers.{layer_idx}.input_layernorm.bias": "",
172
+ # "vision_model.model.layers.{layer_idx}.self_attn.q_proj.weight": "",
173
+ # "vision_model.model.layers.{layer_idx}.self_attn.q_proj.bias": "",
174
+ # "vision_model.model.layers.{layer_idx}.self_attn.k_proj.weight": "",
175
+ # "vision_model.model.layers.{layer_idx}.self_attn.k_proj.bias": "",
176
+ # "vision_model.model.layers.{layer_idx}.self_attn.v_proj.weight": "",
177
+ # "vision_model.model.layers.{layer_idx}.self_attn.v_proj.bias": "",
178
+ # "vision_model.model.layers.{layer_idx}.self_attn.o_proj.weight": "",
179
+ # "vision_model.model.layers.{layer_idx}.self_attn.o_proj.bias": "",
180
+ # "vision_model.model.layers.{layer_idx}.post_attention_layernorm.weight": "",
181
+ # "vision_model.model.layers.{layer_idx}.post_attention_layernorm.bias": "",
182
+ # "vision_model.model.layers.{layer_idx}.mlp.fc1.weight": "",
183
+ # "vision_model.model.layers.{layer_idx}.mlp.fc1.bias": "",
184
+ # "vision_model.model.layers.{layer_idx}.mlp.fc2.weight": "",
185
+ # "vision_model.model.layers.{layer_idx}.mlp.fc2.bias": "",
160
186
}
161
187
162
188
@@ -194,6 +220,10 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
194
220
Returns:
195
221
jax_weights (dict): Dictionary containing the converted weights.
196
222
"""
223
+ num_hidden_layers_for_vit = model_params .get ("num_layers_vit" , 0 )
224
+ num_attention_heads_for_vit = model_params .get ("num_att_head_vit" , 0 )
225
+ hidden_size_for_vit = model_params .get ("hidden_size_vit" , 0 )
226
+ head_dim_for_vit = hidden_size_for_vit // num_attention_heads_for_vit
197
227
base_num_decoder_layers = model_params ["num_layers" ]
198
228
base_num_query_heads = model_params ["num_heads" ]
199
229
head_dim = model_params ["dims_per_head" ]
@@ -217,7 +247,8 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
217
247
layer = int (parts [3 ]) if "layers" in key else 0
218
248
# TODO: update when mutli-modality support is added
219
249
if "vision" in key or "multi_modal_projector" in key :
220
- print ("WARNING: skipping vision or multi-modal key: " , key )
250
+ #print("WARNING: skipping vision or multi-modal key: ", key)
251
+ chkpt_vars [key ] = f .get_tensor (key )
221
252
else :
222
253
mapped_key = _hf_to_maxtext_mapping (layer )[key ]
223
254
chkpt_vars [mapped_key ] = f .get_tensor (key )
@@ -230,8 +261,81 @@ def _convert_huggingface_to_jax_weights(base_model_path: str, model_size: str, m
230
261
"logits_dense" : {"kernel" : None },
231
262
},
232
263
"token_embedder" : {"embedding" : None },
264
+ "vision_encoder" : {
265
+ "Llama4VisionModel_0" : {
266
+ "Llama4VisionEncoder_0" : None ,
267
+ "class_embedding" : None ,
268
+ "positional_embedding_vlm" : None ,
269
+ "Llama4UnfoldConvolution_0" : None ,
270
+ "layernorm_pre" : None ,
271
+ "layernorm_post" : None ,
272
+ },
273
+ "Llama4MultiModalProjector_0" : {"vit_multi_modal_projector" : {"kernel" : None }},
274
+ },
233
275
}
234
276
277
+ # vision encoder ###########################################
278
+ max_logging .log ("Processing vision model" )
279
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["class_embedding" ] = chkpt_vars ["vision_model.class_embedding" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
280
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["positional_embedding_vlm" ] = chkpt_vars ["vision_model.positional_embedding_vlm" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
281
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["Llama4UnfoldConvolution_0" ] = chkpt_vars ["vision_model.patch_embedding.linear.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
282
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["layernorm_pre" ].update ({
283
+ "scale" : chkpt_vars ["vision_model.layernorm_pre.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ),
284
+ "bias" : chkpt_vars ["vision_model.layernorm_pre.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ),
285
+ })
286
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["layernorm_post" ].update ({
287
+ "scale" : chkpt_vars ["vision_model.layernorm_pre.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ),
288
+ "bias" : chkpt_vars ["vision_model.layernorm_pre.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ),
289
+ })
290
+ for layer_idx in tqdm (range (num_hidden_layers_for_vit ), desc = "layers" , leave = False ):
291
+ layer_name = f"layers_{ layer_idx } "
292
+
293
+ wq = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.q_proj.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
294
+ wk = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.k_proj.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
295
+ wv = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.v_proj.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
296
+ wo = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.o_proj.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
297
+ bq = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.q_proj.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
298
+ bk = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.k_proj.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
299
+ bv = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.v_proj.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
300
+ bo = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .self_attn.o_proj.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
301
+
302
+ wq = np .reshape (wq , [hidden_size_for_vit , num_attention_heads_for_vit , head_dim_for_vit ])
303
+ wk = np .reshape (wk , [hidden_size_for_vit , num_attention_heads_for_vit , head_dim_for_vit ])
304
+ wv = np .reshape (wv , [hidden_size_for_vit , num_attention_heads_for_vit , head_dim_for_vit ])
305
+ wo = np .reshape (wo , [num_attention_heads_for_vit , head_dim_for_vit , hidden_size_for_vit ])
306
+
307
+ self_attention_vision = {
308
+ "query" : {"kernel" : wq , "bias" : bq },
309
+ "key" : {"kernel" : wk , "bias" : bk },
310
+ "value" : {"kernel" : wv , "bias" : bv },
311
+ "out" : {"kernel" : wo , "bias" : bo },
312
+ }
313
+
314
+ fc1_w = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .mlp.fc1.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
315
+ fc2_w = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .mlp.fc2.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
316
+ fc1_b = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .mlp.fc1.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
317
+ fc2_b = chkpt_vars [f"vision_model.model.layers.{ layer_idx } .mlp.fc2.bias" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
318
+ vision_mlp = {
319
+ "vit_encoder_layer_mlp_fc1" : {"kernal" : fc1_w , "bias" : fc1_b },
320
+ "vit_encoder_layer_mlp_fc2" : {"kernal" : fc2_w , "bias" : fc2_b },
321
+ }
322
+
323
+ jax_weights ["vision_encoder" ]["Llama4VisionModel_0" ]["Llama4VisionEncoder_0" ].update (
324
+ {
325
+ layer_name : {
326
+ "self_attention_vision" : self_attention_vision ,
327
+ "Llama4VisionMLP_0" : vision_mlp ,
328
+ }
329
+ }
330
+ )
331
+
332
+ max_logging .log ("Processing multimodal projector" )
333
+ jax_weights ["Llama4MultiModalProjector_0" ]["vit_multi_modal_projector" ]["kernel" ] = chkpt_vars ["multi_modal_projector.linear_1.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE ).transpose ()
334
+
335
+ return jax_weights
336
+
337
+ # language model ###########################################
338
+ max_logging .log ("Processing language model" )
235
339
# decoder norm scale ###########################################
236
340
max_logging .log ("Processing decoder norm scale" )
237
341
decoder_norm_scale = chkpt_vars ["norm.weight" ].to (torch .float32 ).numpy ().astype (CAST_DTYPE )
0 commit comments