Skip to content

Commit 8d57286

Browse files
committed
Improve perf and prep for comp graph caching
1 parent 74003ce commit 8d57286

File tree

4 files changed

+73
-56
lines changed

4 files changed

+73
-56
lines changed

coconut/compiler/compiler.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ class Compiler(Grammar, pickleable_obj):
431431
]
432432

433433
def __init__(self, *args, **kwargs):
434-
"""Creates a new compiler with the given parsing parameters."""
434+
"""Create a new compiler with the given parsing parameters."""
435435
self.setup(*args, **kwargs)
436436
self.reset()
437437

@@ -467,7 +467,7 @@ def setup(self, target=None, strict=False, minify=False, line_numbers=True, keep
467467
self.no_wrap = no_wrap
468468

469469
def __reduce__(self):
470-
"""Return pickling information."""
470+
"""Get pickling information."""
471471
return (self.__class__, (self.target, self.strict, self.minify, self.line_numbers, self.keep_lines, self.no_tco, self.no_wrap))
472472

473473
def get_cli_args(self):
@@ -644,6 +644,8 @@ def method(original, loc, tokens_or_item):
644644
if trim_arity:
645645
self_method = _trim_arity(self_method)
646646
return self_method(original, loc, tokens_or_item)
647+
if kwargs:
648+
method.__name__ = py_str(method.__name__ + "$(" + ", ".join(str(k) + "=" + repr(v) for k, v in kwargs.items()) + ")")
647649
internal_assert(
648650
hasattr(cls_method, "ignore_arguments") is hasattr(method, "ignore_arguments")
649651
and hasattr(cls_method, "ignore_no_tokens") is hasattr(method, "ignore_no_tokens")
@@ -1086,18 +1088,20 @@ def wrap_comment(self, text):
10861088
"""Wrap a comment."""
10871089
return "#" + self.add_ref("comment", text) + unwrapper
10881090

1089-
def wrap_error(self, error):
1091+
def wrap_error(self, error_maker):
10901092
"""Create a symbol that will raise the given error in postprocessing."""
1091-
return errwrapper + self.add_ref("error", error) + unwrapper
1093+
return errwrapper + self.add_ref("error_maker", error_maker) + unwrapper
10921094

1093-
def raise_or_wrap_error(self, error):
1094-
"""Raise if USE_COMPUTATION_GRAPH else wrap."""
1095+
def raise_or_wrap_error(self, *args, **kwargs):
1096+
"""Raise or defer if USE_COMPUTATION_GRAPH else wrap."""
1097+
error_maker = partial(self.make_err, *args, **kwargs)
10951098
if not USE_COMPUTATION_GRAPH:
1096-
return self.wrap_error(error)
1099+
return self.wrap_error(error_maker)
1100+
# differently-ordered any ofs can push these errors earlier than they should be, requiring us to defer them
10971101
elif use_adaptive_any_of or reverse_any_of:
1098-
return ExceptionNode(error)
1102+
return ExceptionNode(error_maker)
10991103
else:
1100-
raise error
1104+
raise error_maker()
11011105

11021106
def type_ignore_comment(self):
11031107
"""Get a "type: ignore" comment."""
@@ -2742,7 +2746,7 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names=
27422746
pre_err_line, err_line = raw_line.split(errwrapper, 1)
27432747
err_ref, post_err_line = err_line.split(unwrapper, 1)
27442748
if not ignore_errors:
2745-
raise self.get_ref("error", err_ref)
2749+
raise self.get_ref("error_maker", err_ref)()
27462750
raw_line = pre_err_line + " " + post_err_line
27472751

27482752
# look for functions
@@ -4890,6 +4894,7 @@ def where_stmt_handle(self, loc, tokens):
48904894

48914895
where_assigns = self.current_parsing_context("where")["assigns"]
48924896
internal_assert(lambda: where_assigns is not None, "missing where_assigns")
4897+
print(where_assigns)
48934898

48944899
where_init = "".join(body_stmts)
48954900
where_final = main_stmt + "\n"
@@ -4989,7 +4994,8 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False, expr
49894994
if self.disable_name_check:
49904995
return name
49914996

4992-
if assign:
4997+
# register non-mid-expression variable assignments inside of where statements for later mangling
4998+
if assign and not expr_setname:
49934999
where_context = self.current_parsing_context("where")
49945000
if where_context is not None:
49955001
where_assigns = where_context["assigns"]
@@ -5020,13 +5026,11 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False, expr
50205026
if typevar_info["typevar_locs"].get(name, None) != loc:
50215027
if assign:
50225028
return self.raise_or_wrap_error(
5023-
self.make_err(
5024-
CoconutSyntaxError,
5025-
"cannot reassign type variable '{name}'".format(name=name),
5026-
original,
5027-
loc,
5028-
extra="use explicit '\\{name}' syntax if intended".format(name=name),
5029-
),
5029+
CoconutSyntaxError,
5030+
"cannot reassign type variable '{name}'".format(name=name),
5031+
original,
5032+
loc,
5033+
extra="use explicit '\\{name}' syntax if intended".format(name=name),
50305034
)
50315035
return typevars[name]
50325036

@@ -5057,13 +5061,11 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False, expr
50575061
return name
50585062
elif assign:
50595063
return self.raise_or_wrap_error(
5060-
self.make_err(
5061-
CoconutTargetError,
5062-
"found Python-3-only assignment to 'exec' as a variable name",
5063-
original,
5064-
loc,
5065-
target="3",
5066-
),
5064+
CoconutTargetError,
5065+
"found Python-3-only assignment to 'exec' as a variable name",
5066+
original,
5067+
loc,
5068+
target="3",
50675069
)
50685070
else:
50695071
return "_coconut_exec"
@@ -5076,12 +5078,10 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False, expr
50765078
return name
50775079
elif not escaped and name.startswith(reserved_prefix) and name not in self.operators:
50785080
return self.raise_or_wrap_error(
5079-
self.make_err(
5080-
CoconutSyntaxError,
5081-
"variable names cannot start with reserved prefix '{prefix}' (use explicit '\\{name}' syntax if intending to access Coconut internals)".format(prefix=reserved_prefix, name=name),
5082-
original,
5083-
loc,
5084-
),
5081+
CoconutSyntaxError,
5082+
"variable names cannot start with reserved prefix '{prefix}' (use explicit '\\{name}' syntax if intending to access Coconut internals)".format(prefix=reserved_prefix, name=name),
5083+
original,
5084+
loc,
50855085
)
50865086
else:
50875087
return name
@@ -5104,7 +5104,7 @@ def check_strict(self, name, original, loc, tokens=(None,), only_warn=False, alw
51045104
else:
51055105
if always_warn:
51065106
kwargs["extra"] = "remove --strict to downgrade to a warning"
5107-
return self.raise_or_wrap_error(self.make_err(CoconutStyleError, message, original, loc, **kwargs))
5107+
return self.raise_or_wrap_error(CoconutStyleError, message, original, loc, **kwargs)
51085108
elif always_warn:
51095109
self.syntax_warning(message, original, loc)
51105110
return tokens[0]
@@ -5145,13 +5145,13 @@ def check_py(self, version, name, original, loc, tokens):
51455145
self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens)
51465146
version_info = get_target_info(version)
51475147
if self.target_info < version_info:
5148-
return self.raise_or_wrap_error(self.make_err(
5148+
return self.raise_or_wrap_error(
51495149
CoconutTargetError,
51505150
"found Python " + ".".join(str(v) for v in version_info) + " " + name,
51515151
original,
51525152
loc,
51535153
target=version,
5154-
))
5154+
)
51555155
else:
51565156
return tokens[0]
51575157

coconut/compiler/grammar.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2616,7 +2616,6 @@ class Grammar(object):
26162616
decoratable_data_stmt,
26172617
match_stmt,
26182618
passthrough_stmt,
2619-
where_stmt,
26202619
)
26212620

26222621
flow_stmt = any_of(
@@ -2661,8 +2660,8 @@ class Grammar(object):
26612660
stmt <<= final(
26622661
compound_stmt
26632662
| simple_stmt # includes destructuring
2664-
# must be after destructuring due to ambiguity
2665-
| cases_stmt
2663+
| cases_stmt # must be after destructuring due to ambiguity
2664+
| where_stmt # slows down parsing when put before simple_stmt
26662665
# at the very end as a fallback case for the anything parser
26672666
| anything_stmt
26682667
)

coconut/compiler/util.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def evaluate_tokens(tokens, **kwargs):
265265
elif isinstance(tokens, ComputationNode):
266266
result = tokens.evaluate()
267267
if is_final and isinstance(result, ExceptionNode):
268-
raise result.exception
268+
result.evaluate()
269269
elif isinstance(result, ParseResults):
270270
return make_modified_tokens(result, cls=MergeNode)
271271
elif isinstance(result, list):
@@ -286,7 +286,7 @@ def evaluate_tokens(tokens, **kwargs):
286286

287287
elif isinstance(tokens, ExceptionNode):
288288
if is_final:
289-
raise tokens.exception
289+
tokens.evaluate()
290290
return tokens
291291

292292
elif isinstance(tokens, DeferredNode):
@@ -321,9 +321,12 @@ def build_new_toks_for(tokens, new_toklist, unchanged=False):
321321
return new_toklist
322322

323323

324+
cached_trim_arity = memoize()(_trim_arity)
325+
326+
324327
class ComputationNode(object):
325328
"""A single node in the computation graph."""
326-
__slots__ = ("action", "original", "loc", "tokens")
329+
__slots__ = ("action", "original", "loc", "tokens", "trim_arity")
327330
pprinting = False
328331
override_original = None
329332
add_to_loc = 0
@@ -339,7 +342,7 @@ def using_overrides(cls):
339342
cls.override_original = override_original
340343
cls.add_to_loc = add_to_loc
341344

342-
def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_one_token=False, greedy=False, trim_arity=True):
345+
def __new__(cls, action, original, loc, tokens, trim_arity=True, ignore_no_tokens=False, ignore_one_token=False, greedy=False):
343346
"""Create a ComputionNode to return from a parse action.
344347
345348
If ignore_no_tokens, then don't call the action if there are no tokens.
@@ -350,18 +353,20 @@ def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_o
350353
return build_new_toks_for(tokens, tokens, unchanged=True)
351354
else:
352355
self = super(ComputationNode, cls).__new__(cls)
353-
if trim_arity:
354-
self.action = _trim_arity(action)
355-
else:
356-
self.action = action
357-
self.original = original if self.override_original is None else self.override_original
358-
self.loc = self.add_to_loc + loc
356+
self.action = action
357+
self.original = original
358+
self.loc = loc
359359
self.tokens = tokens
360+
self.trim_arity = trim_arity
360361
if greedy:
361362
return self.evaluate()
362363
else:
363364
return self
364365

366+
def __reduce__(self):
367+
"""Get pickling information."""
368+
return (self.__class__, (self.action, self.original, self.loc, self.tokens, self.trim_arity))
369+
365370
@property
366371
def name(self):
367372
"""Get the name of the action."""
@@ -377,15 +382,23 @@ def evaluate(self):
377382
# to actually be reevaluated
378383
if logger.tracing and not final_evaluate_tokens.enabled:
379384
logger.log_tag("cached_parse invalidated by", self)
385+
386+
if self.trim_arity:
387+
using_action = cached_trim_arity(self.action)
388+
else:
389+
using_action = self.action
390+
using_original = self.original if self.override_original is None else self.override_original
391+
using_loc = self.loc + self.add_to_loc
380392
evaluated_toks = evaluate_tokens(self.tokens)
393+
381394
if logger.tracing: # avoid the overhead of the call if not tracing
382-
logger.log_trace(self.name, self.original, self.loc, evaluated_toks, self.tokens)
395+
logger.log_trace(self.name, using_original, using_loc, evaluated_toks, self.tokens)
383396
if isinstance(evaluated_toks, ExceptionNode):
384397
return evaluated_toks # short-circuit if we got an ExceptionNode
385398
try:
386-
result = self.action(
387-
self.original,
388-
self.loc,
399+
result = using_action(
400+
using_original,
401+
using_loc,
389402
evaluated_toks,
390403
)
391404
except CoconutException:
@@ -398,6 +411,7 @@ def evaluate(self):
398411
embed(depth=2)
399412
else:
400413
raise error
414+
401415
out = build_new_toks_for(evaluated_toks, result)
402416
if logger.tracing: # avoid the overhead if not tracing
403417
dropped_keys = set(self.tokens._ParseResults__tokdict.keys())
@@ -434,12 +448,16 @@ def evaluate(self):
434448

435449
class ExceptionNode(object):
436450
"""A node in the computation graph that stores an exception that will be raised upon final evaluation."""
437-
__slots__ = ("exception",)
451+
__slots__ = ("exception_maker",)
438452

439-
def __init__(self, exception):
453+
def __init__(self, exception_maker):
440454
if not USE_COMPUTATION_GRAPH:
441-
raise exception
442-
self.exception = exception
455+
raise exception_maker()
456+
self.exception_maker = exception_maker
457+
458+
def evaluate(self):
459+
"""Raise the stored exception."""
460+
raise self.exception_maker()
443461

444462

445463
class CombineToNode(Combine):

coconut/root.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
VERSION = "3.1.2"
2727
VERSION_NAME = None
2828
# False for release, int >= 1 for develop
29-
DEVELOP = 1
29+
DEVELOP = 2
3030
ALPHA = False # for pre releases rather than post releases
3131

3232
assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"

0 commit comments

Comments
 (0)