Skip to content

Commit 59ddf3f

Browse files
committed
Adding back comments
1 parent 70f40d6 commit 59ddf3f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

MaxText/layers/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,12 @@ def __call__(self, hidden_states: jnp.ndarray, deterministic: bool, model_mode:
180180
inputs_shape=y.shape,
181181
features=cfg.vocab_size,
182182
weight_dtype=cfg.weight_dtype,
183-
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype,
183+
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
184184
kernel_axes=("embed", "vocab"),
185185
name="logits_dense",
186186
matmul_precision=self.config.matmul_precision,
187187
)
188-
# Then, call the instance with the input tensor.
188+
# We do not quantize the logits matmul.
189189
logits = dense_layer(y)
190190

191191
# 4. Final Casting

0 commit comments

Comments
 (0)