diff --git a/AUTHORS.rst b/AUTHORS.rst index b675f5b..f390cb4 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -38,3 +38,4 @@ Contributors - Jonathan Herriott - Job Evers - Cyrus Durgin +- Hugo Klepsch \ No newline at end of file diff --git a/retrying.py b/retrying.py index c08df1e..2d1bf3b 100644 --- a/retrying.py +++ b/retrying.py @@ -58,6 +58,26 @@ def wrapped_f(*args, **kw): return wrap +_default_logger = None +_configured_null_logger = False + +def _pick_logger(logger=None): + # Factor this logic out into a smaller function so that `global` only needs to be here, + # not the large __init__ function. + global _default_logger, _configured_null_logger + + if logger in (True, None): + if _default_logger is None: + _default_logger = logging.getLogger(__name__) + # Only add the null handler once, not every time we get the logger + if logger is None and not _configured_null_logger: + _configured_null_logger = True + _default_logger.addHandler(logging.NullHandler()) + _default_logger.propagate = False + return _default_logger + else: # Not None (and not True) -> must have supplied a logger. Just use that. + return logger + class Retrying(object): def __init__( @@ -110,14 +130,8 @@ def __init__( self._wait_jitter_max = 0 if wait_jitter_max is None else wait_jitter_max self._before_attempts = before_attempts self._after_attempts = after_attempts - - if logger in (True, None): - self._logger = logging.getLogger(__name__) - if logger is None: - self._logger.addHandler(logging.NullHandler()) - self._logger.propagate = False - elif logger: - self._logger = logger + + self._logger = _pick_logger(logger) # TODO add chaining of stop behaviors # stop behavior diff --git a/test_retrying.py b/test_retrying.py index 05370f6..3134e4b 100644 --- a/test_retrying.py +++ b/test_retrying.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging +import random import time import unittest @@ -461,7 +462,7 @@ def _test_before(): _test_before() - self.assertTrue(TestBeforeAfterAttempts._attempt_number is 1) + self.assertTrue(TestBeforeAfterAttempts._attempt_number == 1) def test_after_attempts(self): TestBeforeAfterAttempts._attempt_number = 0 @@ -478,7 +479,111 @@ def _test_after(): _test_after() - self.assertTrue(TestBeforeAfterAttempts._attempt_number is 2) + self.assertTrue(TestBeforeAfterAttempts._attempt_number == 2) + + +class LoadTest(unittest.TestCase): + total_calls = 0 + foo = "bar" # Static value to reduce variance + + @retry( + retry_on_result=lambda result: result is None, + stop_max_attempt_number=30000, # For this test, never raise an exception + wait_fixed=0 + ) + def fn_to_test(self): + LoadTest.total_calls += 1 + # Simulate sometimes returning None (triggering retry) + if random.random() < 0.4: + return LoadTest.foo + else: + return None + + def benchmark(self, duration_seconds=2): + # Reset counter + LoadTest.total_calls = 0 + + start_time = time.time() + end_time = start_time + duration_seconds + + while time.time() < end_time: + _ = self.fn_to_test() + + actual_duration = time.time() - start_time + + # Calculate metrics + calls_per_second = LoadTest.total_calls / actual_duration + return calls_per_second + + def test_load(self): + """ + In 1.3.5, there was a bug where calls wrapped with retry would take longer and longer to complete. + This test checks that wrapping a function with retry doesn't affect its performance over several calls. + This test takes ~24 seconds to run. + """ + calls_per_second_initial = self.benchmark(2) + # Run the benchmark a few more times. This triggers the performance bug in 1.3.5. + for i in range(10): + _ = self.benchmark(2) + calls_per_second_final = self.benchmark(2) + # Ensure that the later calls are within +/-20% the speed of the initial calls + self.assertTrue( + calls_per_second_initial * 0.8 <= calls_per_second_final <= calls_per_second_initial * 1.2, + { + "calls_per_second_initial": calls_per_second_initial, + "calls_per_second_final": calls_per_second_final, + "msg": "calls_per_second_final was not within +/-20% of calls_per_second_initial" + } + ) + + +class TestLogger(unittest.TestCase): + def setUp(self): + # Set up a function with a dummy logger + self.logger = logging.getLogger("test_retrying") + self.test_handler = self.TestHandler() + self.logger.addHandler(self.test_handler) + @retry(stop_max_attempt_number=1, retry_on_result=lambda r: r is None, logger=self.logger) + def foo_with_logger(): + return None + self.foo_with_logger = foo_with_logger + + class TestHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records = [] + + def handle(self, record): + self.records.append(record) + + @retry(stop_max_attempt_number=1, retry_on_result=lambda r: r is None) + def foo_no_logger(self): + return None + + def test_logger_None(self): + # Assert not raises (anything except RetryError) + try: + self.foo_no_logger() + except RetryError: + pass + + def test_logger_custom(self): + try: + self.foo_with_logger() + except RetryError: + pass + self.assertEqual(len(self.test_handler.records), 1) + + def test_logger_true(self): + @retry(stop_max_attempt_number=1, retry_on_result=lambda r: r is None, logger=True) + def foo_true(): + return None + + # Assert not raises (anything except RetryError) + try: + foo_true() + except RetryError: + pass if __name__ == "__main__":