Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ isort==5.10.1

pytest==7.1.2
pytest-asyncio==0.18.3
pytest-timeout==2.1.0 # used to timeout tests

flaky # Used for flaky tests (flaky decorator)
beautifulsoup4 # used in test_official for parsing tg docs
Expand Down
17 changes: 11 additions & 6 deletions telegram/ext/_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import inspect
import itertools
import logging
import platform
import signal
from collections import defaultdict
from contextlib import AbstractAsyncContextManager
Expand Down Expand Up @@ -547,7 +548,7 @@ def run_polling(
allowed_updates: List[str] = None,
drop_pending_updates: bool = None,
close_loop: bool = True,
stop_signals: Optional[Sequence[int]] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT),
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
) -> None:
"""Convenience method that takes care of initializing and starting the app,
polling updates from Telegram using :meth:`telegram.ext.Updater.start_polling` and
Expand Down Expand Up @@ -596,7 +597,7 @@ def run_polling(
stop_signals (Sequence[:obj:`int`] | :obj:`None`, optional): Signals that will shut
down the app. Pass :obj:`None` to not use stop signals.
Defaults to :data:`signal.SIGINT`, :data:`signal.SIGTERM` and
:data:`signal.SIGABRT`.
:data:`signal.SIGABRT` on non Windows platforms.

Caution:
Not every :class:`asyncio.AbstractEventLoop` implements
Expand Down Expand Up @@ -646,7 +647,7 @@ def run_webhook(
ip_address: str = None,
max_connections: int = 40,
close_loop: bool = True,
stop_signals: Optional[Sequence[int]] = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT),
stop_signals: ODVInput[Sequence[int]] = DEFAULT_NONE,
) -> None:
"""Convenience method that takes care of initializing and starting the app,
polling updates from Telegram using :meth:`telegram.ext.Updater.start_webhook` and
Expand Down Expand Up @@ -736,17 +737,21 @@ def _raise_system_exit() -> NoReturn:
def __run(
self,
updater_coroutine: Coroutine,
stop_signals: Optional[Sequence[int]],
stop_signals: ODVInput[Sequence[int]],
close_loop: bool = True,
) -> None:
# Calling get_event_loop() should still be okay even in py3.10+ as long as there is a
# running event loop or we are in the main thread, which are the intended use cases.
# See the docs of get_event_loop() and get_running_loop() for more info
loop = asyncio.get_event_loop()

if stop_signals is DEFAULT_NONE and platform.system() != "Windows":
stop_signals = (signal.SIGINT, signal.SIGTERM, signal.SIGABRT)

try:
for sig in stop_signals or []:
loop.add_signal_handler(sig, self._raise_system_exit)
if not isinstance(stop_signals, DefaultValue):
for sig in stop_signals or []:
loop.add_signal_handler(sig, self._raise_system_exit)
except NotImplementedError as exc:
warn(
f"Could not add signal handlers for the stop signals {stop_signals} due to "
Expand Down
40 changes: 38 additions & 2 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,9 +1659,9 @@ async def raise_method(*args, **kwargs):

with pytest.raises(RuntimeError, match="Prevent Actually Running"):
if "polling" in method:
app.run_polling(close_loop=False)
app.run_polling(close_loop=False, stop_signals=(signal.SIGINT,))
else:
app.run_webhook(close_loop=False)
app.run_webhook(close_loop=False, stop_signals=(signal.SIGTERM,))

assert len(recwarn) >= 1
found = False
Expand All @@ -1680,3 +1680,39 @@ async def raise_method(*args, **kwargs):
app.run_webhook(close_loop=False, stop_signals=None)

assert len(recwarn) == 0

@pytest.mark.timeout(6)
def test_signal_handlers(self, app, monkeypatch):
# this test should make sure that signal handlers are set by default on Linux + Mac,
# and not on Windows.

received_signals = []

def signal_handler_test(*args, **kwargs):
# args[0] is the signal, [1] the callback
received_signals.append(args[0])

loop = asyncio.get_event_loop()
monkeypatch.setattr(loop, "add_signal_handler", signal_handler_test)

async def abort_app():
await asyncio.sleep(2)
raise SystemExit

loop.create_task(abort_app())

app.run_polling(close_loop=False)

if platform.system() == "Windows":
assert received_signals == []
else:
assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]

received_signals.clear()
loop.create_task(abort_app())
app.run_webhook(port=49152, webhook_url="example.com", close_loop=False)

if platform.system() == "Windows":
assert received_signals == []
else:
assert received_signals == [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]