Skip to content
54 changes: 48 additions & 6 deletions simpervisor/atexitasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,69 @@

_handlers = []

_prev_handlers = {}
signal_handler_set = False


def add_handler(handler):
global signal_handler_set
if not signal_handler_set:
signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
signal_handler_set = True
"""
Adds a signal handler function that will be called when the Python process
receives either a SIGINT (on windows CTRL_C_EVENT) or SIGTERM signal.
"""
_ensure_signal_handlers_set()
_handlers.append(handler)


def remove_handler(handler):
"""Removes previously added signal handler."""
_handlers.remove(handler)


def _ensure_signal_handlers_set():
"""
Ensures _handle_signal is registered as a top level signal handler for
SIGINT and SIGTERM and saves previously registered non-default Python
callable signal handlers.
"""
global signal_handler_set
if not signal_handler_set:
# save previously registered non-default Python callable signal handlers
#
# windows note: signal.getsignal(signal.CTRL_C_EVENT) would error with
# "ValueError: signal number out of range", and
# signal.signal(signal.CTRL_C_EVENT, _handle_signal) would error with
# "ValueError: invalid signal value".
#
prev_sigint = signal.getsignal(signal.SIGINT)
prev_sigterm = signal.getsignal(signal.SIGTERM)
if callable(prev_sigint) and prev_sigint.__qualname__ != "default_int_handler":
_prev_handlers[signal.SIGINT] = prev_sigint
if callable(prev_sigterm) and prev_sigterm != signal.Handlers.SIG_DFL:
_prev_handlers[signal.SIGTERM] = prev_sigterm

# let _handle_signal handle SIGINT and SIGTERM
signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
signal_handler_set = True


def _handle_signal(signum, *args):
"""
Calls functions added by add_handler, and then calls the previously
registered non-default Python callable signal handler if there were one.
"""
prev_handler = _prev_handlers.get(signum)

# Windows doesn't support SIGINT. Replacing it with CTRL_C_EVENT so that it
# can used with subprocess.Popen.send_signal
if signum == signal.SIGINT and sys.platform == "win32":
signum = signal.CTRL_C_EVENT

for handler in _handlers:
handler(signum)
sys.exit(0)

# call previously registered non-default Python callable handler or exit
if prev_handler:
prev_handler(signum, *args)
Comment on lines +72 to +74
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should get review and ideally tested on windows:

  • I made this prev_handler(signum, *args) instead of prev_handler(signum, None)`, is this right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*args will be frame argument. Not sure if it works on windows though. I think we can pass *args to handler.

else:
sys.exit(0)
12 changes: 12 additions & 0 deletions tests/child_scripts/signalprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,30 @@
Print received SIGTERM & SIGINT signals
"""
import asyncio
import signal
import sys
from functools import partial

from simpervisor.atexitasync import add_handler


def _non_default_handle(sig, frame):
# Print the received signum and then exit
print(f"non default handler received {sig}", flush=True)
sys.exit(0)


def _handle_sigterm(number, received_signum):
# Print the received signum so we know our handler was called
print(f"handler {number} received", int(received_signum), flush=True)


handlercount = int(sys.argv[1])
# Add non default handler if arg true is passed
if len(sys.argv) == 3:
if bool(sys.argv[2]):
signal.signal(signal.SIGINT, _non_default_handle)
signal.signal(signal.SIGTERM, _non_default_handle)
for i in range(handlercount):
add_handler(partial(_handle_sigterm, i))

Expand Down
2 changes: 1 addition & 1 deletion tests/child_scripts/simplehttpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ async def hello(request):

app = web.Application()
app.add_routes(routes)
web.run_app(app, port=PORT)
web.run_app(app, port=int(PORT))
45 changes: 45 additions & 0 deletions tests/test_atexitasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,48 @@ def test_atexitasync(signum, handlercount):
# The code should exit cleanly
retcode = proc.wait()
assert retcode == 0


@pytest.mark.parametrize(
"signum, handlercount",
[
(signal.SIGTERM, 1),
(signal.SIGINT, 1),
(signal.SIGTERM, 5),
(signal.SIGINT, 5),
],
)
@pytest.mark.skipif(
sys.platform == "win32",
reason="Testing signals on Windows doesn't seem to be possible",
)
def test_atexitasync_with_non_default_handlers(signum, handlercount):
"""
Test signal handlers receive signals properly and handler existing default handlers
correctly
"""
signalprinter_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "child_scripts", "signalprinter.py"
)
proc = subprocess.Popen(
[sys.executable, signalprinter_file, str(handlercount), "true"],
stdout=subprocess.PIPE,
text=True,
)

# Give the process time to register signal handlers
time.sleep(1)
proc.send_signal(signum)

# Make sure the signal is handled by our handler in the code
stdout, stderr = proc.communicate()
expected_output = (
"\n".join([f"handler {i} received {signum}" for i in range(handlercount)])
+ f"\nnon default handler received {signum}\n"
)

assert stdout == expected_output

# The code should exit cleanly
retcode = proc.wait()
assert retcode == 0