From f21607650c934f32f4e56842f4460a30d6f4bf8a Mon Sep 17 00:00:00 2001 From: Raymond Pan Date: Fri, 16 May 2025 12:31:36 -0400 Subject: [PATCH] update guide and tests --- tests/test_feature_engineering.py | 6 +- tests/test_guide.py | 173 ++++++++++++++++++++++++++++++ zephyr_ml/core.py | 170 ++++++++++++++++++----------- zephyr_ml/core_prev.py | 6 +- 4 files changed, 287 insertions(+), 68 deletions(-) create mode 100644 tests/test_guide.py diff --git a/tests/test_feature_engineering.py b/tests/test_feature_engineering.py index 5baf7dd..e4c08ba 100644 --- a/tests/test_feature_engineering.py +++ b/tests/test_feature_engineering.py @@ -182,7 +182,8 @@ def test_process_signals_pidata(pidata_es, transformations, aggregations): pd.testing.assert_frame_equal(processed, expected) -def test_process_signals_pidata_replace(pidata_es, transformations, aggregations): +def test_process_signals_pidata_replace( + pidata_es, transformations, aggregations): signal_dataframe_name = 'pidata' signal_column = 'val1' window_size = '1m' @@ -248,7 +249,8 @@ def test_process_signals_scada(scada_es, transformations, aggregations): pd.testing.assert_frame_equal(scada_es['scada_processed'], expected) -def test_process_signals_scada_replace(scada_es, transformations, aggregations): +def test_process_signals_scada_replace( + scada_es, transformations, aggregations): signal_dataframe_name = 'scada' signal_column = 'val1' window_size = '1m' diff --git a/tests/test_guide.py b/tests/test_guide.py new file mode 100644 index 0000000..4284f72 --- /dev/null +++ b/tests/test_guide.py @@ -0,0 +1,173 @@ +from zephyr_ml.core import GuideHandler, guide + + +class DummyObject: + def __init__(self): + producers_and_getters = [ + ([self.step0_key, self.step0_set], [self.step0_getter]), + ([self.step1_key, self.step1_set], [self.step1_getter]), + ([self.step2_key, self.step2_set], [self.step2_getter]) + ] + set_methods = { + self.step0_set.__name__, + self.step1_set.__name__, + self.step2_set.__name__ + } + self._guide_handler = GuideHandler(producers_and_getters, set_methods) + + @guide + def step0_key(self): + return "step0_key_result" + + @guide + def step0_set(self): + return "step0_set_result" + + @guide + def step0_getter(self): + return "step0_get_result" + + @guide + def step1_key(self): + return "step1_key_result" + + @guide + def step1_set(self): + return "step1_set_result" + + @guide + def step1_getter(self): + return "step1_get_result" + + @guide + def step2_key(self): + return "step2_key_result" + + @guide + def step2_set(self): + return "step2_set_result" + + @guide + def step2_getter(self): + return "step2_get_result" + + +def test_forward_key_steps(): + """Test performing key steps in forward order""" + obj = DummyObject() + + # First step should work without warnings + assert obj.step0_key() == "step0_key_result" + + # Second step should work without warnings since previous step is up to + # date + assert obj.step1_key() == "step1_key_result" + + # Third step should work without warnings since previous step is up to date + assert obj.step2_key() == "step2_key_result" + + +def test_set_methods_can_skip(caplog): + """Test that set methods can skip steps""" + obj = DummyObject() + + # Set methods should work in any order and start new iterations + assert obj.step2_set() == "step2_set_result" # Skip to step 2 + assert "[GUIDE] STALE WARNING" in caplog.text + + assert obj.step0_set() == "step0_set_result" # Go back to step 0 + assert "[GUIDE] STALE WARNING" in caplog.text + assert obj.step1_set() == "step1_set_result" # Do step 1 + assert "[GUIDE] STALE WARNING" in caplog.text + + +def test_key_methods_require_previous_step(caplog): + """Test that key methods require the previous step to be up to date""" + obj = DummyObject() + + # Try to do step 1 without doing step 0 first + obj.step1_key() + assert "[GUIDE] INCONSISTENCY WARNING" in caplog.text + + # Do step 0, then step 1 should work + obj.step0_key() + assert obj.step1_key() == "step1_key_result" + + +def test_stale_data_warning(caplog): + """Test warning when data becomes stale""" + obj = DummyObject() + + # Complete steps 0 and 1 + obj.step0_key() + obj.step1_key() + + # Go back to step 0 with set method (allowed, but warns about stale data) + obj.step0_set() + assert "[GUIDE] STALE WARNING" in caplog.text + + +def test_getter_with_stale_data(caplog): + """Test getting data that may be stale""" + obj = DummyObject() + + # Complete steps 0 and 1 + obj.step0_key() + obj.step1_key() + + # Go back to step 0 with set method + obj.step0_set() + + # Try to get data from step 1, should warn about stale data + obj.step1_getter() + assert "[GUIDE] STALE WARNING" in caplog.text + + +def test_getter_with_missing_key(caplog): + """Test getting data when the key method hasn't been run""" + obj = DummyObject() + + # Try to get data without running key method first + obj.step1_getter() + assert "[GUIDE] INCONSISTENCY WARNING" in caplog.text + + +def test_key_method_after_stale_data(caplog): + """Test that key methods cannot be run when previous step is stale""" + obj = DummyObject() + + # Complete steps 0 and 1 + obj.step0_key() + obj.step1_key() + + # Go back to step 0 with set method + obj.step1_set() + obj.step1_key() + assert "[GUIDE] INCONSISTENCY WARNING" in caplog.text + + +def test_multiple_iterations(): + """Test multiple iterations through the steps""" + obj = DummyObject() + + # First iteration with key methods + assert obj.step0_key() == "step0_key_result" + assert obj.step1_key() == "step1_key_result" + + # Second iteration starting with set method + assert obj.step0_set() == "step0_set_result" + # Can't do step 1 with key method after set without redoing step 0 key + assert obj.step1_key() == "step1_key_result" + + +def test_guide_decorator(): + """Test that the guide decorator properly wraps methods""" + obj = DummyObject() + + # Check that the decorator preserves function metadata + assert obj.step0_key.__name__ == "step0_key" + assert obj.step0_getter.__name__ == "step0_getter" + + # Check that the decorator routes through the guide handler + assert obj.step0_key() == "step0_key_result" + assert obj.step0_getter() == "step0_get_result" diff --git a/zephyr_ml/core.py b/zephyr_ml/core.py index 309e08b..50cfb64 100644 --- a/zephyr_ml/core.py +++ b/zephyr_ml/core.py @@ -31,7 +31,7 @@ class GuideHandler: def __init__(self, producers_and_getters, set_methods): - self.cur_term = 0 + self.cur_iteration = 0 self.current_step = -1 self.start_point = -1 self.producers_and_getters = producers_and_getters @@ -40,9 +40,9 @@ def __init__(self, producers_and_getters, set_methods): self.producer_to_step_map = {} self.getter_to_step_map = {} - self.terms = [] + self.iterations = [] for idx, (producers, getters) in enumerate(self.producers_and_getters): - self.terms.append(-1) + self.iterations.append(-1) for prod in producers: self.producer_to_step_map[prod.__name__] = idx @@ -69,7 +69,7 @@ def get_get_steps_in_between(self, cur_step, next_step): def get_last_up_to_date(self, next_step): latest_up_to_date = 0 for step in range(next_step): - if self.terms[step] == self.cur_term: + if self.iterations[step] == self.cur_iteration: latest_up_to_date = step return latest_up_to_date @@ -85,83 +85,112 @@ def get_steps_in_between(self, cur_step, next_step): step_strs.append(f"{step}. {' or '.join(option_strs)}") return step_strs - def perform_producer_step(self, zephyr, method, *method_args, **method_kwargs): + def log_next_producer_step(self, name): + next_step = self.current_step + 1 + + if next_step >= len(self.producers_and_getters): + LOGGER.warning("[GUIDE] You have reached the end of the \ + predictive engineering workflow.\ + You may continue to go back and reperform steps based on results.") + else: + next_step_name = self.producers_and_getters[next_step][0][0].__name__ + LOGGER.warning(f"[GUIDE] Successfully performed {name}. You can perform the \ + next step by calling\ + {next_step_name}.") + + def perform_producer_step(self, zephyr, method, + *method_args, **method_kwargs): step_num = self.producer_to_step_map[method.__name__] res = method(zephyr, *method_args, **method_kwargs) self.current_step = step_num - self.terms[step_num] = self.cur_term + self.iterations[step_num] = self.cur_iteration + self.log_next_producer_step(method.__name__) return res - def try_log_skipping_steps_warning(self, name, next_step): - steps_skipped = self.get_steps_in_between(self.current_step, next_step) - if len(steps_skipped) > 0: - necc_steps = self.join_steps(steps_skipped) - LOGGER.warning( - f"Performing {name}. You are skipping the following steps:\n{necc_steps}") - - def try_log_making_stale_warning(self, name, next_step): - next_next_step = next_step + 1 - prod_steps = f"step {next_next_step}: \ - {' or '.join(self.producers_and_getters[next_next_step][0])}" - # add later set methods - get_steps = self.join_steps( - self.get_get_steps_in_between( - next_step, self.current_step + 1)) - - LOGGER.warning(f"Performing {name}. You are beginning a new iteration.\ - Any data returned by the following get methods will be \ - considered stale:\n{get_steps}. To continue with this \ - iteration, please perform \n{prod_steps}") + def try_log_forward_set_method_warning(self, name, next_step): + LOGGER.warning( + f"[GUIDE] STALE WARNING: Performing step {next_step} \ + from step {self.current_step} \ + via {name}.\ + This is a forward step via a set method.\ + The current step is {self.current_step}.\ + All previous steps will be considered stale.") + + def try_log_backwards_set_method_warning(self, name, next_step): + LOGGER.warning(f"[GUIDE] STALE WARNING: Performing step \ + {next_step} from step {self.current_step} \ + via {name}. \ + This is a backwards step via a set method.\ + All other steps will be considered stale.") + + def try_log_backwards_key_method_warning(self, name, next_step): + steps_in_between = self.join_steps( + self.get_steps_in_between( + self.current_step, next_step)) + LOGGER.warning(f"[GUIDE] STALE WARNING: Performing step {next_step} \ + from step {self.current_step} via {name}. \ + This is a backwards step via a key method.\ + The following steps will be considered stale:\n{steps_in_between}") def log_get_inconsistent_warning(self, name, next_step): - prod_steps = f"{next_step}. \ - {' or '.join(self.producers_and_getters[next_step][0])}" + prod_steps_str = ' or '.join([method.__name__ for method in + self.producers_and_getters[next_step][0]]) + prod_steps = f"{next_step}.{prod_steps_str}" latest_up_to_date = self.get_last_up_to_date(next_step) - LOGGER.warning(f"Unable to perform {name} because {prod_steps} has not \ + LOGGER.warning(f"[GUIDE] INCONSISTENCY WARNING: Unable to perform {name} \ + because {prod_steps} has not \ been run yet. Run steps starting at or before \ {latest_up_to_date} ") def log_get_stale_warning(self, name, next_step): latest_up_to_date = self.get_last_up_to_date(next_step) - LOGGER.warning(f"Performing {name}. This data is potentially stale. \ + LOGGER.warning(f"[GUIDE] STALE WARNING: Performing {name}. \ + This data is potentially stale. \ Re-run steps starting at or before \ {latest_up_to_date} to ensure data is up to date.") # tries to perform step if possible -> warns that data might be stale - def try_perform_forward_producer_step(self, zephyr, method, *method_args, **method_kwargs): + def try_perform_forward_producer_step( + self, zephyr, method, *method_args, **method_kwargs): name = method.__name__ next_step = self.producer_to_step_map[name] if name in self.set_methods: # set method will update start point and start new iteration - self.try_log_skipping_steps_warning(name, next_step) + self.try_log_forward_set_method_warning(name, next_step) self.start_point = next_step - self.cur_term += 1 - # next_step == 0, set method (already warned), or previous step is up to term + self.cur_iteration += 1 + # next_step == 0, set method (already warned), or previous step is up + # to term res = self.perform_producer_step( zephyr, method, *method_args, **method_kwargs) return res - # next_step == 0, set method, or previous step is up to term - - def try_perform_backward_producer_step(self, zephyr, method, *method_args, **method_kwargs): + def try_perform_backward_producer_step( + self, zephyr, method, *method_args, **method_kwargs): name = method.__name__ next_step = self.producer_to_step_map[name] # starting new iteration - self.cur_term += 1 + self.cur_iteration += 1 if next_step == 0 or name in self.set_methods: self.start_point = next_step else: # key method # mark everything from start point to next step as current term for i in range(self.start_point, next_step): - if self.terms[i] != -1: - self.terms[i] = self.cur_term + if self.iterations[i] != -1: + self.iterations[i] = self.cur_iteration + + if name in self.set_methods: + self.try_log_backwards_set_method_warning(name, next_step) + else: + self.try_log_backwards_key_method_warning(name, next_step) - self.try_log_making_stale_warning(next_step) res = self.perform_producer_step( zephyr, method, *method_args, **method_kwargs) + return res - def try_perform_producer_step(self, zephyr, method, *method_args, **method_kwargs): + def try_perform_producer_step( + self, zephyr, method, *method_args, **method_kwargs): name = method.__name__ next_step = self.producer_to_step_map[name] if next_step >= self.current_step: @@ -179,13 +208,17 @@ def try_perform_inconsistent_producer_step( # add using stale and overwriting self, zephyr, method, *method_args, **method_kwargs): name = method.__name__ next_step = self.producer_to_step_map[name] - # inconsistent forward step: performing key method but previous step is not up to date - if next_step >= self.current_step and self.terms[next_step-1] != self.cur_term: + # inconsistent forward step: performing key method but previous step is + # not up to date + if (next_step >= self.current_step and + self.iterations[next_step - 1] != self.cur_iteration): corr_set_method = self.producers_and_getters[next_step][0][1].__name__ - prev_step = next_step-1 + prev_step = next_step - 1 prev_set_method = self.producers_and_getters[prev_step][0][1].__name__ prev_key_method = self.producers_and_getters[prev_step][0][0].__name__ - LOGGER.warning(f"Unable to perform {name} because you are performing a key method at\ + LOGGER.warning(f"[GUIDE] INCONSISTENCY WARNING:Unable \ + to perform {name} because you are\ + performing a key method at\ step {next_step} but the result of the previous step, \ step {prev_step}, is STALE.\ If you already have the data for step {next_step}, \ @@ -202,14 +235,17 @@ def try_perform_inconsistent_producer_step( # add using stale and overwriting # method at step 0: {first_set_method}.\ # If you would like to perform step {next_step}, \ # please use the corresponding key method: {corr_key_method}.") - # inconsistent backward step: performing key method but previous step is not up to date - elif next_step < self.current_step and self.terms[next_step-1] != self.cur_term: - prev_step = next_step-1 + # inconsistent backward step: performing key method but previous step + # is not up to date + elif (next_step < self.current_step and + self.iterations[next_step - 1] != self.cur_iteration): + prev_step = next_step - 1 prev_key_method = self.producers_and_getters[prev_step][0][0].__name__ corr_set_method = self.producers_and_getters[next_step][0][1].__name__ prev_get_method = self.producers_and_getters[prev_step][1][0].__name__ prev_set_method = self.producers_and_getters[prev_step][0][1].__name__ - LOGGER.warning(f"Unable to perform {name} because you are going \ + LOGGER.warning(f"[GUIDE] INCONSISTENCY WARNING: Unable to perform {name} \ + because you are going \ backwards and starting a new iteration by\ performing a key method at step {next_step} \ but the result of the previous step,\ @@ -223,18 +259,19 @@ def try_perform_inconsistent_producer_step( # add using stale and overwriting call {corr_set_method} to set the data.\ ") - def try_perform_getter_step(self, zephyr, method, *method_args, **method_kwargs): + def try_perform_getter_step( + self, zephyr, method, *method_args, **method_kwargs): name = method.__name__ # either inconsistent, stale, or up to date step_num = self.getter_to_step_map[name] - step_term = self.terms[step_num] - if step_term == -1: - self.log_get_inconsistent_warning(step_num) - elif step_term == self.cur_term: + step_iteration = self.iterations[step_num] + if step_iteration == -1: + self.log_get_inconsistent_warning(name, step_num) + elif step_iteration == self.cur_iteration: res = method(zephyr, *method_args, **method_kwargs) return res else: - self.log_get_stale_warning(step_num) + self.log_get_stale_warning(name, step_num) res = method(zephyr, *method_args, **method_kwargs) return res @@ -244,11 +281,13 @@ def guide_step(self, zephyr, method, *method_args, **method_kwargs): # up-todate next_step = self.producer_to_step_map[method_name] if (next_step == 0 or # 0 step always valid, starting new iteration - # set method always valid, but will update start point and start new iteration + # set method always valid, but will update start point and + # start new iteration method_name in self.set_methods or # key method valid if previous step is up to date - self.terms[next_step-1] == self.cur_term): - # forward step only valid if set method or key method w/ no skips + self.iterations[next_step - 1] == self.cur_iteration): + # forward step only valid if set method or key method w/ no + # skips res = self.try_perform_producer_step( zephyr, method, *method_args, **method_kwargs) return res @@ -267,8 +306,9 @@ def guide_step(self, zephyr, method, *method_args, **method_kwargs): def guide(method): @wraps(method) - def guided_step(self, *method_args, **method_kwargs): - return self.guide_handler.guide_step(self, method, *method_args, **method_kwargs) + def guided_step(instance, *method_args, **method_kwargs): + return instance._guide_handler.guide_step( + instance, method, *method_args, **method_kwargs) return guided_step @@ -317,7 +357,7 @@ def __init__(self): self.set_feature_matrix.__name__, self.set_train_test_split.__name__, self.set_fitted_pipeline.__name__]) - self.guide_handler = GuideHandler(step_order, set_methods) + self._guide_handler = GuideHandler(step_order, set_methods) def GET_ENTITYSET_TYPES(self): """Get the supported entityset types and their required dataframes/columns. @@ -735,7 +775,8 @@ def get_feature_matrix(self): return self._feature_matrix, self._label_col_name, self._features @guide - def set_feature_matrix(self, feature_matrix, labels=None, label_col_name="label"): + def set_feature_matrix(self, feature_matrix, + labels=None, label_col_name="label"): """Set the feature matrix for this Zephyr instance. Args: @@ -909,7 +950,8 @@ def predict(self, X=None, visual=False, **kwargs): outputs = self._pipeline.predict(X, output_=outputs_spec, **kwargs) if visual and visual_names: prediction = outputs[0] - return prediction, dict(zip(visual_names, outputs[-len(visual_names):])) + return prediction, dict( + zip(visual_names, outputs[-len(visual_names):])) return outputs diff --git a/zephyr_ml/core_prev.py b/zephyr_ml/core_prev.py index ca16567..aaadd8d 100644 --- a/zephyr_ml/core_prev.py +++ b/zephyr_ml/core_prev.py @@ -130,7 +130,8 @@ def fit( if visual and outputs is not None: return dict(zip(visual_names, outputs)) - def predict(self, X: pd.DataFrame, visual: bool = False, **kwargs) -> pd.Series: + def predict(self, X: pd.DataFrame, visual: bool = False, + **kwargs) -> pd.Series: """Predict the pipeline to the given data. Args: @@ -155,7 +156,8 @@ def predict(self, X: pd.DataFrame, visual: bool = False, **kwargs) -> pd.Series: if visual and visual_names: prediction = outputs[0] - return prediction, dict(zip(visual_names, outputs[-len(visual_names):])) + return prediction, dict( + zip(visual_names, outputs[-len(visual_names):])) return outputs