Skip to content

Commit 044f3a2

Browse files
committed
[*.py] bash code_style.sh ; [MaxText/maxengine.py] Add pylint: disable=arguments-differ to MaxEngine::prefill_multisampling
1 parent c2cf4a0 commit 044f3a2

File tree

8 files changed

+19
-33
lines changed

8 files changed

+19
-33
lines changed

MaxText/checkpointing.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def create_orbax_emergency_checkpoint_manager(
103103

104104
# Only create directories if running on GPUs as the previous
105105
# directory structure might be assumed by TPUs
106-
if global_mesh.devices.flatten()[0].platform == 'gpu':
106+
if global_mesh.devices.flatten()[0].platform == "gpu":
107107
# pylint: disable=protected-access
108108
local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}"
109109
local_p = epath.Path(local_checkpoint_dir)
@@ -374,9 +374,7 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-
374374
max_logging.log("Setting up checkpoint logger...")
375375
if config.enable_checkpoint_cloud_logger:
376376
logger_name = f"goodput_{config.run_name}"
377-
options = ocp.logging.CloudLoggerOptions(
378-
job_name=config.run_name, logger_name=logger_name
379-
)
377+
options = ocp.logging.CloudLoggerOptions(job_name=config.run_name, logger_name=logger_name)
380378
orbax_cloud_logger = ocp.logging.CloudLogger(options=options)
381379
max_logging.log("Successfully set up checkpoint cloud logger.")
382380
return orbax_cloud_logger

MaxText/decode.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,12 @@ def main(argv: Sequence[str]) -> None:
140140
for i in range(_NUM_STREAMS):
141141
with jax.profiler.StepTraceAnnotation("prefill", stream=i):
142142
prefill_result, first_token = engine.prefill(
143-
params=params,
144-
padded_tokens=tokens,
145-
images=processor_output.pixel_values,
146-
true_length=true_length,
147-
rng=rng_prefill,
148-
slot=i,
143+
params=params,
144+
padded_tokens=tokens,
145+
images=processor_output.pixel_values,
146+
true_length=true_length,
147+
rng=rng_prefill,
148+
slot=i,
149149
)
150150
prefill_result_list.append(prefill_result)
151151
first_token_list.append(first_token)
@@ -178,6 +178,7 @@ def main(argv: Sequence[str]) -> None:
178178
# Deactivate profiler
179179
prof.deactivate()
180180

181+
181182
def _validate_config(config):
182183
assert config.load_full_state_path == "", (
183184
"Decode doesn't operate on full states! Convert to parameter checkpoint first." "Using generate_param_only_checkpoint."

MaxText/layers/linears.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ def __init__(
8484
axis: Union[Iterable[int], int] = -1,
8585
weight_dtype: DType = jnp.float32,
8686
dtype: DType = jnp.float32,
87-
kernel_init: NdInitializer = nd_dense_init(
88-
1.0, "fan_in", "truncated_normal"
89-
),
87+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
9088
kernel_axes: Tuple[Optional[str], ...] = (),
9189
quant: Optional[Quant] = None,
9290
use_bias: bool = False,
@@ -127,9 +125,7 @@ def __init__(
127125
# Parameter initialization
128126
kernel_shape = self.in_features + self.out_features
129127
kernel_in_axis = np.arange(len(self.axis))
130-
kernel_out_axis = np.arange(
131-
len(self.axis), len(self.axis) + len(self.out_features)
132-
)
128+
kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features))
133129

134130
if not quantizations.in_serve_mode(self.quant):
135131
self.kernel = nnx.Param(
@@ -218,9 +214,7 @@ def dense_general(
218214
axis: Union[Iterable[int], int] = -1,
219215
weight_dtype: DType = jnp.float32,
220216
dtype: DType = jnp.float32,
221-
kernel_init: NdInitializer = nd_dense_init(
222-
1.0, "fan_in", "truncated_normal"
223-
),
217+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
224218
kernel_axes: Tuple[Optional[str], ...] = (),
225219
quant: Optional[Quant] = None,
226220
use_bias: bool = False,
@@ -247,15 +241,11 @@ def dense_general(
247241
name: name passed to the ToLinen Module
248242
"""
249243
if not (inputs_shape is not None) ^ (in_features is not None):
250-
raise ValueError(
251-
"Exactly one of inputs_shape or in_features must be specified."
252-
)
244+
raise ValueError("Exactly one of inputs_shape or in_features must be specified.")
253245

254246
if inputs_shape is not None:
255247
axis = _canonicalize_tuple(axis)
256-
in_features = tuple(
257-
inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape))
258-
)
248+
in_features = tuple(inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape)))
259249
else:
260250
assert in_features is not None
261251
module = nnx.bridge.to_linen(
@@ -401,4 +391,3 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
401391

402392
output = checkpoint_name(output, "mlpwo")
403393
return output
404-

MaxText/layers/llama4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Llama4UnfoldConvolution(nn.Module):
5858

5959
def setup(self):
6060
"""
61-
Initialize Llama4UnfoldConvolution
61+
Initialize Llama4UnfoldConvolution
6262
"""
6363
cfg = self.config
6464
# Linear projection layer using dense_general.
@@ -190,7 +190,7 @@ class Llama4VisionMLP2(nn.Module):
190190

191191
def setup(self):
192192
"""
193-
Initialize Llama4VisionMLP2
193+
Initialize Llama4VisionMLP2
194194
"""
195195
cfg = self.config
196196
self.fc1 = linears.dense_general(

MaxText/layers/models.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,7 @@ def __call__(
666666
inputs_shape=y.shape,
667667
features=cfg.vocab_size,
668668
weight_dtype=cfg.weight_dtype,
669-
dtype=jnp.float32
670-
if cfg.logits_dot_in_fp32
671-
else cfg.dtype, # for logit training stability
669+
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
672670
kernel_axes=("embed", "vocab"),
673671
name="logits_dense",
674672
matmul_precision=self.config.matmul_precision,
@@ -804,4 +802,3 @@ def __call__(
804802
image_embeddings=image_embeddings,
805803
)
806804
return logits
807-

MaxText/layers/multi_token_prediction.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,3 @@ def __call__(
136136
# Shape: [B, S, H]
137137
# --- Return Processed Hidden State ---
138138
return next_hidden_state
139-

MaxText/maxengine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def prefill_multisampling_aot( # pylint: disable=too-many-positional-arguments
587587

588588
def prefill_multisampling(
589589
self, # pytype: disable=signature-mismatch
590+
# pylint: disable=arguments-differ
590591
*,
591592
params: Params,
592593
padded_tokens: jax.Array,

MaxText/tests/hf_checkpoint_conversion_check.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import os
1718
import torch
1819
import torch.nn.functional as F

0 commit comments

Comments
 (0)