Skip to content

Commit 283fb57

Browse files
committed
Correcting the params
1 parent a7808e4 commit 283fb57

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

MaxText/layers/models.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,60 +48,63 @@
4848
# ------------------------------------------------------------------------------
4949

5050

51-
def get_decoder_layers(self):
52-
"""Get decoder layers, one of `DecoderBlockType` discriminants or a direct `nn.Module` inheritor"""
53-
if self.config.decoder_block == DecoderBlockType.DEFAULT:
51+
def get_decoder_layers(config: Config):
52+
"""
53+
Helper function to get the list of decoder layer classes based on config.
54+
Get decoder layers, one of `DecoderBlockType` discriminants or a direct `nn.Module` inheritor
55+
"""
56+
if config.decoder_block == DecoderBlockType.DEFAULT:
5457
return [DecoderLayer]
55-
elif self.config.decoder_block == DecoderBlockType.LLAMA2:
58+
elif config.decoder_block == DecoderBlockType.LLAMA2:
5659
from MaxText.layers import llama2 # pylint: disable=import-outside-toplevel
5760

5861
return [llama2.LlamaDecoderLayer]
59-
elif self.config.decoder_block == DecoderBlockType.MISTRAL:
62+
elif config.decoder_block == DecoderBlockType.MISTRAL:
6063
# TODO(ranran): update to Mistral with sliding window attention
6164
from MaxText.layers import mistral # pylint: disable=import-outside-toplevel
6265

6366
return [mistral.MistralDecoderLayer]
64-
elif self.config.decoder_block == DecoderBlockType.MIXTRAL:
67+
elif config.decoder_block == DecoderBlockType.MIXTRAL:
6568
from MaxText.layers import mixtral # pylint: disable=import-outside-toplevel
6669

6770
return [mixtral.MixtralDecoderLayer]
68-
elif self.config.decoder_block == DecoderBlockType.DEEPSEEK:
71+
elif config.decoder_block == DecoderBlockType.DEEPSEEK:
6972
from MaxText.layers import deepseek # pylint: disable=import-outside-toplevel
7073

7174
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
72-
elif self.config.decoder_block == DecoderBlockType.GEMMA:
75+
elif config.decoder_block == DecoderBlockType.GEMMA:
7376
from MaxText.layers import gemma # pylint: disable=import-outside-toplevel
7477

7578
return [gemma.GemmaDecoderLayer]
76-
elif self.config.decoder_block == DecoderBlockType.GEMMA2:
79+
elif config.decoder_block == DecoderBlockType.GEMMA2:
7780
from MaxText.layers import gemma2 # pylint: disable=import-outside-toplevel
7881

7982
return [gemma2.Gemma2DecoderLayer]
80-
elif self.config.decoder_block == DecoderBlockType.GEMMA3:
83+
elif config.decoder_block == DecoderBlockType.GEMMA3:
8184
from MaxText.layers import gemma3 # pylint: disable=import-outside-toplevel
8285

8386
return [gemma3.Gemma3DecoderLayer]
84-
elif self.config.decoder_block == DecoderBlockType.GPT3:
87+
elif config.decoder_block == DecoderBlockType.GPT3:
8588
from MaxText.layers import gpt3 # pylint: disable=import-outside-toplevel
8689

8790
return [gpt3.Gpt3DecoderLayer]
88-
elif self.config.decoder_block == DecoderBlockType.SIMPLE:
91+
elif config.decoder_block == DecoderBlockType.SIMPLE:
8992
from MaxText.layers import simple_layer # pylint: disable=import-outside-toplevel
9093

9194
return [simple_layer.SimpleDecoderLayer]
92-
elif self.config.decoder_block == DecoderBlockType.SIMPLE_MLP:
95+
elif config.decoder_block == DecoderBlockType.SIMPLE_MLP:
9396
from MaxText.layers import simple_layer # pylint: disable=import-outside-toplevel
9497

9598
return [simple_layer.SimpleMlpDecoderLayer]
96-
elif self.config.decoder_block == DecoderBlockType.LLAMA4:
99+
elif config.decoder_block == DecoderBlockType.LLAMA4:
97100
from MaxText.layers import llama4 # pylint: disable=import-outside-toplevel
98101

99-
if self.config.scan_layers:
102+
if config.scan_layers:
100103
return [llama4.Llama4ScannableBlock]
101104
else:
102105
return [llama4.Llama4DecoderLayer]
103106
else:
104-
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
107+
raise ValueError(f"Incorrect decoder_block name {config.decoder_block.value=}")
105108

106109

107110
class SequentialBlockDecoderLayers(nn.Module):
@@ -152,7 +155,7 @@ class Decoder(nn.Module):
152155

153156
def setup(self):
154157
"""Initialize decoder layer."""
155-
self.decoder_layer = get_decoder_layers()
158+
self.decoder_layer = get_decoder_layers(self.config)
156159
self.norm_layer = self.get_norm_layer()
157160
if self.config.using_pipeline_parallelism:
158161
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)

0 commit comments

Comments
 (0)