diff --git a/simpervisor/atexitasync.py b/simpervisor/atexitasync.py index 65aa954..d4869cb 100644 --- a/simpervisor/atexitasync.py +++ b/simpervisor/atexitasync.py @@ -8,27 +8,69 @@ _handlers = [] +_prev_handlers = {} signal_handler_set = False def add_handler(handler): - global signal_handler_set - if not signal_handler_set: - signal.signal(signal.SIGINT, _handle_signal) - signal.signal(signal.SIGTERM, _handle_signal) - signal_handler_set = True + """ + Adds a signal handler function that will be called when the Python process + receives either a SIGINT (on windows CTRL_C_EVENT) or SIGTERM signal. + """ + _ensure_signal_handlers_set() _handlers.append(handler) def remove_handler(handler): + """Removes previously added signal handler.""" _handlers.remove(handler) +def _ensure_signal_handlers_set(): + """ + Ensures _handle_signal is registered as a top level signal handler for + SIGINT and SIGTERM and saves previously registered non-default Python + callable signal handlers. + """ + global signal_handler_set + if not signal_handler_set: + # save previously registered non-default Python callable signal handlers + # + # windows note: signal.getsignal(signal.CTRL_C_EVENT) would error with + # "ValueError: signal number out of range", and + # signal.signal(signal.CTRL_C_EVENT, _handle_signal) would error with + # "ValueError: invalid signal value". + # + prev_sigint = signal.getsignal(signal.SIGINT) + prev_sigterm = signal.getsignal(signal.SIGTERM) + if callable(prev_sigint) and prev_sigint.__qualname__ != "default_int_handler": + _prev_handlers[signal.SIGINT] = prev_sigint + if callable(prev_sigterm) and prev_sigterm != signal.Handlers.SIG_DFL: + _prev_handlers[signal.SIGTERM] = prev_sigterm + + # let _handle_signal handle SIGINT and SIGTERM + signal.signal(signal.SIGINT, _handle_signal) + signal.signal(signal.SIGTERM, _handle_signal) + signal_handler_set = True + + def _handle_signal(signum, *args): + """ + Calls functions added by add_handler, and then calls the previously + registered non-default Python callable signal handler if there were one. + """ + prev_handler = _prev_handlers.get(signum) + # Windows doesn't support SIGINT. Replacing it with CTRL_C_EVENT so that it # can used with subprocess.Popen.send_signal if signum == signal.SIGINT and sys.platform == "win32": signum = signal.CTRL_C_EVENT + for handler in _handlers: handler(signum) - sys.exit(0) + + # call previously registered non-default Python callable handler or exit + if prev_handler: + prev_handler(signum, *args) + else: + sys.exit(0) diff --git a/tests/child_scripts/signalprinter.py b/tests/child_scripts/signalprinter.py index 61b583c..480f048 100644 --- a/tests/child_scripts/signalprinter.py +++ b/tests/child_scripts/signalprinter.py @@ -2,18 +2,30 @@ Print received SIGTERM & SIGINT signals """ import asyncio +import signal import sys from functools import partial from simpervisor.atexitasync import add_handler +def _non_default_handle(sig, frame): + # Print the received signum and then exit + print(f"non default handler received {sig}", flush=True) + sys.exit(0) + + def _handle_sigterm(number, received_signum): # Print the received signum so we know our handler was called print(f"handler {number} received", int(received_signum), flush=True) handlercount = int(sys.argv[1]) +# Add non default handler if arg true is passed +if len(sys.argv) == 3: + if bool(sys.argv[2]): + signal.signal(signal.SIGINT, _non_default_handle) + signal.signal(signal.SIGTERM, _non_default_handle) for i in range(handlercount): add_handler(partial(_handle_sigterm, i)) diff --git a/tests/child_scripts/simplehttpserver.py b/tests/child_scripts/simplehttpserver.py index cc460ff..44668cc 100644 --- a/tests/child_scripts/simplehttpserver.py +++ b/tests/child_scripts/simplehttpserver.py @@ -23,4 +23,4 @@ async def hello(request): app = web.Application() app.add_routes(routes) -web.run_app(app, port=PORT) +web.run_app(app, port=int(PORT)) diff --git a/tests/test_atexitasync.py b/tests/test_atexitasync.py index 94e2fe1..866c6d2 100644 --- a/tests/test_atexitasync.py +++ b/tests/test_atexitasync.py @@ -49,3 +49,48 @@ def test_atexitasync(signum, handlercount): # The code should exit cleanly retcode = proc.wait() assert retcode == 0 + + +@pytest.mark.parametrize( + "signum, handlercount", + [ + (signal.SIGTERM, 1), + (signal.SIGINT, 1), + (signal.SIGTERM, 5), + (signal.SIGINT, 5), + ], +) +@pytest.mark.skipif( + sys.platform == "win32", + reason="Testing signals on Windows doesn't seem to be possible", +) +def test_atexitasync_with_non_default_handlers(signum, handlercount): + """ + Test signal handlers receive signals properly and handler existing default handlers + correctly + """ + signalprinter_file = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "child_scripts", "signalprinter.py" + ) + proc = subprocess.Popen( + [sys.executable, signalprinter_file, str(handlercount), "true"], + stdout=subprocess.PIPE, + text=True, + ) + + # Give the process time to register signal handlers + time.sleep(1) + proc.send_signal(signum) + + # Make sure the signal is handled by our handler in the code + stdout, stderr = proc.communicate() + expected_output = ( + "\n".join([f"handler {i} received {signum}" for i in range(handlercount)]) + + f"\nnon default handler received {signum}\n" + ) + + assert stdout == expected_output + + # The code should exit cleanly + retcode = proc.wait() + assert retcode == 0