Skip to content

Commit c551e72

Browse files
author
maxtext authors
committed
[maxtext] avoid creating kernel param in serve mode
self.kernel is not used in serve model anyway. In serve mode, the checkpoint is quantized. Quantized weights will be read by aqt, so we don't need to create this param. PiperOrigin-RevId: 772594747
1 parent 3662540 commit c551e72

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

MaxText/layers/linears.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,17 @@ def __init__(
131131
len(self.axis), len(self.axis) + len(self.out_features)
132132
)
133133

134-
self.kernel = nnx.Param(
135-
self.kernel_init(
136-
rngs.params(),
137-
kernel_shape,
138-
self.weight_dtype,
139-
kernel_in_axis,
140-
kernel_out_axis,
141-
),
142-
sharding=self.kernel_axes,
143-
)
134+
if not quantizations.in_serve_mode(self.quant):
135+
self.kernel = nnx.Param(
136+
self.kernel_init(
137+
rngs.params(),
138+
kernel_shape,
139+
self.weight_dtype,
140+
kernel_in_axis,
141+
kernel_out_axis,
142+
),
143+
sharding=self.kernel_axes,
144+
)
144145

145146
if self.use_bias:
146147
bias_axes = self.kernel_axes[-len(self.out_features) :]

0 commit comments

Comments
 (0)