Skip to content

Commit c439630

Browse files
committed
Fix comp graph pickling
1 parent 7187d5c commit c439630

File tree

5 files changed

+106
-68
lines changed

5 files changed

+106
-68
lines changed

coconut/compiler/compiler.py

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
from threading import Lock
4040
from copy import copy
4141

42-
if sys.version_info >= (3,):
43-
import pickle
44-
else:
45-
import cPickle as pickle
46-
4742
from coconut._pyparsing import (
4843
USE_COMPUTATION_GRAPH,
4944
USE_CACHE,
@@ -109,8 +104,10 @@
109104
incremental_mode_cache_size,
110105
incremental_cache_limit,
111106
use_line_by_line_parser,
107+
coconut_cache_dir,
112108
)
113109
from coconut.util import (
110+
pickle,
114111
pickleable_obj,
115112
checksum,
116113
clip,
@@ -126,6 +123,7 @@
126123
create_method,
127124
univ_open,
128125
staledict,
126+
ensure_dir,
129127
)
130128
from coconut.exceptions import (
131129
CoconutException,
@@ -161,6 +159,7 @@
161159
ComputationNode,
162160
StartOfStrGrammar,
163161
MatchAny,
162+
CombineToNode,
164163
sys_target,
165164
getline,
166165
addskip,
@@ -210,7 +209,8 @@
210209
get_cache_items_for,
211210
clear_packrat_cache,
212211
add_packrat_cache_items,
213-
get_cache_path,
212+
parse_elem_to_identifier,
213+
identifier_to_parse_elem,
214214
_lookup_loc,
215215
_value_exc_loc_or_ret,
216216
)
@@ -447,10 +447,7 @@ def pickle_cache(original, cache_path, include_incremental=True, protocol=pickle
447447
# are the only ones that parseIncremental will reuse
448448
if 0 < loc < len(original) - 1:
449449
elem = lookup[0]
450-
identifier = elem.parse_element_index
451-
internal_assert(lambda: elem == all_parse_elements[identifier](), "failed to look up parse element by identifier", lambda: (elem, all_parse_elements[identifier]()))
452-
if validation_dict is not None:
453-
validation_dict[identifier] = elem.__class__.__name__
450+
identifier = parse_elem_to_identifier(elem, validation_dict)
454451
pickleable_lookup = (identifier,) + lookup[1:]
455452
internal_assert(value[_value_exc_loc_or_ret] is True or isinstance(value[_value_exc_loc_or_ret], int), "cache must be dehybridized before pickling", value[_value_exc_loc_or_ret])
456453
pickleable_cache_items.append((pickleable_lookup, value))
@@ -460,21 +457,15 @@ def pickle_cache(original, cache_path, include_incremental=True, protocol=pickle
460457
for wkref in MatchAny.all_match_anys:
461458
match_any = wkref()
462459
if match_any is not None and match_any.adaptive_usage is not None:
463-
identifier = match_any.parse_element_index
464-
internal_assert(lambda: match_any == all_parse_elements[identifier](), "failed to look up match_any by identifier", lambda: (match_any, all_parse_elements[identifier]()))
465-
if validation_dict is not None:
466-
validation_dict[identifier] = match_any.__class__.__name__
460+
identifier = parse_elem_to_identifier(match_any, validation_dict)
467461
match_any.expr_order.sort(key=lambda i: (-match_any.adaptive_usage[i], i))
468462
all_adaptive_items.append((identifier, (match_any.adaptive_usage, match_any.expr_order)))
469463
logger.log("Caching adaptive item:", match_any, (match_any.adaptive_usage, match_any.expr_order))
470464

471465
# computation graph cache
472466
computation_graph_cache_items = []
473467
for (call_site_name, grammar_elem), cache in Compiler.computation_graph_caches.items():
474-
identifier = grammar_elem.parse_element_index
475-
internal_assert(lambda: grammar_elem == all_parse_elements[identifier](), "failed to look up grammar by identifier", lambda: (grammar_elem, all_parse_elements[identifier]()))
476-
if validation_dict is not None:
477-
validation_dict[identifier] = grammar_elem.__class__.__name__
468+
identifier = parse_elem_to_identifier(grammar_elem, validation_dict)
478469
computation_graph_cache_items.append(((call_site_name, identifier), cache))
479470

480471
logger.log("Saving {num_inc} incremental, {num_adapt} adaptive, and {num_comp_graph} computation graph cache items to {cache_path!r}.".format(
@@ -492,8 +483,9 @@ def pickle_cache(original, cache_path, include_incremental=True, protocol=pickle
492483
"computation_graph_cache_items": computation_graph_cache_items,
493484
}
494485
try:
495-
with univ_open(cache_path, "wb") as pickle_file:
496-
pickle.dump(pickle_info_obj, pickle_file, protocol=protocol)
486+
with CombineToNode.enable_pickling(validation_dict):
487+
with univ_open(cache_path, "wb") as pickle_file:
488+
pickle.dump(pickle_info_obj, pickle_file, protocol=protocol)
497489
except Exception:
498490
logger.log_exc()
499491
return False
@@ -531,15 +523,25 @@ def unpickle_cache(cache_path):
531523
all_adaptive_items = pickle_info_obj["all_adaptive_items"]
532524
computation_graph_cache_items = pickle_info_obj["computation_graph_cache_items"]
533525

526+
# incremental cache
527+
new_cache_items = []
528+
for pickleable_lookup, value in pickleable_cache_items:
529+
maybe_elem = identifier_to_parse_elem(pickleable_lookup[0], validation_dict)
530+
if maybe_elem is not None:
531+
internal_assert(value[_value_exc_loc_or_ret] is True or isinstance(value[_value_exc_loc_or_ret], int), "attempting to unpickle hybrid cache item", value[_value_exc_loc_or_ret])
532+
lookup = (maybe_elem,) + pickleable_lookup[1:]
533+
usefullness = value[-1][0]
534+
internal_assert(usefullness, "loaded useless cache item", (lookup, value))
535+
stale_value = value[:-1] + ([usefullness + 1],)
536+
new_cache_items.append((lookup, stale_value))
537+
add_packrat_cache_items(new_cache_items)
538+
534539
# adaptive cache
535540
for identifier, (adaptive_usage, expr_order) in all_adaptive_items:
536-
if identifier < len(all_parse_elements):
537-
maybe_elem = all_parse_elements[identifier]()
538-
if maybe_elem is not None:
539-
if validation_dict is not None:
540-
internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "adaptive cache pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier]))
541-
maybe_elem.adaptive_usage = adaptive_usage
542-
maybe_elem.expr_order = expr_order
541+
maybe_elem = identifier_to_parse_elem(identifier, validation_dict)
542+
if maybe_elem is not None:
543+
maybe_elem.adaptive_usage = adaptive_usage
544+
maybe_elem.expr_order = expr_order
543545

544546
max_cache_size = min(
545547
incremental_mode_cache_size or float("inf"),
@@ -548,38 +550,29 @@ def unpickle_cache(cache_path):
548550
if max_cache_size != float("inf"):
549551
pickleable_cache_items = pickleable_cache_items[-max_cache_size:]
550552

551-
# incremental cache
552-
new_cache_items = []
553-
for pickleable_lookup, value in pickleable_cache_items:
554-
identifier = pickleable_lookup[0]
555-
if identifier < len(all_parse_elements):
556-
maybe_elem = all_parse_elements[identifier]()
557-
if maybe_elem is not None:
558-
if validation_dict is not None:
559-
internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "incremental cache pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier]))
560-
internal_assert(value[_value_exc_loc_or_ret] is True or isinstance(value[_value_exc_loc_or_ret], int), "attempting to unpickle hybrid cache item", value[_value_exc_loc_or_ret])
561-
lookup = (maybe_elem,) + pickleable_lookup[1:]
562-
usefullness = value[-1][0]
563-
internal_assert(usefullness, "loaded useless cache item", (lookup, value))
564-
stale_value = value[:-1] + ([usefullness + 1],)
565-
new_cache_items.append((lookup, stale_value))
566-
add_packrat_cache_items(new_cache_items)
567-
568553
# computation graph cache
569554
for (call_site_name, identifier), cache in computation_graph_cache_items:
570-
if identifier < len(all_parse_elements):
571-
maybe_elem = all_parse_elements[identifier]()
572-
if maybe_elem is not None:
573-
if validation_dict is not None:
574-
internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "computation graph cache pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier]))
575-
Compiler.computation_graph_caches[(call_site_name, maybe_elem)].update(cache)
555+
maybe_elem = identifier_to_parse_elem(identifier, validation_dict)
556+
if maybe_elem is not None:
557+
Compiler.computation_graph_caches[(call_site_name, maybe_elem)].update(cache)
576558

577559
num_inc = len(pickleable_cache_items)
578560
num_adapt = len(all_adaptive_items)
579561
num_comp_graph = sum(len(cache) for _, cache in computation_graph_cache_items) if computation_graph_cache_items else 0
580562
return num_inc, num_adapt, num_comp_graph
581563

582564

565+
def get_cache_path(codepath):
566+
"""Get the cache filename to use for the given codepath."""
567+
code_dir, code_fname = os.path.split(codepath)
568+
569+
cache_dir = os.path.join(code_dir, coconut_cache_dir)
570+
ensure_dir(cache_dir, logger=logger)
571+
572+
pickle_fname = code_fname + ".pkl"
573+
return os.path.join(cache_dir, pickle_fname)
574+
575+
583576
def load_cache_for(inputstring, codepath):
584577
"""Load cache_path (for the given inputstring and filename)."""
585578
if not SUPPORTS_INCREMENTAL:

coconut/compiler/util.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from coconut.root import * # NOQA
2929

3030
import sys
31-
import os
3231
import re
3332
import ast
3433
import inspect
@@ -49,7 +48,6 @@
4948
SUPPORTS_INCREMENTAL,
5049
SUPPORTS_ADAPTIVE,
5150
SUPPORTS_PACKRAT_CONTEXT,
52-
replaceWith,
5351
ZeroOrMore,
5452
OneOrMore,
5553
Optional,
@@ -77,13 +75,15 @@
7775

7876
from coconut.integrations import embed
7977
from coconut.util import (
78+
pickle,
8079
override,
8180
get_name,
8281
get_target_info,
8382
memoize,
84-
ensure_dir,
8583
get_clock_time,
8684
literal_lines,
85+
const,
86+
pickleable_obj,
8787
)
8888
from coconut.terminal import (
8989
logger,
@@ -120,14 +120,14 @@
120120
incremental_cache_limit,
121121
incremental_mode_cache_successes,
122122
use_adaptive_any_of,
123-
coconut_cache_dir,
124123
use_fast_pyparsing_reprs,
125124
require_cache_clear_frac,
126125
reverse_any_of,
127126
all_keywords,
128127
always_keep_parse_name_prefix,
129128
keep_if_unchanged_parse_name_prefix,
130129
incremental_use_hybrid,
130+
test_computation_graph_pickling,
131131
)
132132
from coconut.exceptions import (
133133
CoconutException,
@@ -315,7 +315,7 @@ def build_new_toks_for(tokens, new_toklist, unchanged=False):
315315
cached_trim_arity = memoize()(_trim_arity)
316316

317317

318-
class ComputationNode(object):
318+
class ComputationNode(pickleable_obj):
319319
"""A single node in the computation graph."""
320320
__slots__ = ("action", "original", "loc", "tokens", "trim_arity")
321321
pprinting = False
@@ -339,6 +339,12 @@ def __new__(cls, action, original, loc, tokens, trim_arity=True, ignore_no_token
339339
If ignore_no_tokens, then don't call the action if there are no tokens.
340340
If ignore_one_token, then don't call the action if there is only one token.
341341
If greedy, then never defer the action until later."""
342+
if test_computation_graph_pickling:
343+
with CombineToNode.enable_pickling():
344+
try:
345+
pickle.dumps(action, protocol=pickle.HIGHEST_PROTOCOL)
346+
except Exception:
347+
raise ValueError("unpickleable action in ComputationNode: " + repr(action))
342348
if ignore_no_tokens and len(tokens) == 0 or ignore_one_token and len(tokens) == 1:
343349
# could be a ComputationNode, so we can't have an __init__
344350
return build_new_toks_for(tokens, tokens, unchanged=True)
@@ -452,9 +458,10 @@ def evaluate(self):
452458
raise self.exception_maker()
453459

454460

455-
class CombineToNode(Combine):
461+
class CombineToNode(Combine, pickleable_obj):
456462
"""Modified Combine to work with the computation graph."""
457463
__slots__ = ()
464+
validation_dict = None
458465

459466
def _combine(self, original, loc, tokens):
460467
"""Implement the parse action for Combine."""
@@ -468,6 +475,26 @@ def postParse(self, original, loc, tokens):
468475
"""Create a ComputationNode for Combine."""
469476
return ComputationNode(self._combine, original, loc, tokens, ignore_no_tokens=True, ignore_one_token=True, trim_arity=False)
470477

478+
@classmethod
479+
def reconstitute(self, identifier):
480+
return identifier_to_parse_elem(identifier, self.validation_dict)
481+
482+
def __reduce__(self):
483+
if self.validation_dict is None:
484+
return super(CombineToNode, self).__reduce__()
485+
else:
486+
return (self.reconstitute, (parse_elem_to_identifier(self, self.validation_dict),))
487+
488+
@classmethod
489+
@contextmanager
490+
def enable_pickling(validation_dict={}):
491+
"""Context manager to enable pickling for CombineToNode."""
492+
old_validation_dict, CombineToNode.validation_dict = CombineToNode.validation_dict, validation_dict
493+
try:
494+
yield
495+
finally:
496+
CombineToNode.validation_dict = old_validation_dict
497+
471498

472499
if USE_COMPUTATION_GRAPH:
473500
combine = CombineToNode
@@ -1136,15 +1163,24 @@ def disable_incremental_parsing():
11361163
force_reset_packrat_cache()
11371164

11381165

1139-
def get_cache_path(codepath):
1140-
"""Get the cache filename to use for the given codepath."""
1141-
code_dir, code_fname = os.path.split(codepath)
1166+
def parse_elem_to_identifier(elem, validation_dict=None):
1167+
"""Get the identifier for the given parse element."""
1168+
identifier = elem.parse_element_index
1169+
internal_assert(lambda: elem == all_parse_elements[identifier](), "failed to look up parse element by identifier", lambda: (elem, all_parse_elements[identifier]()))
1170+
if validation_dict is not None:
1171+
validation_dict[identifier] = elem.__class__.__name__
1172+
return identifier
11421173

1143-
cache_dir = os.path.join(code_dir, coconut_cache_dir)
1144-
ensure_dir(cache_dir, logger=logger)
11451174

1146-
pickle_fname = code_fname + ".pkl"
1147-
return os.path.join(cache_dir, pickle_fname)
1175+
def identifier_to_parse_elem(identifier, validation_dict=None):
1176+
"""Get the parse element for the given identifier."""
1177+
if identifier < len(all_parse_elements):
1178+
maybe_elem = all_parse_elements[identifier]()
1179+
if maybe_elem is not None:
1180+
if validation_dict is not None:
1181+
internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "parse element pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier]))
1182+
return maybe_elem
1183+
return None
11481184

11491185

11501186
# -----------------------------------------------------------------------------------------------------------------------
@@ -1350,11 +1386,14 @@ def add_labels(tokens):
13501386
return (item, tokens._ParseResults__tokdict.keys())
13511387

13521388

1353-
def invalid_syntax_handle(msg, loc, tokens):
1389+
def invalid_syntax_handle(msg, original, loc, tokens):
13541390
"""Pickleable handler for invalid_syntax."""
13551391
raise CoconutDeferredSyntaxError(msg, loc)
13561392

13571393

1394+
invalid_syntax_handle.trim_arity = False # fixes pypy issue
1395+
1396+
13581397
def invalid_syntax(item, msg, **kwargs):
13591398
"""Mark a grammar item as an invalid item that raises a syntax err with msg."""
13601399
if isinstance(item, str):
@@ -1405,7 +1444,7 @@ def regex_item(regex, options=None):
14051444

14061445
def fixto(item, output):
14071446
"""Force an item to result in a specific output."""
1408-
return attach(item, replaceWith(output), ignore_arguments=True)
1447+
return attach(item, const([output]), ignore_arguments=True)
14091448

14101449

14111450
def addspace(item):

coconut/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ def get_path_env_var(env_var, default):
172172
# COMPILER CONSTANTS:
173173
# -----------------------------------------------------------------------------------------------------------------------
174174

175-
# set this to True only ever temporarily for ease of debugging
175+
# set these to True only ever temporarily for ease of debugging
176176
embed_on_internal_exc = get_bool_env_var("COCONUT_EMBED_ON_INTERNAL_EXC", False)
177+
test_computation_graph_pickling = False
177178

178179
# should be the minimal ref count observed by maybe_copy_elem
179180
temp_grammar_item_ref_count = 4 if PY311 else 5

coconut/tests/constants_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ class TestConstants(unittest.TestCase):
8080

8181
def test_defaults(self):
8282
assert constants.use_fast_pyparsing_reprs
83-
assert not constants.embed_on_internal_exc
8483
assert constants.num_assemble_logical_lines_tries >= 1
84+
assert not constants.embed_on_internal_exc
85+
assert not constants.test_computation_graph_pickling
8586

8687
def test_fixpath(self):
8788
assert os.path.basename(fixpath("CamelCase.py")) == "CamelCase.py"

coconut/util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
from backports.functools_lru_cache import lru_cache
4040
except ImportError:
4141
lru_cache = None
42+
if sys.version_info >= (3,):
43+
import pickle # NOQA
44+
else:
45+
import cPickle as pickle # NOQA
4246

4347
from coconut.root import _get_target_info
4448
from coconut.constants import (
@@ -286,7 +290,7 @@ def add(self, item):
286290
self[item] = True
287291

288292

289-
class staledict(dict, object):
293+
class staledict(dict, pickleable_obj):
290294
"""A dictionary that keeps track of staleness.
291295
292296
Initial elements are always marked as stale and pickling always prunes stale elements."""

0 commit comments

Comments
 (0)