Skip to content

Commit e7e7f16

Browse files
committed
second attempt at adding warning
1 parent 9f62bc4 commit e7e7f16

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

MaxText/tests/module_tests.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Copyright 2023 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""Tests for train.py with various configs"""
18+
19+
import os
20+
import sys
21+
import subprocess
22+
import unittest
23+
from absl.testing import absltest
24+
25+
from MaxText.globals import PKG_DIR
26+
27+
28+
class ModuleTests(unittest.TestCase):
29+
"""Tests train.py with various invocation methods"""
30+
31+
def test_get_informative_error(self):
32+
command = [
33+
sys.executable, # use the same interpreter instance
34+
"MaxText/train.py",
35+
"MaxText/configs/base.yml",
36+
"run_name=maxtext-module-test",
37+
"base_output_directory=gs://does-not-exist",
38+
"dataset_type=synthetic",
39+
]
40+
41+
with self.assertRaises(subprocess.CalledProcessError) as context:
42+
subprocess.run(command, check=True, text=True, capture_output=True)
43+
44+
ex = context.exception
45+
self.assertEqual(ex.returncode, 64)
46+
self.assertIn("The MaxText API has changed", ex.stderr)
47+
48+
49+
if __name__ == "__main__":
50+
absltest.main()

MaxText/train.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,18 @@
1717
# pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports
1818
"""Training loop and Decoding of the model."""
1919

20-
# Calling jax.device_count here prevents a "TPU platform already registered" error.
21-
# See github.com/google/maxtext/issues/20 for more
20+
import sys
21+
22+
if not __package__:
23+
print(
24+
"Error: The MaxText API has changed. MaxText entry-points are now "
25+
"invoked as modules, with syntax of the form "
26+
"`python3 -m MaxText.module <args>`, e.g., "
27+
"`python3 -m MaxText.train <args>`.",
28+
file=sys.stderr,
29+
)
30+
# EX_USAGE
31+
sys.exit(64)
2232

2333
import datetime
2434
import os

0 commit comments

Comments
 (0)