Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ Contributors
- Jonathan Herriott
- Job Evers
- Cyrus Durgin
- Hugo Klepsch
30 changes: 22 additions & 8 deletions retrying.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ def wrapped_f(*args, **kw):

return wrap

_default_logger = None
_configured_null_logger = False

def _pick_logger(logger=None):
# Factor this logic out into a smaller function so that `global` only needs to be here,
# not the large __init__ function.
global _default_logger, _configured_null_logger

if logger in (True, None):
if _default_logger is None:
_default_logger = logging.getLogger(__name__)
# Only add the null handler once, not every time we get the logger
if logger is None and not _configured_null_logger:
_configured_null_logger = True
_default_logger.addHandler(logging.NullHandler())
_default_logger.propagate = False
return _default_logger
else: # Not None (and not True) -> must have supplied a logger. Just use that.
return logger


class Retrying(object):
def __init__(
Expand Down Expand Up @@ -110,14 +130,8 @@ def __init__(
self._wait_jitter_max = 0 if wait_jitter_max is None else wait_jitter_max
self._before_attempts = before_attempts
self._after_attempts = after_attempts

if logger in (True, None):
self._logger = logging.getLogger(__name__)
if logger is None:
self._logger.addHandler(logging.NullHandler())
self._logger.propagate = False
elif logger:
self._logger = logger

self._logger = _pick_logger(logger)

# TODO add chaining of stop behaviors
# stop behavior
Expand Down
111 changes: 108 additions & 3 deletions test_retrying.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
import time
import unittest

Expand Down Expand Up @@ -461,7 +462,7 @@ def _test_before():

_test_before()

self.assertTrue(TestBeforeAfterAttempts._attempt_number is 1)
self.assertTrue(TestBeforeAfterAttempts._attempt_number == 1)

def test_after_attempts(self):
TestBeforeAfterAttempts._attempt_number = 0
Expand All @@ -478,7 +479,111 @@ def _test_after():

_test_after()

self.assertTrue(TestBeforeAfterAttempts._attempt_number is 2)
self.assertTrue(TestBeforeAfterAttempts._attempt_number == 2)


class LoadTest(unittest.TestCase):
total_calls = 0
foo = "bar" # Static value to reduce variance

@retry(
retry_on_result=lambda result: result is None,
stop_max_attempt_number=30000, # For this test, never raise an exception
wait_fixed=0
)
def fn_to_test(self):
LoadTest.total_calls += 1
# Simulate sometimes returning None (triggering retry)
if random.random() < 0.4:
return LoadTest.foo
else:
return None

def benchmark(self, duration_seconds=2):
# Reset counter
LoadTest.total_calls = 0

start_time = time.time()
end_time = start_time + duration_seconds

while time.time() < end_time:
_ = self.fn_to_test()

actual_duration = time.time() - start_time

# Calculate metrics
calls_per_second = LoadTest.total_calls / actual_duration
return calls_per_second

def test_load(self):
"""
In 1.3.5, there was a bug where calls wrapped with retry would take longer and longer to complete.
This test checks that wrapping a function with retry doesn't affect its performance over several calls.
This test takes ~24 seconds to run.
"""
calls_per_second_initial = self.benchmark(2)
# Run the benchmark a few more times. This triggers the performance bug in 1.3.5.
for i in range(10):
_ = self.benchmark(2)
calls_per_second_final = self.benchmark(2)
# Ensure that the later calls are within +/-20% the speed of the initial calls
self.assertTrue(
calls_per_second_initial * 0.8 <= calls_per_second_final <= calls_per_second_initial * 1.2,
{
"calls_per_second_initial": calls_per_second_initial,
"calls_per_second_final": calls_per_second_final,
"msg": "calls_per_second_final was not within +/-20% of calls_per_second_initial"
}
)


class TestLogger(unittest.TestCase):
def setUp(self):
# Set up a function with a dummy logger
self.logger = logging.getLogger("test_retrying")
self.test_handler = self.TestHandler()
self.logger.addHandler(self.test_handler)
@retry(stop_max_attempt_number=1, retry_on_result=lambda r: r is None, logger=self.logger)
def foo_with_logger():
return None
self.foo_with_logger = foo_with_logger

class TestHandler(logging.Handler):
def __init__(self):
super().__init__()
self.records = []

def handle(self, record):
self.records.append(record)

@retry(stop_max_attempt_number=1, retry_on_result=lambda r: r is None)
def foo_no_logger(self):
return None

def test_logger_None(self):
# Assert not raises (anything except RetryError)
try:
self.foo_no_logger()
except RetryError:
pass

def test_logger_custom(self):
try:
self.foo_with_logger()
except RetryError:
pass
self.assertEqual(len(self.test_handler.records), 1)

def test_logger_true(self):
@retry(stop_max_attempt_number=1, retry_on_result=lambda r: r is None, logger=True)
def foo_true():
return None

# Assert not raises (anything except RetryError)
try:
foo_true()
except RetryError:
pass


if __name__ == "__main__":
Expand Down