Skip to content

Commit c60d601

Browse files
committed
Fix CombineToNode pickling
1 parent c439630 commit c60d601

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

coconut/compiler/util.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ class CombineToNode(Combine, pickleable_obj):
462462
"""Modified Combine to work with the computation graph."""
463463
__slots__ = ()
464464
validation_dict = None
465+
pickling_enabled = False
465466

466467
def _combine(self, original, loc, tokens):
467468
"""Implement the parse action for Combine."""
@@ -475,25 +476,23 @@ def postParse(self, original, loc, tokens):
475476
"""Create a ComputationNode for Combine."""
476477
return ComputationNode(self._combine, original, loc, tokens, ignore_no_tokens=True, ignore_one_token=True, trim_arity=False)
477478

478-
@classmethod
479-
def reconstitute(self, identifier):
480-
return identifier_to_parse_elem(identifier, self.validation_dict)
481-
482479
def __reduce__(self):
483-
if self.validation_dict is None:
484-
return super(CombineToNode, self).__reduce__()
480+
if self.pickling_enabled:
481+
return (identifier_to_parse_elem, (parse_elem_to_identifier(self, self.validation_dict),))
485482
else:
486-
return (self.reconstitute, (parse_elem_to_identifier(self, self.validation_dict),))
483+
return super(CombineToNode, self).__reduce__()
487484

488485
@classmethod
489486
@contextmanager
490-
def enable_pickling(validation_dict={}):
487+
def enable_pickling(cls, validation_dict=None):
491488
"""Context manager to enable pickling for CombineToNode."""
492-
old_validation_dict, CombineToNode.validation_dict = CombineToNode.validation_dict, validation_dict
489+
old_validation_dict, cls.validation_dict = cls.validation_dict, validation_dict
490+
old_pickling_enabled, cls.pickling_enabled = cls.pickling_enabled, True
493491
try:
494492
yield
495493
finally:
496-
CombineToNode.validation_dict = old_validation_dict
494+
cls.pickling_enabled = old_pickling_enabled
495+
cls.validation_dict = old_validation_dict
497496

498497

499498
if USE_COMPUTATION_GRAPH:

0 commit comments

Comments
 (0)