|
48 | 48 | # ------------------------------------------------------------------------------
|
49 | 49 |
|
50 | 50 |
|
51 |
| -def get_decoder_layers(self): |
52 |
| - """Get decoder layers, one of `DecoderBlockType` discriminants or a direct `nn.Module` inheritor""" |
53 |
| - if self.config.decoder_block == DecoderBlockType.DEFAULT: |
| 51 | +def get_decoder_layers(config: Config): |
| 52 | + """ |
| 53 | + Helper function to get the list of decoder layer classes based on config. |
| 54 | + Get decoder layers, one of `DecoderBlockType` discriminants or a direct `nn.Module` inheritor |
| 55 | + """ |
| 56 | + if config.decoder_block == DecoderBlockType.DEFAULT: |
54 | 57 | return [DecoderLayer]
|
55 |
| - elif self.config.decoder_block == DecoderBlockType.LLAMA2: |
| 58 | + elif config.decoder_block == DecoderBlockType.LLAMA2: |
56 | 59 | from MaxText.layers import llama2 # pylint: disable=import-outside-toplevel
|
57 | 60 |
|
58 | 61 | return [llama2.LlamaDecoderLayer]
|
59 |
| - elif self.config.decoder_block == DecoderBlockType.MISTRAL: |
| 62 | + elif config.decoder_block == DecoderBlockType.MISTRAL: |
60 | 63 | # TODO(ranran): update to Mistral with sliding window attention
|
61 | 64 | from MaxText.layers import mistral # pylint: disable=import-outside-toplevel
|
62 | 65 |
|
63 | 66 | return [mistral.MistralDecoderLayer]
|
64 |
| - elif self.config.decoder_block == DecoderBlockType.MIXTRAL: |
| 67 | + elif config.decoder_block == DecoderBlockType.MIXTRAL: |
65 | 68 | from MaxText.layers import mixtral # pylint: disable=import-outside-toplevel
|
66 | 69 |
|
67 | 70 | return [mixtral.MixtralDecoderLayer]
|
68 |
| - elif self.config.decoder_block == DecoderBlockType.DEEPSEEK: |
| 71 | + elif config.decoder_block == DecoderBlockType.DEEPSEEK: |
69 | 72 | from MaxText.layers import deepseek # pylint: disable=import-outside-toplevel
|
70 | 73 |
|
71 | 74 | return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
|
72 |
| - elif self.config.decoder_block == DecoderBlockType.GEMMA: |
| 75 | + elif config.decoder_block == DecoderBlockType.GEMMA: |
73 | 76 | from MaxText.layers import gemma # pylint: disable=import-outside-toplevel
|
74 | 77 |
|
75 | 78 | return [gemma.GemmaDecoderLayer]
|
76 |
| - elif self.config.decoder_block == DecoderBlockType.GEMMA2: |
| 79 | + elif config.decoder_block == DecoderBlockType.GEMMA2: |
77 | 80 | from MaxText.layers import gemma2 # pylint: disable=import-outside-toplevel
|
78 | 81 |
|
79 | 82 | return [gemma2.Gemma2DecoderLayer]
|
80 |
| - elif self.config.decoder_block == DecoderBlockType.GEMMA3: |
| 83 | + elif config.decoder_block == DecoderBlockType.GEMMA3: |
81 | 84 | from MaxText.layers import gemma3 # pylint: disable=import-outside-toplevel
|
82 | 85 |
|
83 | 86 | return [gemma3.Gemma3DecoderLayer]
|
84 |
| - elif self.config.decoder_block == DecoderBlockType.GPT3: |
| 87 | + elif config.decoder_block == DecoderBlockType.GPT3: |
85 | 88 | from MaxText.layers import gpt3 # pylint: disable=import-outside-toplevel
|
86 | 89 |
|
87 | 90 | return [gpt3.Gpt3DecoderLayer]
|
88 |
| - elif self.config.decoder_block == DecoderBlockType.SIMPLE: |
| 91 | + elif config.decoder_block == DecoderBlockType.SIMPLE: |
89 | 92 | from MaxText.layers import simple_layer # pylint: disable=import-outside-toplevel
|
90 | 93 |
|
91 | 94 | return [simple_layer.SimpleDecoderLayer]
|
92 |
| - elif self.config.decoder_block == DecoderBlockType.SIMPLE_MLP: |
| 95 | + elif config.decoder_block == DecoderBlockType.SIMPLE_MLP: |
93 | 96 | from MaxText.layers import simple_layer # pylint: disable=import-outside-toplevel
|
94 | 97 |
|
95 | 98 | return [simple_layer.SimpleMlpDecoderLayer]
|
96 |
| - elif self.config.decoder_block == DecoderBlockType.LLAMA4: |
| 99 | + elif config.decoder_block == DecoderBlockType.LLAMA4: |
97 | 100 | from MaxText.layers import llama4 # pylint: disable=import-outside-toplevel
|
98 | 101 |
|
99 |
| - if self.config.scan_layers: |
| 102 | + if config.scan_layers: |
100 | 103 | return [llama4.Llama4ScannableBlock]
|
101 | 104 | else:
|
102 | 105 | return [llama4.Llama4DecoderLayer]
|
103 | 106 | else:
|
104 |
| - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") |
| 107 | + raise ValueError(f"Incorrect decoder_block name {config.decoder_block.value=}") |
105 | 108 |
|
106 | 109 |
|
107 | 110 | class SequentialBlockDecoderLayers(nn.Module):
|
@@ -152,7 +155,7 @@ class Decoder(nn.Module):
|
152 | 155 |
|
153 | 156 | def setup(self):
|
154 | 157 | """Initialize decoder layer."""
|
155 |
| - self.decoder_layer = get_decoder_layers() |
| 158 | + self.decoder_layer = get_decoder_layers(self.config) |
156 | 159 | self.norm_layer = self.get_norm_layer()
|
157 | 160 | if self.config.using_pipeline_parallelism:
|
158 | 161 | pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
|
|
0 commit comments