From b6f085e86f0a62e1fe3a38b33256479cbf84a0dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 20:55:50 +0200 Subject: [PATCH 001/762] Fix badge. --- README.rst | 4 ++-- docs/index.rst | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index bda73c640..6b83ffed6 100644 --- a/README.rst +++ b/README.rst @@ -19,8 +19,8 @@ .. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://github.com/aaugustin/websockets/workflows/tests/badge.svg?branch=master - :target: https://github.com/aaugustin/websockets/actions?workflow=tests +.. |tests| image:: https://github.com/aaugustin/websockets/actions/workflows/tests.yml/badge.svg + :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml What is ``websockets``? ----------------------- diff --git a/docs/index.rst b/docs/index.rst index 5914d7289..f0c5f8d00 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,8 +15,8 @@ websockets .. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://github.com/aaugustin/websockets/workflows/tests/badge.svg?branch=master - :target: https://github.com/aaugustin/websockets/actions?workflow=tests +.. |tests| image:: https://github.com/aaugustin/websockets/actions/workflows/tests.yml/badge.svg + :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml ``websockets`` is a library for building WebSocket servers_ and clients_ in Python with a focus on correctness and simplicity. From db6f5b50d09ae86606c4e0d9b92079eb02dbff5e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 21:24:13 +0200 Subject: [PATCH 002/762] Document public methods that can raise ConnectionClosed. Fix #768. --- src/websockets/legacy/protocol.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index e4c6d63c5..c155c3bd7 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -465,6 +465,8 @@ async def send( Stopping in the middle of a fragmented message will cause a protocol error. Closing the connection has the same effect. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed :raises TypeError: for unsupported inputs """ @@ -653,6 +655,11 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no effect. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises ValueError: if another ping was sent with the same data and + the corresponding pong wasn't received yet + """ await self.ensure_open() @@ -685,6 +692,9 @@ async def pong(self, data: Data = b"") -> None: Canceling :meth:`pong` is discouraged for the same reason as :meth:`ping`. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + """ await self.ensure_open() From e4fe1e0999237c3def055dddfee39c3a09e743a4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 May 2021 07:38:21 +0200 Subject: [PATCH 003/762] Give up on a flaky test. Ref #390. --- tests/legacy/test_protocol.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index a89bcc88b..f928322ca 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1389,7 +1389,8 @@ def test_local_close_connection_lost_timeout_after_close(self): # half-close our side with write_eof() and close it with close(), time # out in 20ms. # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(19 * MS, 29 * MS): + # Add another 10ms because this test is flaky and I don't understand. + with self.assertCompletesWithin(19 * MS, 39 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True # HACK: disable close => other end drops connection emulation. @@ -1444,12 +1445,13 @@ def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS # If the client doesn't close its side of the TCP connection after we # half-close our side with write_eof() and close it with close(), time - # out in 20ms. + # out in 30ms. # - 10ms waiting for receiving a half-close # - 10ms waiting for receiving a close after write_eof # - 10ms waiting for receiving a close after close # Check the timing within -1/+9ms for robustness. - with self.assertCompletesWithin(29 * MS, 39 * MS): + # Add another 10ms because this test is flaky and I don't understand. + with self.assertCompletesWithin(29 * MS, 49 * MS): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True # HACK: disable close => other end drops connection emulation. From 4ca3721bd29e1e0efddb89dff17809382af339dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 09:01:01 +0200 Subject: [PATCH 004/762] Drop support for Python 3.6. Also fix a few places that were missed when adding 3.9. --- .github/workflows/tests.yml | 2 +- .github/workflows/wheels.yml | 2 +- README.rst | 2 +- docs/changelog.rst | 8 +++- docs/intro.rst | 2 +- setup.cfg | 2 +- setup.py | 7 +-- src/websockets/__main__.py | 6 +-- src/websockets/datastructures.py | 6 +-- src/websockets/imports.py | 74 +++++++++++++------------------- src/websockets/legacy/server.py | 9 +--- src/websockets/typing.py | 20 ++------- tox.ini | 2 +- 13 files changed, 52 insertions(+), 90 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index eb06ebfea..0042c302c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: [3.6, 3.7, 3.8, 3.9] + python: [3.7, 3.8, 3.9] steps: - name: Check out repository uses: actions/checkout@v2 diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 7ea97c61f..249cd36ce 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -50,7 +50,7 @@ jobs: uses: joerick/cibuildwheel@v1.11.0 env: CIBW_ARCHS_LINUX: auto aarch64 - CIBW_BUILD: cp36-* cp37-* cp38-* cp39-* + CIBW_BUILD: cp37-* cp38-* cp39-* - name: Save wheels uses: actions/upload-artifact@v2 with: diff --git a/README.rst b/README.rst index 6b83ffed6..7a163f5b7 100644 --- a/README.rst +++ b/README.rst @@ -125,7 +125,7 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. * If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which - only works on Python 3. ``websockets`` requires Python ≥ 3.6.1. + only works on Python 3. ``websockets`` requires Python ≥ 3.7. What else? ---------- diff --git a/docs/changelog.rst b/docs/changelog.rst index 1e5f92211..35f292e49 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,11 +25,15 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. -9.1 -... +10.0 +.... *In development* +.. warning:: + + **Version 10.0 drops compatibility with Python 3.6.** + 9.0.1 ..... diff --git a/docs/intro.rst b/docs/intro.rst index c77139cab..58d482a09 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -``websockets`` requires Python ≥ 3.6.1. +``websockets`` requires Python ≥ 3.7. You should use the latest version of Python if possible. If you're using an older version, be aware that for each minor version (3.x), only the latest diff --git a/setup.cfg b/setup.cfg index 04b792989..b15a7515c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py36.py37.py38.py39 +python-tag = py37.py38.py39 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index 5adb8e835..d493a12f1 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,5 @@ import pathlib import re -import sys import setuptools @@ -21,9 +20,6 @@ exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) -if sys.version_info[:3] < (3, 6, 1): - raise Exception("websockets requires Python >= 3.6.1.") - packages = ['websockets', 'websockets/legacy', 'websockets/extensions'] ext_modules = [ @@ -51,7 +47,6 @@ 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', @@ -62,6 +57,6 @@ ext_modules=ext_modules, include_package_data=True, zip_safe=False, - python_requires='>=3.6.1', + python_requires='>=3.7', test_loader='unittest:TestLoader', ) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index d44e34e74..530daef85 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -105,8 +105,8 @@ async def run_client( try: while True: - incoming: asyncio.Future[Any] = asyncio.ensure_future(websocket.recv()) - outgoing: asyncio.Future[Any] = asyncio.ensure_future(inputs.get()) + incoming: asyncio.Future[Any] = asyncio.create_task(websocket.recv()) + outgoing: asyncio.Future[Any] = asyncio.create_task(inputs.get()) done: Set[asyncio.Future[Any]] pending: Set[asyncio.Future[Any]] done, pending = await asyncio.wait( @@ -188,7 +188,7 @@ async def queue_factory() -> asyncio.Queue[str]: stop: asyncio.Future[None] = loop.create_future() # Schedule the task that will manage the connection. - asyncio.ensure_future(run_client(args.uri, loop, inputs, stop), loop=loop) + loop.create_task(run_client(args.uri, loop, inputs, stop)) # Start the event loop in a background thread. thread = threading.Thread(target=loop.run_forever) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index c8e17fa98..66f91e9bb 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -158,8 +158,4 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] HeadersLike__doc__ = """Types accepted wherever :class:`Headers` is expected""" -# Remove try / except when dropping support for Python < 3.7 -try: - HeadersLike.__doc__ = HeadersLike__doc__ -except AttributeError: # pragma: no cover - pass +HeadersLike.__doc__ = HeadersLike__doc__ diff --git a/src/websockets/imports.py b/src/websockets/imports.py index efd3eabf3..06917ce1d 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,4 +1,3 @@ -import sys import warnings from typing import Any, Dict, Iterable, Optional @@ -50,9 +49,6 @@ def lazy_import( This function defines __getattr__ and __dir__ per PEP 562. - On Python 3.6 and earlier, it falls back to non-lazy imports and doesn't - raise deprecation warnings. - """ if aliases is None: aliases = {} @@ -69,43 +65,33 @@ def lazy_import( package = namespace["__name__"] - if sys.version_info[:2] >= (3, 7): - - def __getattr__(name: str) -> Any: - assert aliases is not None # mypy cannot figure this out - try: - source = aliases[name] - except KeyError: - pass - else: - return import_name(name, source, namespace) - - assert deprecated_aliases is not None # mypy cannot figure this out - try: - source = deprecated_aliases[name] - except KeyError: - pass - else: - warnings.warn( - f"{package}.{name} is deprecated", - DeprecationWarning, - stacklevel=2, - ) - return import_name(name, source, namespace) - - raise AttributeError(f"module {package!r} has no attribute {name!r}") - - namespace["__getattr__"] = __getattr__ - - def __dir__() -> Iterable[str]: - return sorted(namespace_set | aliases_set | deprecated_aliases_set) - - namespace["__dir__"] = __dir__ - - else: # pragma: no cover - - for name, source in aliases.items(): - namespace[name] = import_name(name, source, namespace) - - for name, source in deprecated_aliases.items(): - namespace[name] = import_name(name, source, namespace) + def __getattr__(name: str) -> Any: + assert aliases is not None # mypy cannot figure this out + try: + source = aliases[name] + except KeyError: + pass + else: + return import_name(name, source, namespace) + + assert deprecated_aliases is not None # mypy cannot figure this out + try: + source = deprecated_aliases[name] + except KeyError: + pass + else: + warnings.warn( + f"{package}.{name} is deprecated", + DeprecationWarning, + stacklevel=2, + ) + return import_name(name, source, namespace) + + raise AttributeError(f"module {package!r} has no attribute {name!r}") + + namespace["__getattr__"] = __getattr__ + + def __dir__() -> Iterable[str]: + return sorted(namespace_set | aliases_set | deprecated_aliases_set) + + namespace["__dir__"] = __dir__ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index e693bbd2f..1daf3a9ad 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -760,12 +760,7 @@ def is_serving(self) -> bool: Tell whether the server is accepting new connections or shutting down. """ - try: - # Python ≥ 3.7 - return self.server.is_serving() - except AttributeError: # pragma: no cover - # Python < 3.7 - return self.server.sockets is not None + return self.server.is_serving() def close(self) -> None: """ @@ -815,7 +810,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [ - asyncio.ensure_future(websocket.close(1001)) + asyncio.create_task(websocket.close(1001)) for websocket in self.websockets ], loop=self.loop if sys.version_info[:2] < (3, 8) else None, diff --git a/src/websockets/typing.py b/src/websockets/typing.py index ca66a8c54..630a9fbe3 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -4,19 +4,13 @@ __all__ = ["Data", "Origin", "ExtensionHeader", "ExtensionParameter", "Subprotocol"] Data = Union[str, bytes] - -Data__doc__ = """ +Data.__doc__ = """ Types supported in a WebSocket message: - :class:`str` for text messages - :class:`bytes` for binary messages """ -# Remove try / except when dropping support for Python < 3.7 -try: - Data.__doc__ = Data__doc__ -except AttributeError: # pragma: no cover - pass Origin = NewType("Origin", str) @@ -28,19 +22,11 @@ ExtensionParameter = Tuple[str, Optional[str]] -ExtensionParameter__doc__ = """Parameter of a WebSocket extension""" -try: - ExtensionParameter.__doc__ = ExtensionParameter__doc__ -except AttributeError: # pragma: no cover - pass +ExtensionParameter.__doc__ = """Parameter of a WebSocket extension""" ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] -ExtensionHeader__doc__ = """Extension in a Sec-WebSocket-Extensions header""" -try: - ExtensionHeader.__doc__ = ExtensionHeader__doc__ -except AttributeError: # pragma: no cover - pass +ExtensionHeader.__doc__ = """Extension in a Sec-WebSocket-Extensions header""" Subprotocol = NewType("Subprotocol", str) diff --git a/tox.ini b/tox.ini index b5488e5b0..e74c979ba 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36,py37,py38,py39,coverage,black,flake8,isort,mypy +envlist = py37,py38,py39,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} From ab6f4382ec1cd122b3a515601d413a8f247ea79e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 09:42:50 +0200 Subject: [PATCH 005/762] Remove use of get_event_loop. Refs #916 - get_event_loop is deprecated in Python 3.10. Fix #534. --- README.rst | 15 +++++++-------- compliance/test_client.py | 3 +-- compliance/test_server.py | 14 ++++++++------ docs/faq.rst | 3 ++- docs/heroku.rst | 7 ++++--- example/basic_auth_client.py | 2 +- example/basic_auth_server.py | 17 +++++++++-------- example/client.py | 2 +- example/counter.py | 7 +++---- example/echo.py | 7 ++++--- example/health_check_server.py | 11 ++++++----- example/hello.py | 2 +- example/secure_client.py | 2 +- example/secure_server.py | 11 ++++++----- example/server.py | 7 ++++--- example/show_time.py | 7 ++++--- example/shutdown_client.py | 4 ++-- example/shutdown_server.py | 15 ++++++--------- example/unix_client.py | 2 +- example/unix_server.py | 9 +++++---- performance/mem_client.py | 2 +- performance/mem_server.py | 17 ++++++++--------- 22 files changed, 85 insertions(+), 81 deletions(-) diff --git a/README.rst b/README.rst index 7a163f5b7..5db7d9d0b 100644 --- a/README.rst +++ b/README.rst @@ -45,15 +45,14 @@ Here's how a client sends and receives messages: #!/usr/bin/env python import asyncio - import websockets + from websockets import connect async def hello(uri): - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: await websocket.send("Hello world!") await websocket.recv() - asyncio.get_event_loop().run_until_complete( - hello('ws://localhost:8765')) + asyncio.run(hello('ws://localhost:8765')) And here's an echo server: @@ -62,15 +61,15 @@ And here's an echo server: #!/usr/bin/env python import asyncio - import websockets + from websockets import serve async def echo(websocket, path): async for message in websocket: await websocket.send(message) - asyncio.get_event_loop().run_until_complete( - websockets.serve(echo, 'localhost', 8765)) - asyncio.get_event_loop().run_forever() + async def main(): + async with serve(echo, 'localhost', 8765): + await asyncio.Future() # run forever Does that look good? diff --git a/compliance/test_client.py b/compliance/test_client.py index 5fd0f4b4f..1ed4d711e 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -45,5 +45,4 @@ async def run_tests(server, agent): await update_reports(server, agent) -main = run_tests(SERVER, urllib.parse.quote(AGENT)) -asyncio.get_event_loop().run_until_complete(main) +asyncio.run(run_tests(SERVER, urllib.parse.quote(AGENT))) diff --git a/compliance/test_server.py b/compliance/test_server.py index 8020f68d3..58e357bc7 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -18,10 +18,12 @@ async def echo(ws, path): await ws.send(msg) -start_server = websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1) +async def main(): + with websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): + try: + await asyncio.Future() + except KeyboardInterrupt: + pass -try: - asyncio.get_event_loop().run_until_complete(start_server) - asyncio.get_event_loop().run_forever() -except KeyboardInterrupt: - pass + +asyncio.run(main) diff --git a/docs/faq.rst b/docs/faq.rst index ff91105b4..20b74ba98 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -274,7 +274,8 @@ Can I use ``websockets`` synchronously, without ``async`` / ``await``? ...................................................................... You can convert every asynchronous call to a synchronous call by wrapping it -in ``asyncio.get_event_loop().run_until_complete(...)``. +in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this +is deprecated as of Python 3.10. If this turns out to be impractical, you should use another library. diff --git a/docs/heroku.rst b/docs/heroku.rst index 31c4b3f19..5ef4b120e 100644 --- a/docs/heroku.rst +++ b/docs/heroku.rst @@ -52,10 +52,11 @@ Here's the implementation of the app, an echo server. Save it in a file called async for message in websocket: await websocket.send(message) - start_server = websockets.serve(echo, "", int(os.environ["PORT"])) + async def main(): + async with websockets.serve(echo, "", int(os.environ["PORT"])): + await asyncio.Future() # run forever - asyncio.get_event_loop().run_until_complete(start_server) - asyncio.get_event_loop().run_forever() + asyncio.run(main) The server relies on the ``$PORT`` environment variable to tell on which port it will listen, according to Heroku's conventions. diff --git a/example/basic_auth_client.py b/example/basic_auth_client.py index cc94dbe4b..164732152 100755 --- a/example/basic_auth_client.py +++ b/example/basic_auth_client.py @@ -11,4 +11,4 @@ async def hello(): greeting = await websocket.recv() print(greeting) -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/basic_auth_server.py b/example/basic_auth_server.py index 6740d5798..a11910445 100755 --- a/example/basic_auth_server.py +++ b/example/basic_auth_server.py @@ -9,12 +9,13 @@ async def hello(websocket, path): greeting = f"Hello {websocket.username}!" await websocket.send(greeting) -start_server = websockets.serve( - hello, "localhost", 8765, - create_protocol=websockets.basic_auth_protocol_factory( - realm="example", credentials=("mary", "p@ssw0rd") - ), -) +async def main(): + async with websockets.serve( + hello, "localhost", 8765, + create_protocol=websockets.basic_auth_protocol_factory( + realm="example", credentials=("mary", "p@ssw0rd") + ), + ): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/client.py b/example/client.py index 4f969c478..e39df81f7 100755 --- a/example/client.py +++ b/example/client.py @@ -16,4 +16,4 @@ async def hello(): greeting = await websocket.recv() print(f"< {greeting}") -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/counter.py b/example/counter.py index 239ec203a..81cbdb55c 100755 --- a/example/counter.py +++ b/example/counter.py @@ -63,7 +63,6 @@ async def counter(websocket, path): await unregister(websocket) -start_server = websockets.serve(counter, "localhost", 6789) - -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +async def main(): + async with websockets.serve(counter, "localhost", 6789): + await asyncio.Future() # run forever diff --git a/example/echo.py b/example/echo.py index b7ca38d32..b285f1664 100755 --- a/example/echo.py +++ b/example/echo.py @@ -7,7 +7,8 @@ async def echo(websocket, path): async for message in websocket: await websocket.send(message) -start_server = websockets.serve(echo, "localhost", 8765) +async def main(): + async with websockets.serve(echo, "localhost", 8765): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/health_check_server.py b/example/health_check_server.py index 417063fce..bb861fad3 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -14,9 +14,10 @@ async def echo(websocket, path): async for message in websocket: await websocket.send(message) -start_server = websockets.serve( - echo, "localhost", 8765, process_request=health_check -) +async def main(): + async with websockets.serve( + echo, "localhost", 8765, process_request=health_check + ): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/hello.py b/example/hello.py index 6c9c839d8..96095dd02 100755 --- a/example/hello.py +++ b/example/hello.py @@ -9,4 +9,4 @@ async def hello(): await websocket.send("Hello world!") await websocket.recv() -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/secure_client.py b/example/secure_client.py index 54971b984..455b6492a 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -24,4 +24,4 @@ async def hello(): greeting = await websocket.recv() print(f"< {greeting}") -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/secure_server.py b/example/secure_server.py index 2a00bdb50..cc1cea876 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -20,9 +20,10 @@ async def hello(websocket, path): localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") ssl_context.load_cert_chain(localhost_pem) -start_server = websockets.serve( - hello, "localhost", 8765, ssl=ssl_context -) +async def main(): + async with websockets.serve( + hello, "localhost", 8765, ssl=ssl_context + ): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/server.py b/example/server.py index c8ab69971..3ad6c8226 100755 --- a/example/server.py +++ b/example/server.py @@ -14,7 +14,8 @@ async def hello(websocket, path): await websocket.send(greeting) print(f"> {greeting}") -start_server = websockets.serve(hello, "localhost", 8765) +async def main(): + async with websockets.serve(hello, "localhost", 8765): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/show_time.py b/example/show_time.py index e5d6ac9aa..2f2ad897a 100755 --- a/example/show_time.py +++ b/example/show_time.py @@ -13,7 +13,8 @@ async def time(websocket, path): await websocket.send(now) await asyncio.sleep(random.random() * 3) -start_server = websockets.serve(time, "127.0.0.1", 5678) +async def main(): + async with websockets.serve(time, "localhost", 5678): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/example/shutdown_client.py b/example/shutdown_client.py index f21c0f6fa..ba1287801 100755 --- a/example/shutdown_client.py +++ b/example/shutdown_client.py @@ -7,8 +7,8 @@ async def client(): uri = "ws://localhost:8765" async with websockets.connect(uri) as websocket: + loop = asyncio.get_running_loop() # Close the connection when receiving SIGTERM. - loop = asyncio.get_event_loop() loop.add_signal_handler( signal.SIGTERM, loop.create_task, websocket.close()) @@ -16,4 +16,4 @@ async def client(): async for message in websocket: ... -asyncio.get_event_loop().run_until_complete(client()) +asyncio.run(client()) diff --git a/example/shutdown_server.py b/example/shutdown_server.py index 86846abe7..5732313cb 100755 --- a/example/shutdown_server.py +++ b/example/shutdown_server.py @@ -8,15 +8,12 @@ async def echo(websocket, path): async for message in websocket: await websocket.send(message) -async def echo_server(stop): +async def server(): + loop = asyncio.get_running_loop() + stop = loop.create_future() + # Set the stop condition when receiving SIGTERM. + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with websockets.serve(echo, "localhost", 8765): await stop -loop = asyncio.get_event_loop() - -# The stop condition is set when receiving SIGTERM. -stop = loop.create_future() -loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - -# Run the server until the stop condition is met. -loop.run_until_complete(echo_server(stop)) +asyncio.run(server()) diff --git a/example/unix_client.py b/example/unix_client.py index 577135b3d..434638c80 100755 --- a/example/unix_client.py +++ b/example/unix_client.py @@ -16,4 +16,4 @@ async def hello(): greeting = await websocket.recv() print(f"< {greeting}") -asyncio.get_event_loop().run_until_complete(hello()) +asyncio.run(hello()) diff --git a/example/unix_server.py b/example/unix_server.py index a6ec0168a..80bc824cf 100755 --- a/example/unix_server.py +++ b/example/unix_server.py @@ -15,8 +15,9 @@ async def hello(websocket, path): await websocket.send(greeting) print(f"> {greeting}") -socket_path = os.path.join(os.path.dirname(__file__), "socket") -start_server = websockets.unix_serve(hello, socket_path) +async def main(): + socket_path = os.path.join(os.path.dirname(__file__), "socket") + async with websockets.unix_serve(hello, socket_path): + await asyncio.Future() # run forever -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() +asyncio.run(main) diff --git a/performance/mem_client.py b/performance/mem_client.py index 890216edf..6eab690d8 100644 --- a/performance/mem_client.py +++ b/performance/mem_client.py @@ -43,7 +43,7 @@ async def mem_client(client): await asyncio.sleep(CLIENTS * INTERVAL) -asyncio.get_event_loop().run_until_complete( +asyncio.run( asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) ) diff --git a/performance/mem_server.py b/performance/mem_server.py index 0a4a29f76..81490a0e7 100644 --- a/performance/mem_server.py +++ b/performance/mem_server.py @@ -31,7 +31,11 @@ async def handler(ws, path): await asyncio.sleep(CLIENTS * INTERVAL) -async def mem_server(stop): +async def mem_server(): + loop = asyncio.get_running_loop() + stop = loop.create_future() + # Set the stop condition when receiving SIGTERM. + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with websockets.serve( handler, "localhost", @@ -44,19 +48,14 @@ async def mem_server(stop): ) ], ): + tracemalloc.start() await stop -loop = asyncio.get_event_loop() +asyncio.run(mem_server()) -stop = loop.create_future() -loop.add_signal_handler(signal.SIGINT, stop.set_result, None) -tracemalloc.start() - -loop.run_until_complete(mem_server(stop)) - -# First connection incurs non-representative setup costs. +# First connection may incur non-representative setup costs. del MEM_SIZE[0] print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") From 08d8011132ba038b3f6c4d591189b57af4c9f147 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 May 2021 09:19:16 +0200 Subject: [PATCH 006/762] Add support for Python 3.10. --- compliance/test_server.py | 2 +- docs/api/client.rst | 6 +- docs/api/server.rst | 6 +- docs/changelog.rst | 8 + docs/heroku.rst | 2 +- example/basic_auth_server.py | 2 +- example/echo.py | 2 +- example/health_check_server.py | 2 +- example/secure_server.py | 2 +- example/server.py | 2 +- example/show_time.py | 2 +- example/unix_server.py | 2 +- setup.cfg | 2 +- setup.py | 1 + src/websockets/exceptions.py | 2 +- src/websockets/legacy/client.py | 8 +- src/websockets/legacy/compatibility.py | 22 ++ src/websockets/legacy/protocol.py | 30 ++- src/websockets/legacy/server.py | 19 +- tests/legacy/test_client_server.py | 272 ++++++++++++------------- tests/legacy/test_protocol.py | 41 ++-- tests/legacy/utils.py | 11 +- tox.ini | 2 +- 23 files changed, 235 insertions(+), 213 deletions(-) create mode 100644 src/websockets/legacy/compatibility.py diff --git a/compliance/test_server.py b/compliance/test_server.py index 58e357bc7..14ac90fe6 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -26,4 +26,4 @@ async def main(): pass -asyncio.run(main) +asyncio.run(main()) diff --git a/docs/api/client.rst b/docs/api/client.rst index f969227a9..af341b2ba 100644 --- a/docs/api/client.rst +++ b/docs/api/client.rst @@ -6,16 +6,16 @@ Client Opening a connection -------------------- - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) :async: - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) :async: Using a connection ------------------ - .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None) .. autoattribute:: local_address diff --git a/docs/api/server.rst b/docs/api/server.rst index 16c8f6359..849f03fab 100644 --- a/docs/api/server.rst +++ b/docs/api/server.rst @@ -6,10 +6,10 @@ Server Starting a server ----------------- - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) :async: - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) :async: Stopping a server @@ -25,7 +25,7 @@ Server Using a connection ------------------ - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) .. autoattribute:: local_address diff --git a/docs/changelog.rst b/docs/changelog.rst index 35f292e49..70102cade 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -34,6 +34,14 @@ They may change at any time. **Version 10.0 drops compatibility with Python 3.6.** +.. note:: + + **Version 10.0 deprecates the** ``loop`` **parameter from all APIs for the + same reasons the same change was made in Python 3.8. See the release notes + of Python 3.10 for details.** + +* Added compatibility with Python 3.10. + 9.0.1 ..... diff --git a/docs/heroku.rst b/docs/heroku.rst index 5ef4b120e..8af2ebd3d 100644 --- a/docs/heroku.rst +++ b/docs/heroku.rst @@ -56,7 +56,7 @@ Here's the implementation of the app, an echo server. Save it in a file called async with websockets.serve(echo, "", int(os.environ["PORT"])): await asyncio.Future() # run forever - asyncio.run(main) + asyncio.run(main()) The server relies on the ``$PORT`` environment variable to tell on which port it will listen, according to Heroku's conventions. diff --git a/example/basic_auth_server.py b/example/basic_auth_server.py index a11910445..532c5bc51 100755 --- a/example/basic_auth_server.py +++ b/example/basic_auth_server.py @@ -18,4 +18,4 @@ async def main(): ): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/echo.py b/example/echo.py index b285f1664..024f8d8ac 100755 --- a/example/echo.py +++ b/example/echo.py @@ -11,4 +11,4 @@ async def main(): async with websockets.serve(echo, "localhost", 8765): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/health_check_server.py b/example/health_check_server.py index bb861fad3..2565f9c48 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -20,4 +20,4 @@ async def main(): ): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/secure_server.py b/example/secure_server.py index cc1cea876..55b5a4231 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -26,4 +26,4 @@ async def main(): ): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/server.py b/example/server.py index 3ad6c8226..98dbb5acd 100755 --- a/example/server.py +++ b/example/server.py @@ -18,4 +18,4 @@ async def main(): async with websockets.serve(hello, "localhost", 8765): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/show_time.py b/example/show_time.py index 2f2ad897a..8e39f1776 100755 --- a/example/show_time.py +++ b/example/show_time.py @@ -17,4 +17,4 @@ async def main(): async with websockets.serve(time, "localhost", 5678): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/example/unix_server.py b/example/unix_server.py index 80bc824cf..223d97301 100755 --- a/example/unix_server.py +++ b/example/unix_server.py @@ -20,4 +20,4 @@ async def main(): async with websockets.unix_serve(hello, socket_path): await asyncio.Future() # run forever -asyncio.run(main) +asyncio.run(main()) diff --git a/setup.cfg b/setup.cfg index b15a7515c..d8877aa2e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py37.py38.py39 +python-tag = py37.py38.py39.py310 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index d493a12f1..b2d07737d 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', ], package_dir = {'': 'src'}, package_data = {'websockets': ['py.typed']}, diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index e0860c743..9ab9d3ebe 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -314,7 +314,7 @@ def __init__( self.status = status self.headers = Headers(headers) self.body = body - message = f"HTTP {status}, {len(self.headers)} headers, {len(body)} bytes" + message = f"HTTP {status:d}, {len(self.headers)} headers, {len(body)} bytes" super().__init__(message) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 1b5bd303f..b77b4e86d 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -33,6 +33,7 @@ from ..http import USER_AGENT, build_host from ..typing import ExtensionHeader, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri +from .compatibility import asyncio_get_running_loop from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol @@ -472,7 +473,6 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, compression: Optional[str] = "deflate", origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, @@ -503,8 +503,12 @@ def __init__( # Backwards compatibility: recv() used to return None on closed connections legacy_recv: bool = kwargs.pop("legacy_recv", False) + # Backwards compatibility: the loop parameter used to be supported. + loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) if loop is None: - loop = asyncio.get_event_loop() + loop = asyncio_get_running_loop() + else: + warnings.warn("remove loop argument", DeprecationWarning) wsuri = parse_uri(uri) if wsuri.secure: diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py new file mode 100644 index 000000000..86f6715fd --- /dev/null +++ b/src/websockets/legacy/compatibility.py @@ -0,0 +1,22 @@ +import asyncio +import sys +from typing import Any, Dict + + +def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: + """ + Helper for the removal of the loop argument in Python 3.10. + + """ + return {"loop": loop} if sys.version_info[:2] < (3, 8) else {} + + +def asyncio_get_running_loop() -> asyncio.AbstractEventLoop: + """ + Helper for the deprecation of get_event_loop in Python 3.10. + + """ + if sys.version_info[:2] < (3, 10): # pragma: no cover + return asyncio.get_event_loop() + else: # pragma: no cover + return asyncio.get_running_loop() diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c155c3bd7..4e8958b60 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -14,7 +14,6 @@ import logging import random import struct -import sys import warnings from typing import ( Any, @@ -55,6 +54,7 @@ serialize_close, ) from ..typing import Data, Subprotocol +from .compatibility import loop_if_py_lt_38 from .framing import Frame @@ -107,13 +107,13 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, # The following arguments are kept only for backwards compatibility. host: Optional[str] = None, port: Optional[int] = None, secure: Optional[bool] = None, - legacy_recv: bool = False, timeout: Optional[float] = None, + legacy_recv: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: @@ -132,8 +132,8 @@ def __init__( self.read_limit = read_limit self.write_limit = write_limit - if loop is None: - loop = asyncio.get_event_loop() + assert loop is not None + # Remove when dropping Python < 3.10 - use get_running_loop instead. self.loop = loop self._host = host @@ -151,9 +151,7 @@ def __init__( self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._drain_lock = asyncio.Lock( - loop=loop if sys.version_info[:2] < (3, 8) else None - ) + self._drain_lock = asyncio.Lock(**loop_if_py_lt_38(loop)) # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -231,9 +229,7 @@ async def _drain(self) -> None: # pragma: no cover # write(...); yield from drain() # in a loop would never call connection_lost(), so it # would not see an error when the socket is closed. - await asyncio.sleep( - 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) + await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) await self._drain_helper() def connection_open(self) -> None: @@ -403,8 +399,8 @@ async def recv(self) -> Data: # pop_message_waiter and self.transfer_data_task. await asyncio.wait( [pop_message_waiter, self.transfer_data_task], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, return_when=asyncio.FIRST_COMPLETED, + **loop_if_py_lt_38(self.loop), ) finally: self._pop_message_waiter = None @@ -601,7 +597,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: await asyncio.wait_for( self.write_close_frame(serialize_close(code, reason)), self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers @@ -622,7 +618,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: await asyncio.wait_for( self.transfer_data_task, self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1052,7 +1048,7 @@ async def keepalive_ping(self) -> None: while True: await asyncio.sleep( self.ping_interval, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) # ping() raises CancelledError if the connection is closed, @@ -1068,7 +1064,7 @@ async def keepalive_ping(self) -> None: await asyncio.wait_for( pong_waiter, self.ping_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: logger.debug("%s ! timed out waiting for pong", self.side) @@ -1168,7 +1164,7 @@ async def wait_for_connection_lost(self) -> bool: await asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), self.close_timeout, - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: pass diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 1daf3a9ad..5bd7d0f56 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -10,7 +10,6 @@ import http import logging import socket -import sys import warnings from types import TracebackType from typing import ( @@ -43,6 +42,7 @@ from ..headers import build_extension, parse_extension, parse_subprotocol from ..http import USER_AGENT from ..typing import ExtensionHeader, Origin, Subprotocol +from .compatibility import asyncio_get_running_loop, loop_if_py_lt_38 from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol @@ -798,9 +798,7 @@ async def _close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep( - 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) + await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING conections with a HTTP 503 error. @@ -813,7 +811,7 @@ async def _close(self) -> None: asyncio.create_task(websocket.close(1001)) for websocket in self.websockets ], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) # Wait until all connection handlers are complete. @@ -822,7 +820,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - loop=self.loop if sys.version_info[:2] < (3, 8) else None, + **loop_if_py_lt_38(self.loop), ) # Tell wait_closed() to return. @@ -953,7 +951,6 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - loop: Optional[asyncio.AbstractEventLoop] = None, compression: Optional[str] = "deflate", origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -990,10 +987,14 @@ def __init__( # Backwards compatibility: recv() used to return None on closed connections legacy_recv: bool = kwargs.pop("legacy_recv", False) + # Backwards compatibility: the loop parameter used to be supported. + loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) if loop is None: - loop = asyncio.get_event_loop() + loop = asyncio_get_running_loop() + else: + warnings.warn("remove loop argument", DeprecationWarning) - ws_server = WebSocketServer(loop) + ws_server = WebSocketServer(loop=loop) secure = kwargs.get("ssl") is not None diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 499ea1d59..80c9d56bb 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -6,6 +6,7 @@ import random import socket import ssl +import sys import tempfile import unittest import unittest.mock @@ -51,7 +52,7 @@ testcert = bytes(pathlib.Path(__file__).parent.with_name("test_localhost.pem")) -async def handler(ws, path): +async def default_handler(ws, path): if path == "/deprecated_attributes": await ws.recv() # delay that allows catching warnings await ws.send(repr((ws.host, ws.port, ws.secure))) @@ -84,9 +85,9 @@ def temp_test_server(test, **kwargs): @contextlib.contextmanager def temp_test_redirecting_server( - test, status, include_location=True, force_insecure=False + test, status, include_location=True, force_insecure=False, **kwargs ): - test.start_redirecting_server(status, include_location, force_insecure) + test.start_redirecting_server(status, include_location, force_insecure, **kwargs) try: yield finally: @@ -206,21 +207,36 @@ def server_context(self): return None def start_server(self, deprecation_warnings=None, **kwargs): + handler = kwargs.pop("handler", default_handler) # Disable compression by default in tests. kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) + # Python 3.10 dislikes not having a running event loop + if sys.version_info[:2] >= (3, 10): # pragma: no cover + kwargs.setdefault("loop", self.loop) with warnings.catch_warnings(record=True) as recorded_warnings: start_server = serve(handler, "localhost", 0, **kwargs) self.server = self.loop.run_until_complete(start_server) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings + if sys.version_info[:2] >= (3, 10): # pragma: no cover + expected_warnings += ["remove loop argument"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_redirecting_server( - self, status, include_location=True, force_insecure=False + self, + status, + include_location=True, + force_insecure=False, + deprecation_warnings=None, + **kwargs, ): + # Python 3.10 dislikes not having a running event loop + if sys.version_info[:2] >= (3, 10): # pragma: no cover + kwargs.setdefault("loop", self.loop) + async def process_request(path, headers): server_uri = get_server_uri(self.server, self.secure, path) if force_insecure: @@ -228,16 +244,23 @@ async def process_request(path, headers): headers = {"Location": server_uri} if include_location else [] return status, headers, b"" - start_server = serve( - handler, - "localhost", - 0, - compression=None, - ping_interval=None, - process_request=process_request, - ssl=self.server_context, - ) - self.redirecting_server = self.loop.run_until_complete(start_server) + with warnings.catch_warnings(record=True) as recorded_warnings: + start_server = serve( + default_handler, + "localhost", + 0, + compression=None, + ping_interval=None, + process_request=process_request, + ssl=self.server_context, + **kwargs, + ) + self.redirecting_server = self.loop.run_until_complete(start_server) + + expected_warnings = [] if deprecation_warnings is None else deprecation_warnings + if sys.version_info[:2] >= (3, 10): # pragma: no cover + expected_warnings += ["remove loop argument"] + self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_client( self, resource_name="/", user_info=None, deprecation_warnings=None, **kwargs @@ -246,6 +269,10 @@ def start_client( kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) + # Python 3.10 dislikes not having a running event loop + if sys.version_info[:2] >= (3, 10): # pragma: no cover + kwargs.setdefault("loop", self.loop) + secure = kwargs.get("ssl") is not None try: server_uri = kwargs.pop("uri") @@ -258,6 +285,8 @@ def start_client( self.client = self.loop.run_until_complete(start_client) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings + if sys.version_info[:2] >= (3, 10): # pragma: no cover + expected_warnings += ["remove loop argument"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): @@ -376,7 +405,12 @@ def test_redirect(self): self.assertEqual(reply, "Hello!") def test_infinite_redirect(self): - with temp_test_redirecting_server(self, http.HTTPStatus.FOUND): + with temp_test_redirecting_server( + self, + http.HTTPStatus.FOUND, + loop=self.loop, + deprecation_warnings=["remove loop argument"], + ): self.server = self.redirecting_server with self.assertRaises(InvalidHandshake): with temp_test_client(self): @@ -385,15 +419,23 @@ def test_infinite_redirect(self): @with_server() def test_redirect_missing_location(self): with temp_test_redirecting_server( - self, http.HTTPStatus.FOUND, include_location=False + self, + http.HTTPStatus.FOUND, + include_location=False, + loop=self.loop, + deprecation_warnings=["remove loop argument"], ): with self.assertRaises(InvalidHeader): with temp_test_client(self): self.fail("Did not raise") # pragma: no cover def test_explicit_event_loop(self): - with self.temp_server(loop=self.loop): - with self.temp_client(loop=self.loop): + with self.temp_server( + loop=self.loop, deprecation_warnings=["remove loop argument"] + ): + with self.temp_client( + loop=self.loop, deprecation_warnings=["remove loop argument"] + ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") @@ -460,12 +502,19 @@ def test_unix_socket(self): path = bytes(pathlib.Path(temp_dir) / "websockets") # Like self.start_server() but with unix_serve(). - unix_server = unix_serve(handler, path) - self.server = self.loop.run_until_complete(unix_server) + with warnings.catch_warnings(record=True) as recorded_warnings: + unix_server = unix_serve(default_handler, path, loop=self.loop) + self.server = self.loop.run_until_complete(unix_server) + self.assertDeprecationWarnings(recorded_warnings, ["remove loop argument"]) + try: # Like self.start_client() but with unix_connect() - unix_client = unix_connect(path) - self.client = self.loop.run_until_complete(unix_client) + with warnings.catch_warnings(record=True) as recorded_warnings: + unix_client = unix_connect(path, loop=self.loop) + self.client = self.loop.run_until_complete(unix_client) + self.assertDeprecationWarnings( + recorded_warnings, ["remove loop argument"] + ) try: self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) @@ -1214,7 +1263,9 @@ class SecureClientServerTests( @with_server() def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): - connect(get_server_uri(self.server, secure=False), ssl=self.client_context) + self.start_client( + uri=get_server_uri(self.server, secure=False), ssl=self.client_context + ) @with_server() def test_redirect_insecure(self): @@ -1226,107 +1277,59 @@ def test_redirect_insecure(self): self.fail("Did not raise") # pragma: no cover -class ClientServerOriginTests(AsyncioTestCase): +class ClientServerOriginTests(ClientServerTestsMixin, AsyncioTestCase): + @with_server(origins=["http://localhost"]) + @with_client(origin="http://localhost") def test_checking_origin_succeeds(self): - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=["http://localhost"]) - ) - client = self.loop.run_until_complete( - connect(get_server_uri(server), origin="http://localhost") - ) - - self.loop.run_until_complete(client.send("Hello!")) - self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") - - self.loop.run_until_complete(client.close()) - server.close() - self.loop.run_until_complete(server.wait_closed()) + self.loop.run_until_complete(self.client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") + @with_server(origins=["http://localhost"]) def test_checking_origin_fails(self): - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=["http://localhost"]) - ) with self.assertRaisesRegex( InvalidHandshake, "server rejected WebSocket connection: HTTP 403" ): - self.loop.run_until_complete( - connect(get_server_uri(server), origin="http://otherhost") - ) - - server.close() - self.loop.run_until_complete(server.wait_closed()) + self.start_client(origin="http://otherhost") + @with_server(origins=["http://localhost"]) def test_checking_origins_fails_with_multiple_headers(self): - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=["http://localhost"]) - ) with self.assertRaisesRegex( InvalidHandshake, "server rejected WebSocket connection: HTTP 400" ): - self.loop.run_until_complete( - connect( - get_server_uri(server), - origin="http://localhost", - extra_headers=[("Origin", "http://otherhost")], - ) + self.start_client( + origin="http://localhost", + extra_headers=[("Origin", "http://otherhost")], ) - server.close() - self.loop.run_until_complete(server.wait_closed()) - + @with_server(origins=[None]) + @with_client() def test_checking_lack_of_origin_succeeds(self): - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=[None]) - ) - client = self.loop.run_until_complete(connect(get_server_uri(server))) - - self.loop.run_until_complete(client.send("Hello!")) - self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") - - self.loop.run_until_complete(client.close()) - server.close() - self.loop.run_until_complete(server.wait_closed()) + self.loop.run_until_complete(self.client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") + @with_server(origins=[""]) + @with_client(deprecation_warnings=["use None instead of '' in origins"]) def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): - with warnings.catch_warnings(record=True) as recorded_warnings: - server = self.loop.run_until_complete( - serve(handler, "localhost", 0, origins=[""]) - ) - client = self.loop.run_until_complete(connect(get_server_uri(server))) - - self.assertDeprecationWarnings( - recorded_warnings, ["use None instead of '' in origins"] - ) - - self.loop.run_until_complete(client.send("Hello!")) - self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!") - - self.loop.run_until_complete(client.close()) - server.close() - self.loop.run_until_complete(server.wait_closed()) + self.loop.run_until_complete(self.client.send("Hello!")) + self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") -class YieldFromTests(AsyncioTestCase): +class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): + @with_server() def test_client(self): - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - # @asyncio.coroutine is deprecated on Python ≥ 3.8 with warnings.catch_warnings(record=True): @asyncio.coroutine def run_client(): # Yield from connect. - client = yield from connect(get_server_uri(server)) + client = yield from connect(get_server_uri(self.server)) self.assertEqual(client.state, State.OPEN) yield from client.close() self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) - server.close() - self.loop.run_until_complete(server.wait_closed()) - def test_server(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 with warnings.catch_warnings(record=True): @@ -1334,7 +1337,7 @@ def test_server(self): @asyncio.coroutine def run_server(): # Yield from serve. - server = yield from serve(handler, "localhost", 0) + server = yield from serve(default_handler, "localhost", 0) self.assertTrue(server.sockets) server.close() yield from server.wait_closed() @@ -1343,27 +1346,22 @@ def run_server(): self.loop.run_until_complete(run_server()) -class AsyncAwaitTests(AsyncioTestCase): +class AsyncAwaitTests(ClientServerTestsMixin, AsyncioTestCase): + @with_server() def test_client(self): - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - async def run_client(): # Await connect. - client = await connect(get_server_uri(server)) + client = await connect(get_server_uri(self.server)) self.assertEqual(client.state, State.OPEN) await client.close() self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) - server.close() - self.loop.run_until_complete(server.wait_closed()) - def test_server(self): async def run_server(): # Await serve. - server = await serve(handler, "localhost", 0) + server = await serve(default_handler, "localhost", 0) self.assertTrue(server.sockets) server.close() await server.wait_closed() @@ -1372,14 +1370,12 @@ async def run_server(): self.loop.run_until_complete(run_server()) -class ContextManagerTests(AsyncioTestCase): +class ContextManagerTests(ClientServerTestsMixin, AsyncioTestCase): + @with_server() def test_client(self): - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - async def run_client(): # Use connect as an asynchronous context manager. - async with connect(get_server_uri(server)) as client: + async with connect(get_server_uri(self.server)) as client: self.assertEqual(client.state, State.OPEN) # Check that exiting the context manager closed the connection. @@ -1387,13 +1383,10 @@ async def run_client(): self.loop.run_until_complete(run_client()) - server.close() - self.loop.run_until_complete(server.wait_closed()) - def test_server(self): async def run_server(): # Use serve as an asynchronous context manager. - async with serve(handler, "localhost", 0) as server: + async with serve(default_handler, "localhost", 0) as server: self.assertTrue(server.sockets) # Check that exiting the context manager closed the server. @@ -1404,7 +1397,7 @@ async def run_server(): @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_server(self): async def run_server(path): - async with unix_serve(handler, path) as server: + async with unix_serve(default_handler, path) as server: self.assertTrue(server.sockets) # Check that exiting the context manager closed the server. @@ -1415,26 +1408,24 @@ async def run_server(path): self.loop.run_until_complete(run_server(path)) -class AsyncIteratorTests(AsyncioTestCase): +class AsyncIteratorTests(ClientServerTestsMixin, AsyncioTestCase): # This is a protocol-level feature, but since it's a high-level API, it is # much easier to exercise at the client or server level. MESSAGES = ["3", "2", "1", "Fire!"] - def test_iterate_on_messages(self): - async def handler(ws, path): - for message in self.MESSAGES: - await ws.send(message) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) + async def echo_handler(ws, path): + for message in AsyncIteratorTests.MESSAGES: + await ws.send(message) + @with_server(handler=echo_handler) + def test_iterate_on_messages(self): messages = [] async def run_client(): nonlocal messages - async with connect(get_server_uri(server)) as ws: + async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) @@ -1442,23 +1433,18 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - server.close() - self.loop.run_until_complete(server.wait_closed()) + async def echo_handler_1001(ws, path): + for message in AsyncIteratorTests.MESSAGES: + await ws.send(message) + await ws.close(1001) + @with_server(handler=echo_handler_1001) def test_iterate_on_messages_going_away_exit_ok(self): - async def handler(ws, path): - for message in self.MESSAGES: - await ws.send(message) - await ws.close(1001) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - messages = [] async def run_client(): nonlocal messages - async with connect(get_server_uri(server)) as ws: + async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) @@ -1466,23 +1452,18 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - server.close() - self.loop.run_until_complete(server.wait_closed()) + async def echo_handler_1011(ws, path): + for message in AsyncIteratorTests.MESSAGES: + await ws.send(message) + await ws.close(1011) + @with_server(handler=echo_handler_1011) def test_iterate_on_messages_internal_error_exit_not_ok(self): - async def handler(ws, path): - for message in self.MESSAGES: - await ws.send(message) - await ws.close(1011) - - start_server = serve(handler, "localhost", 0) - server = self.loop.run_until_complete(start_server) - messages = [] async def run_client(): nonlocal messages - async with connect(get_server_uri(server)) as ws: + async with connect(get_server_uri(self.server)) as ws: async for message in ws: messages.append(message) @@ -1490,6 +1471,3 @@ async def run_client(): self.loop.run_until_complete(run_client()) self.assertEqual(messages, self.MESSAGES) - - server.close() - self.loop.run_until_complete(server.wait_closed()) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index f928322ca..58444ce5a 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import sys import unittest import unittest.mock import warnings @@ -15,6 +14,7 @@ OP_TEXT, serialize_close, ) +from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame from websockets.legacy.protocol import State, WebSocketCommonProtocol @@ -87,7 +87,7 @@ class CommonTests: def setUp(self): super().setUp() # Disable pings to make it easier to test what frames are sent exactly. - self.protocol = WebSocketCommonProtocol(ping_interval=None) + self.protocol = WebSocketCommonProtocol(ping_interval=None, loop=self.loop) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) @@ -105,9 +105,7 @@ def make_drain_slow(self, delay=MS): original_drain = self.protocol._drain async def delayed_drain(): - await asyncio.sleep( - delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None - ) + await asyncio.sleep(delay, **loop_if_py_lt_38(self.loop)) await original_drain() self.protocol._drain = delayed_drain @@ -312,14 +310,13 @@ def assertCompletesWithin(self, min_time, max_time): def test_timeout_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: - protocol = WebSocketCommonProtocol(timeout=5) + protocol = WebSocketCommonProtocol(timeout=5, loop=self.loop) self.assertEqual(protocol.close_timeout, 5) - self.assertEqual(len(recorded_warnings), 1) - warning = recorded_warnings[0].message - self.assertEqual(str(warning), "rename timeout to close_timeout") - self.assertEqual(type(warning), DeprecationWarning) + self.assertDeprecationWarnings( + recorded_warnings, ["rename timeout to close_timeout"] + ) # Test public attributes. @@ -647,7 +644,14 @@ async def send_concurrent(): await asyncio.sleep(MS) await self.protocol.send(b"tea") - self.loop.run_until_complete(asyncio.gather(send_iterable(), send_concurrent())) + async def run_concurrently(): + await asyncio.gather( + send_iterable(), + send_concurrent(), + ) + + self.loop.run_until_complete(run_concurrently()) + self.assertFramesSent( (False, OP_TEXT, "ca".encode("utf-8")), (False, OP_CONT, "fé".encode("utf-8")), @@ -714,9 +718,14 @@ async def send_concurrent(): await asyncio.sleep(MS) await self.protocol.send(b"tea") - self.loop.run_until_complete( - asyncio.gather(send_async_iterable(), send_concurrent()) - ) + async def run_concurrently(): + await asyncio.gather( + send_async_iterable(), + send_concurrent(), + ) + + self.loop.run_until_complete(run_concurrently()) + self.assertFramesSent( (False, OP_TEXT, "ca".encode("utf-8")), (False, OP_CONT, "fé".encode("utf-8")), @@ -1098,7 +1107,9 @@ def restart_protocol_with_keepalive_ping( self.loop.run_until_complete(self.protocol.close()) # copied from setUp, but enables keepalive pings self.protocol = WebSocketCommonProtocol( - ping_interval=ping_interval, ping_timeout=ping_timeout + ping_interval=ping_interval, + ping_timeout=ping_timeout, + loop=self.loop, ) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 983a91edf..4d4306232 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -74,11 +74,12 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): Check recorded deprecation warnings match a list of expected messages. """ - self.assertEqual(len(recorded_warnings), len(expected_warnings)) - for recorded, expected in zip(recorded_warnings, expected_warnings): - actual = recorded.message - self.assertEqual(str(actual), expected) - self.assertEqual(type(actual), DeprecationWarning) + for recorded in recorded_warnings: + self.assertEqual(type(recorded.message), DeprecationWarning) + self.assertEqual( + set(str(recorded.message) for recorded in recorded_warnings), + set(expected_warnings), + ) # Unit for timeouts. May be increased on slow machines by setting the diff --git a/tox.ini b/tox.ini index e74c979ba..c243e9880 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py37,py38,py39,coverage,black,flake8,isort,mypy +envlist = py37,py38,py39,py310,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} From eb856f2ba31d1154e18db4ca33b0ab9586ae129c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 May 2021 09:21:18 +0200 Subject: [PATCH 007/762] Add missing asyncio.run in example. --- README.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 5db7d9d0b..fa42a9be3 100644 --- a/README.rst +++ b/README.rst @@ -52,7 +52,7 @@ Here's how a client sends and receives messages: await websocket.send("Hello world!") await websocket.recv() - asyncio.run(hello('ws://localhost:8765')) + asyncio.run(hello("ws://localhost:8765")) And here's an echo server: @@ -68,9 +68,11 @@ And here's an echo server: await websocket.send(message) async def main(): - async with serve(echo, 'localhost', 8765): + async with serve(echo, "localhost", 8765): await asyncio.Future() # run forever + asyncio.run(main()) + Does that look good? `Get started with the tutorial! `_ From 9834fca95204c517adedca2478a53058ecf72ae3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 May 2021 21:21:37 +0200 Subject: [PATCH 008/762] Use relative imports everywhere, for consistency. Fix #946. --- docs/api/extensions.rst | 2 +- docs/extensions.rst | 7 +++---- src/websockets/extensions/__init__.py | 4 ++++ src/websockets/frames.py | 6 +++--- src/websockets/legacy/framing.py | 6 +++--- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/docs/api/extensions.rst b/docs/api/extensions.rst index 635c5c426..71f015bb2 100644 --- a/docs/api/extensions.rst +++ b/docs/api/extensions.rst @@ -13,7 +13,7 @@ Per-Message Deflate Abstract classes ---------------- -.. automodule:: websockets.extensions.base +.. automodule:: websockets.extensions .. autoclass:: Extension :members: diff --git a/docs/extensions.rst b/docs/extensions.rst index 151a7e297..042ed3d9a 100644 --- a/docs/extensions.rst +++ b/docs/extensions.rst @@ -91,9 +91,8 @@ As a consequence, writing an extension requires implementing several classes: ``websockets`` provides abstract base classes for extension factories and extensions. See the API documentation for details on their methods: -* :class:`~base.ClientExtensionFactory` and - :class:`~base.ServerExtensionFactory` for extension factories, - -* :class:`~base.Extension` for extensions. +* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for + :extension factories, +* :class:`Extension` for extensions. diff --git a/src/websockets/extensions/__init__.py b/src/websockets/extensions/__init__.py index e69de29bb..02838b98a 100644 --- a/src/websockets/extensions/__init__.py +++ b/src/websockets/extensions/__init__.py @@ -0,0 +1,4 @@ +from .base import * + + +__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 71783e176..6e5ef1b73 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -103,7 +103,7 @@ def parse( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> Generator[None, None, "Frame"]: """ Read a WebSocket frame. @@ -172,7 +172,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> bytes: """ Write a WebSocket frame. @@ -338,4 +338,4 @@ def check_close(code: int) -> None: # at the bottom to allow circular import, because Extension depends on Frame -import websockets.extensions.base # isort:skip # noqa +from . import extensions # isort:skip # noqa diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index e41c295dd..627e6922c 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -31,7 +31,7 @@ async def read( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> "Frame": """ Read a WebSocket frame. @@ -102,7 +102,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, + extensions: Optional[Sequence["extensions.Extension"]] = None, ) -> None: """ Write a WebSocket frame. @@ -132,4 +132,4 @@ def write( # at the bottom to allow circular import, because Extension depends on Frame -import websockets.extensions.base # isort:skip # noqa +from .. import extensions # isort:skip # noqa From 60e9531eb978b41deee1db8552b27faa749aa515 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 3 May 2021 23:49:45 +0200 Subject: [PATCH 009/762] Document how to use with Django. Fix #890. --- docs/howto/django.rst | 276 +++++++++++++++++++++++++++++++ docs/index.rst | 1 + docs/spelling_wordlist.txt | 28 ++-- example/django/authentication.py | 53 ++++++ example/django/notifications.py | 70 ++++++++ example/django/signals.py | 23 +++ 6 files changed, 439 insertions(+), 12 deletions(-) create mode 100644 docs/howto/django.rst create mode 100644 example/django/authentication.py create mode 100644 example/django/notifications.py create mode 100644 example/django/signals.py diff --git a/docs/howto/django.rst b/docs/howto/django.rst new file mode 100644 index 000000000..fd170c387 --- /dev/null +++ b/docs/howto/django.rst @@ -0,0 +1,276 @@ +Using websockets with Django +============================ + +If you're looking at adding real-time capabilities to a Django project with +WebSocket, you have two main options. + +1. Using Django Channels_, a project adding WebSocket to Django, among other + features. This approach is fully supported by Django. However, it requires + switching to a new deployment architecture. + +2. Deploying a separate WebSocket server next to your Django project. This + technique is well suited when you need to add a small set of real-time + features — maybe a notification service — to a HTTP application. + +.. _Channels: https://channels.readthedocs.io/en/latest/ + +This guide shows how to implement the second technique with websockets. It +assumes familiarity with Django. + +Authenticating connections +-------------------------- + +Since the websockets server will run outside of Django, we need to connect it +to ``django.contrib.auth``. + +Our clients are running in browser. The `WebSocket API`_ doesn't support +setting `custom headers`_ so our options boil down to: + +* HTTP Basic Auth: this seems technically possible but isn't supported by + Firefox (`bug 1229443`_) so browser support is clearly insufficient. +* Sharing cookies: this is technically possible if there's a common parent + domain between the Django server (e.g. api.example.com) and the websockets + server (e.g. ws.example.com). However, there's a risk to share cookies too + widely (e.g. with anything under .example.com here). For authentication + cookies, this risk seems unacceptable. +* Sending an authentication ticket: Django generates a secure single-use token + with the user ID. The browser includes this token in the WebSocket URI when + it connects to the server in order to authenticate. It could also send the + ticket over the WebSocket connection in the first message, however this is a + bit more difficult to monitor, as you can't detect authentication failures + simply by looking at HTTP response codes. + +.. _custom headers: https://github.com/whatwg/html/issues/3062 +.. _WebSocket API: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API +.. _bug 1229443: https://bugzilla.mozilla.org/show_bug.cgi?id=1229443 + +To generate our authentication tokens, we'll use `django-sesame`_, a small +library designed exactly for this purpose. + + .. _django-sesame: https://github.com/aaugustin/django-sesame + +Add django-sesame to the dependencies of your Django project, install it and configure it in the settings of the project: + +.. code:: python + + AUTHENTICATION_BACKENDS = [ + "django.contrib.auth.backends.ModelBackend", + "sesame.backends.ModelBackend", + ] + +(If your project already uses another authentication backend than the default +``"django.contrib.auth.backends.ModelBackend"``, adjust accordingly.) + +You don't need ``"sesame.middleware.AuthenticationMiddleware"``. It is for +authenticating users in the Django server, while we're authenticating them in +the websockets server. + +We'd like our tokens to be valid for 30 seconds and usable only once. A +shorter lifespan is possible but it would make manual testing difficult. + +Configure django-sesame accordingly in the settings of your Django project: + +.. code:: python + + SESAME_MAX_AGE = 30 + SESAME_ONE_TIME = True + +Now you can generate tokens in a ``django-admin shell`` as follows: + +.. code:: pycon + + >>> from django.contrib.auth import get_user_model + >>> User = get_user_model() + >>> user = User.objects.get(username="") + >>> from sesame.utils import get_token + >>> get_token(user) + '' + +Keep this console open: since tokens are single use, we'll have to generate a +new token every time we want to test our server. + +Let's move on to the websockets server. Add websockets to the dependencies of +your Django project and install it. + +Now here's a way to implement authentication. We're taking advantage of the +:meth:`~websockets.server.WebSocketServerProtocol.process_request` hook to +authenticate requests. If authentication succeeds, we store the user as an +attribute of the connection in order to make it available to the connection +handler. If authentication fails, we return a HTTP 401 Unauthorized error. + +.. literalinclude:: ../../example/django/authentication.py + +Let's unpack this code. + +We're using Django in a `standalone script`_. This requires setting the +``DJANGO_SETTINGS_MODULE`` environment variable and calling ``django.setup()`` +before doing anything with Django. + +.. _standalone script: https://docs.djangoproject.com/en/stable/topics/settings/#calling-django-setup-is-required-for-standalone-django-usage + +We subclass :class:`~websockets.server.WebSocketServerProtocol` and override +:meth:`~websockets.server.WebSocketServerProtocol.process_request`, where: + +* We extract the token from the URL with the ``get_sesame()`` utility function + defined just above. If the token is missing, we return a HTTP 401 error. +* We authenticate the user with ``get_user()``, the API for `authentication + outside views`_. If authentication fails, we return a HTTP 401 error. + +.. _authentication outside views: https://github.com/aaugustin/django-sesame#authentication-outside-views + +When we call an API that makes a database query, we wrap the call in +:func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support +asynchronous I/O. We would block the event loop if we didn't run these calls +in a separate thread. :func:`~asyncio.to_thread` is available since Python +3.9; in earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. + +The connection handler accesses the logged-in user that we stored as an +attribute of the connection object so we can test that authentication works. + +Finally, we start a server with :func:`~websockets.serve`, with the +``create_protocol`` pointing to our subclass of +:class:`~websockets.server.WebSocketServerProtocol`. + +We're ready to test! + +Make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set to the +Python path to your settings module and start the websockets server. If you +saved the server implementation to a file called ``authentication.py``: + +.. code:: console + + $ python authentication.py + +Open a new shell, generate a new token — remember, they're only valid for +30 seconds — and use it to connect to your server: + +.. code:: console + + $ python -m websockets "ws://localhost:8888/?sesame=" + Connected to ws://localhost:8888/?sesame= + < Hello ! + Connection closed: code = 1000 (OK), no reason. + +It works! + +If we try to reuse the same token, the connection is now rejected: + +.. code:: console + + $ python -m websockets "ws://localhost:8888/?sesame=" + Failed to connect to ws://localhost:8888/?sesame=: + server rejected WebSocket connection: HTTP 401. + +You can also test from a browser by generating a new token and running the +following code in the JavaScript console of the browser: + +.. code:: javascript + + webSocket = new WebSocket("ws://localhost:8888/?sesame="); + webSocket.onmessage = (event) => console.log(event.data); + +Streaming events +---------------- + +We can connect and authenticate but our server doesn't do anything useful yet! + +Let's send a message every time any user makes any action in the admin. This +message will be broadcast to all users who can access the model on which the +action was made. This may be used for showing notifications to other users. + +Many use cases for WebSocket with Django follow a similar pattern. + +We need a event bus to enable communications between Django and websockets. +Both sides connect permanently to the bus. Then Django writes events and +websockets reads them. For the sake of simplicity, we'll rely on `Redis +Pub/Sub`_. + +.. _Redis Pub/Sub: https://redis.io/topics/pubsub + +Let's start by writing events. The easiest way to add Redis to a Django +project is by configuring a cache backend with `django-redis`_. This library +manages connections to Redis efficiently, persisting them between requests, +and provides an API to access the Redis connection directly. + +.. _django-redis: https://github.com/jazzband/django-redis + +Install Redis, add django-redis to the dependencies of your Django project, +install it and configure it in the settings of the project: + +.. code:: python + + CACHES = { + "default": { + "BACKEND": "django_redis.cache.RedisCache", + "LOCATION": "redis://127.0.0.1:6379/1", + }, + } + +If you already have a default cache, add a new one with a different name and +change ``get_redis_connection("default")`` in the code below to the same name. + +Add the following code to a module that is imported when your Django project +starts. Typically, you would put it in a ``signals.py`` module, which you +would import in the ``AppConfig.ready()`` method of one of your apps: + +.. literalinclude:: ../../example/django/signals.py + +This code runs every time the admin saves a ``LogEntry`` object to keep track +of a change. It extracts interesting data, serializes it to JSON, and writes +an event to Redis. + +Let's check that it works: + +.. code:: console + + $ redis-cli + 127.0.0.1:6379> SELECT 1 + OK + 127.0.0.1:6379[1]> SUBSCRIBE events + Reading messages... (press Ctrl-C to quit) + 1) "subscribe" + 2) "events" + 3) (integer) 1 + +Leave this command running, start the Django development server and make +changes in the admin: add, modify, or delete objects. You should see +corresponding events published to the ``"events"`` stream. + +Now let's turn to reading events and broadcasting them to connected clients. +We'll reuse our custom ``ServerProtocol`` class for authentication. Then we +need to add several features: + +* Keep track of connected clients so we can broadcast messages. +* Tell which content types the user has permission to view or to change. +* Connect to the message bus and read events. +* Broadcast these events to users who have corresponding permissions. + +Here's a complete implementation. + +.. literalinclude:: ../../example/django/notifications.py + +Since the ``get_content_types()`` function makes a database query, it is +wrapped inside ``asyncio.to_thread()``. It runs once when each WebSocket +connection is open; then its result is cached for the lifetime of the +connection. Indeed, running it for each message would trigger database queries +for all connected users at the same time, which could hurt the database. + +The connection handler merely registers the connection in a global variable, +associated to the list of content types for which events should be sent to +that connection, and waits until the client disconnects. + +The ``process_events()`` function reads events from Redis and broadcasts them +to all connections that should receive them. We don't care much if a sending a +notification fails — this happens when a connection drops between the moment +we iterate on connections and the moment the corresponding message is sent — +so we start a task with for each message and forget about it. Also, this means +we're immediately ready to process the next event, even if it takes time to +send a message to a slow client. + +Since Redis can publish a message to multiple subscribers, multiple instances +of this server can safely run in parallel. + +In theory, given enough servers, this design can scale to a hundred million +clients, since Redis can handle ten thousand servers and each server can +handle ten thousand clients. In practice, you would need a more scalable +message bus before reaching that scale, due to the volume of messages. diff --git a/docs/index.rst b/docs/index.rst index f0c5f8d00..257a806ea 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -59,6 +59,7 @@ These guides will help you build and deploy a ``websockets`` application. cheatsheet deployment extensions + howto/django heroku Reference diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index d7c744147..e168035c1 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1,16 +1,16 @@ +api attr augustin -Auth +auth awaitable aymeric +backend backpressure -Backpressure balancer balancers -Bitcoin +bitcoin bottlenecked bufferbloat -Bufferbloat bugfix bytestring bytestrings @@ -19,38 +19,42 @@ coroutine coroutines cryptocurrencies cryptocurrency -Ctrl +ctrl daemonize datastructures +django fractalideas IPv iterable keepalive KiB lifecycle -Lifecycle lookups MiB nginx +onmessage parsers permessage pong pongs -Pythonic +pythonic +redis +scalable serializers -Subclasses subclasses subclassing subprotocol subprotocols -Tidelift -TLS +tidelift +tls tox -Unparse +unparse unregister uple username virtualenv -websocket WebSocket +websocket websockets +ws +wss diff --git a/example/django/authentication.py b/example/django/authentication.py new file mode 100644 index 000000000..c0a061109 --- /dev/null +++ b/example/django/authentication.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +import asyncio +import http +import urllib.parse + +import django +import websockets + +django.setup() + +from sesame.utils import get_user + + +def get_sesame(path): + """Utility function to extract sesame token from request path.""" + query = urllib.parse.urlparse(path).query + params = urllib.parse.parse_qs(query) + sesame = params.get("sesame", []) + if len(sesame) == 1: + return sesame[0] + + +class ServerProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + """Authenticate users with a django-sesame token.""" + sesame = get_sesame(path) + if sesame is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = await asyncio.to_thread(get_user, sesame) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + +async def handler(websocket, path): + await websocket.send(f"Hello {websocket.user}!") + + +async def main(): + async with websockets.serve( + handler, + "localhost", + 8888, + create_protocol=ServerProtocol, + ): + await asyncio.Future() # run forever + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/django/notifications.py b/example/django/notifications.py new file mode 100644 index 000000000..ad2751d98 --- /dev/null +++ b/example/django/notifications.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +import asyncio +import json + +import aioredis +import django +import websockets + +django.setup() + +from django.contrib.contenttypes.models import ContentType + +# Reuse our custom protocol to authenticate connections +from authentication import ServerProtocol + + +CONNECTIONS = {} + + +def get_content_types(user): + """Return the set of IDs of content types visible by user.""" + # This does only three database queries because Django caches + # all permissions on the first call to user.has_perm(...). + return { + ct.id + for ct in ContentType.objects.all() + if user.has_perm(f"{ct.app_label}.view_{ct.model}") + or user.has_perm(f"{ct.app_label}.change_{ct.model}") + } + + +async def handler(websocket, path): + """Register connection in CONNECTIONS dict, until it's closed.""" + ct_ids = await asyncio.to_thread(get_content_types, websocket.user) + CONNECTIONS[websocket] = {"content_type_ids": ct_ids} + try: + await websocket.wait_closed() + finally: + del CONNECTIONS[websocket] + + +async def process_events(): + """Listen to events in Redis and process them.""" + redis = aioredis.from_url("redis://127.0.0.1:6379/1") + pubsub = redis.pubsub() + await pubsub.subscribe("events") + async for message in pubsub.listen(): + if message["type"] != "message": + continue + payload = message["data"].decode() + # Broadcast event to all users who have permissions to see it. + event = json.loads(payload) + for websocket, connection in CONNECTIONS.items(): + if event["content_type_id"] in connection["content_type_ids"]: + asyncio.create_task(websocket.send(payload)) + + +async def main(): + async with websockets.serve( + handler, + "localhost", + 8888, + create_protocol=ServerProtocol, + ): + await process_events() # runs forever + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/django/signals.py b/example/django/signals.py new file mode 100644 index 000000000..6dc827f72 --- /dev/null +++ b/example/django/signals.py @@ -0,0 +1,23 @@ +import json + +from django.contrib.admin.models import LogEntry +from django.db.models.signals import post_save +from django.dispatch import receiver + +from django_redis import get_redis_connection + + +@receiver(post_save, sender=LogEntry) +def publish_event(instance, **kwargs): + event = { + "model": instance.content_type.name, + "object": instance.object_repr, + "message": instance.get_change_message(), + "timestamp": instance.action_time.isoformat(), + "user": str(instance.user), + "content_type_id": instance.content_type_id, + "object_id": instance.object_id, + } + connection = get_redis_connection("default") + payload = json.dumps(event) + connection.publish("events", payload) From cceff406270c47037357a5adab9665bc05b3ab15 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 May 2021 08:19:38 +0200 Subject: [PATCH 010/762] Add FAQ about connection handler terminating early. Ref #948. --- docs/faq.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/faq.rst b/docs/faq.rst index 20b74ba98..8173f7b61 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -14,6 +14,23 @@ FAQ Server side ----------- +Why does my server close the connection prematurely? +.................................................... + +Your connection handler exits prematurely. Wait for the work to be finished +before returning. + +For example, if your handler has a structure similar to:: + + async def handler(websocket, path): + ... + asyncio.create_task(do_some_work()) + +change it to:: + + async def handler(websocket, path): + await do_some_work() + Why does the server close the connection after processing one message? ...................................................................... From aa79b6c057c5f11c85b0ac0f932e0abf59e8033c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 May 2021 08:22:03 +0200 Subject: [PATCH 011/762] Add FAQ about context manager terminating early. Ref #947. --- docs/faq.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/faq.rst b/docs/faq.rst index 8173f7b61..5f70de43a 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -144,6 +144,22 @@ If you need more, pick a HTTP server and run it separately. Client side ----------- +Why does my client close the connection prematurely? +.................................................... + +You're exiting the context manager prematurely. Wait for the work to be +finished before exiting. + +For example, if your code has a structure similar to:: + + async with connect(...) as websocket: + asyncio.create_task(do_some_work()) + +change it to:: + + async with connect(...) as websocket: + await do_some_work() + How do I close a connection properly? ..................................... From 088c59bb895f24e46e6606d702376c9e7a229d29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 May 2021 08:26:15 +0200 Subject: [PATCH 012/762] Update FAQ answer on Python 2. --- docs/faq.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/faq.rst b/docs/faq.rst index 5f70de43a..53da0f004 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -346,6 +346,8 @@ Is there a Python 2 version? No, there isn't. -websockets builds upon asyncio which requires Python 3. +Python 2 reached end of life on January 1st, 2020. + +Before that date, websockets required asyncio and therefore Python 3. From bc19676a06d410ba35363d4630fb417c24d90b66 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 14 May 2021 07:07:41 +0100 Subject: [PATCH 013/762] Add project_urls metadata (#943) --- setup.cfg | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.cfg b/setup.cfg index d8877aa2e..8c7a4984a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,11 @@ python-tag = py37.py38.py39.py310 [metadata] license_file = LICENSE +project_urls = + Changelog = https://websockets.readthedocs.io/en/stable/changelog.html + Documentation = https://websockets.readthedocs.io/ + Funding = https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme + Tracker = https://github.com/aaugustin/websockets/issues [flake8] ignore = E203,E731,F403,F405,W503 From ccfe98e5111a74d602df7a2a195951acf69057f0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 10:42:29 +0200 Subject: [PATCH 014/762] Add FAQ on blocking send loops. Refs #867 (and others). --- docs/faq.rst | 53 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/docs/faq.rst b/docs/faq.rst index 53da0f004..abd396cde 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -23,7 +23,6 @@ before returning. For example, if your handler has a structure similar to:: async def handler(websocket, path): - ... asyncio.create_task(do_some_work()) change it to:: @@ -185,7 +184,6 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../example/shutdown_client.py :emphasize-lines: 10-13 - How do I disable TLS/SSL certificate verification? .................................................. @@ -194,27 +192,50 @@ Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. :func:`connect` accepts the same arguments as :meth:`~asyncio.loop.create_connection`. -Both sides ----------- +asyncio usage +------------- How do I do two things in parallel? How do I integrate with another coroutine? .............................................................................. You must start two tasks, which the event loop will run concurrently. You can -achieve this with :func:`asyncio.gather` or :func:`asyncio.wait`. - -This is also part of learning asyncio and not specific to websockets. +achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. Keep track of the tasks and make sure they terminate or you cancel them when the connection terminates. -How do I create channels or topics? -................................... +Why does my program never receives any messages? +................................................ -websockets doesn't have built-in publish / subscribe for these use cases. +Your program runs a coroutine that never yield control to the event loop. The +coroutine that receives messages never gets a chance to run. -Depending on the scale of your service, a simple in-memory implementation may -do the job or you may need an external publish / subscribe component. +Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough +to yield control. Awaiting a coroutine may yield control, but there's no +guarantee that it will. + +For example, ``send()`` only yields control when send buffers are full, which +never happens in most practical cases. + +If you run a loop that contains only synchronous operations and a ``send()`` +call, you must yield control explicitly with :func:`asyncio.sleep`:: + + async def producer(websocket): + message = generate_next_message() + await websocket.send(message) + await asyncio.sleep(0) # yield control to the event loop + +:func:`asyncio.sleep` always suspends the current task, allowing other tasks +to run. This behavior is documented precisely because it isn't expected from +every coroutine. + +See `issue 867`_. + +.. _issue 867: https://github.com/aaugustin/websockets/issues/867 + + +Both sides +---------- What does ``ConnectionClosedError: code = 1006`` mean? ...................................................... @@ -315,6 +336,14 @@ If this turns out to be impractical, you should use another library. Miscellaneous ------------- +How do I create channels or topics? +................................... + +websockets doesn't have built-in publish / subscribe for these use cases. + +Depending on the scale of your service, a simple in-memory implementation may +do the job or you may need an external publish / subscribe component. + How do I set a timeout on ``recv()``? ..................................... From b8517b11f98582d4ed3c0bb0c20c5ecf1c31df47 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 14:18:14 +0200 Subject: [PATCH 015/762] Optimize default compression settings. --- .gitignore | 1 + benchmark/compression.py | 163 ++++++++++++++++++ {performance => benchmark}/mem_client.py | 19 +- {performance => benchmark}/mem_server.py | 16 +- docs/changelog.rst | 2 + docs/deployment.rst | 64 ++++--- .../extensions/permessage_deflate.py | 14 +- tests/extensions/test_permessage_deflate.py | 36 ++-- tests/legacy/test_client_server.py | 4 +- 9 files changed, 263 insertions(+), 56 deletions(-) create mode 100644 benchmark/compression.py rename {performance => benchmark}/mem_client.py (77%) rename {performance => benchmark}/mem_server.py (82%) diff --git a/.gitignore b/.gitignore index c23cf5210..ac68ff739 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ .mypy_cache .tox build/ +benchmark/corpus.pkl compliance/reports/ dist/ docs/_build/ diff --git a/benchmark/compression.py b/benchmark/compression.py new file mode 100644 index 000000000..15fb8653e --- /dev/null +++ b/benchmark/compression.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +import getpass +import json +import pickle +import subprocess +import sys +import time +import zlib + + +CORPUS_FILE = "corpus.pkl" + +REPEAT = 10 + +WB, ML = 12, 5 # defaults used as a reference + + +def _corpus(): + OAUTH_TOKEN = getpass.getpass("OAuth Token? ") + COMMIT_API = ( + f'curl -H "Authorization: token {OAUTH_TOKEN}" ' + f"https://api.github.com/repos/aaugustin/websockets/git/commits/:sha" + ) + + commits = [] + + head = subprocess.check_output("git rev-parse HEAD", shell=True).decode().strip() + todo = [head] + seen = set() + + while todo: + sha = todo.pop(0) + commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) + commits.append(commit) + seen.add(sha) + for parent in json.loads(commit)["parents"]: + sha = parent["sha"] + if sha not in seen and sha not in todo: + todo.append(sha) + time.sleep(1) # rate throttling + + return commits + + +def corpus(): + data = _corpus() + with open(CORPUS_FILE, "wb") as handle: + pickle.dump(data, handle) + + +def _benchmark(data): + size = {} + duration = {} + + for wbits in range(9, 16): + size[wbits] = {} + duration[wbits] = {} + + for memLevel in range(1, 10): + encoder = zlib.compressobj(wbits=-wbits, memLevel=memLevel) + encoded = [] + + t0 = time.perf_counter() + + for _ in range(REPEAT): + for item in data: + if isinstance(item, str): + item = item.encode("utf-8") + # Taken from PerMessageDeflate.encode + item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) + if item.endswith(b"\x00\x00\xff\xff"): + item = item[:-4] + encoded.append(item) + + t1 = time.perf_counter() + + size[wbits][memLevel] = sum(len(item) for item in encoded) + duration[wbits][memLevel] = (t1 - t0) / REPEAT + + raw_size = sum(len(item) for item in data) + + print("=" * 79) + print("Compression ratio") + print("=" * 79) + print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) + for wbits in range(9, 16): + print( + "\t".join( + [str(wbits)] + + [ + f"{100 * (1 - size[wbits][memLevel] / raw_size):.1f}%" + for memLevel in range(1, 10) + ] + ) + ) + print("=" * 79) + print() + + print("=" * 79) + print("CPU time") + print("=" * 79) + print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) + for wbits in range(9, 16): + print( + "\t".join( + [str(wbits)] + + [ + f"{1000 * duration[wbits][memLevel]:.1f}ms" + for memLevel in range(1, 10) + ] + ) + ) + print("=" * 79) + print() + + print("=" * 79) + print(f"Size vs. {WB} \\ {ML}") + print("=" * 79) + print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) + for wbits in range(9, 16): + print( + "\t".join( + [str(wbits)] + + [ + f"{100 * (size[wbits][memLevel] / size[WB][ML] - 1):.1f}%" + for memLevel in range(1, 10) + ] + ) + ) + print("=" * 79) + print() + + print("=" * 79) + print(f"Time vs. {WB} \\ {ML}") + print("=" * 79) + print("\t".join(["wb \\ ml"] + [str(memLevel) for memLevel in range(1, 10)])) + for wbits in range(9, 16): + print( + "\t".join( + [str(wbits)] + + [ + f"{100 * (duration[wbits][memLevel] / duration[WB][ML] - 1):.1f}%" + for memLevel in range(1, 10) + ] + ) + ) + print("=" * 79) + print() + + +def benchmark(): + with open(CORPUS_FILE, "rb") as handle: + data = pickle.load(handle) + _benchmark(data) + + +try: + run = globals()[sys.argv[1]] +except (KeyError, IndexError): + print(f"Usage: {sys.argv[0]} [corpus|benchmark]") +else: + run() diff --git a/performance/mem_client.py b/benchmark/mem_client.py similarity index 77% rename from performance/mem_client.py rename to benchmark/mem_client.py index 6eab690d8..db68eb995 100644 --- a/performance/mem_client.py +++ b/benchmark/mem_client.py @@ -8,9 +8,11 @@ from websockets.extensions import permessage_deflate -CLIENTS = 10 +CLIENTS = 20 INTERVAL = 1 / 10 # seconds +WB, ML = 12, 5 + MEM_SIZE = [] @@ -24,9 +26,9 @@ async def mem_client(client): "ws://localhost:8765", extensions=[ permessage_deflate.ClientPerMessageDeflateFactory( - server_max_window_bits=10, - client_max_window_bits=10, - compress_settings={"memLevel": 3}, + server_max_window_bits=WB, + client_max_window_bits=WB, + compress_settings={"memLevel": ML}, ) ], ) as ws: @@ -43,9 +45,12 @@ async def mem_client(client): await asyncio.sleep(CLIENTS * INTERVAL) -asyncio.run( - asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) -) +async def mem_clients(): + await asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) + + +asyncio.run(mem_clients()) + # First connection incurs non-representative setup costs. del MEM_SIZE[0] diff --git a/performance/mem_server.py b/benchmark/mem_server.py similarity index 82% rename from performance/mem_server.py rename to benchmark/mem_server.py index 81490a0e7..852796249 100644 --- a/performance/mem_server.py +++ b/benchmark/mem_server.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio +import os import signal import statistics import tracemalloc @@ -9,9 +10,11 @@ from websockets.extensions import permessage_deflate -CLIENTS = 10 +CLIENTS = 20 INTERVAL = 1 / 10 # seconds +WB, ML = 12, 5 + MEM_SIZE = [] @@ -34,17 +37,22 @@ async def handler(ws, path): async def mem_server(): loop = asyncio.get_running_loop() stop = loop.create_future() + # Set the stop condition when receiving SIGTERM. + print("Stop the server with:") + print(f"kill -TERM {os.getpid()}") + print() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + async with websockets.serve( handler, "localhost", 8765, extensions=[ permessage_deflate.ServerPerMessageDeflateFactory( - server_max_window_bits=10, - client_max_window_bits=10, - compress_settings={"memLevel": 3}, + server_max_window_bits=WB, + client_max_window_bits=WB, + compress_settings={"memLevel": ML}, ) ], ): diff --git a/docs/changelog.rst b/docs/changelog.rst index d2d58c57b..434cdea61 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -42,6 +42,8 @@ They may change at any time. * Added compatibility with Python 3.10. +* Optimized default compression settings to reduce memory usage. + 9.0.2 ..... diff --git a/docs/deployment.rst b/docs/deployment.rst index 8baa8836c..aa1af211c 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -6,8 +6,8 @@ Deployment Application server ------------------ -The author of ``websockets`` isn't aware of best practices for deploying -network services based on :mod:`asyncio`, let alone application servers. +The author of websockets isn't aware of best practices for deploying network +services based on :mod:`asyncio`, let alone application servers. You can run a script similar to the :ref:`server example `, inside a supervisor if you deem that useful. @@ -59,7 +59,7 @@ memory usage can become a bottleneck. Memory usage of a single connection is the sum of: -1. the baseline amount of memory ``websockets`` requires for each connection, +1. the baseline amount of memory websockets requires for each connection, 2. the amount of data held in buffers before the application processes it, 3. any additional memory allocated by the application itself. @@ -71,60 +71,78 @@ Baseline Compression settings are the main factor affecting the baseline amount of memory used by each connection. -By default ``websockets`` maximizes compression rate at the expense of memory -usage. If memory usage is an issue, lowering compression settings can help: +If you'd like to customize compression settings, here are the main knobs. - Context Takeover is necessary to get good performance for almost all applications. It should remain enabled. - Window Bits is a trade-off between memory usage and compression rate. - It defaults to 15 and can be lowered. The default value isn't optimal - for small, repetitive messages which are typical of WebSocket servers. + It should be an integer between 9 (lowest memory usage) and 15 (highest + compression rate). Setting it to 8 is possible but triggers a bug in some + versions of zlib. - Memory Level is a trade-off between memory usage and compression speed. - It defaults to 8 and can be lowered. A lower memory level can actually - increase speed thanks to memory locality, even if the CPU does more work! + However, a lower memory level can increase speed thanks to memory locality, + even if the CPU does more work! It should be an integer between 1 (lowest + memory usage) and 9 (highest compression speed in theory, not in practice). -See this :ref:`example ` for how to -configure compression settings. +By default, websockets enables compression with conservative settings that +optimize memory usage at the cost of a slightly worse compression rate: Window +Bits = 12 and Memory Level = 5. This strikes a good balance for small messages +that are typical of WebSocket servers. + +If you'd like to configure different compression settings, see this +:ref:`example `. If you don't set +limits on Window Bits and neither does the remote endpoint, it defaults to the +maximum value of 15. If you don't set Memory Level, it defaults to 8 — more +accurately, to ``zlib.DEF_MEM_LEVEL`` which is 8. Here's how various compression settings affect memory usage of a single -connection on a 64-bit system, as well a benchmark_ of compressed size and +connection on a 64-bit system, as well a benchmark of compressed size and compression time for a corpus of small JSON documents. +-------------+-------------+--------------+--------------+------------------+------------------+ | Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | +=============+=============+==============+==============+==================+==================+ -| *default* | 15 | 8 | 325 KiB | +0% | +0% + +| | 15 | 8 | 322 KiB | -4.0% | +15% + ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 14 | 7 | 178 KiB | -2.6% | +10% | +-------------+-------------+--------------+--------------+------------------+------------------+ -| | 14 | 7 | 181 KiB | +1.5% | -5.3% | +| | 13 | 6 | 106 KiB | -1.4% | +5% | +-------------+-------------+--------------+--------------+------------------+------------------+ -| | 13 | 6 | 110 KiB | +2.8% | -7.5% | +| *default* | 12 | 5 | 70 KiB | = | = | +-------------+-------------+--------------+--------------+------------------+------------------+ -| | 12 | 5 | 73 KiB | +4.4% | -18.9% | +| | 11 | 4 | 52 KiB | +3.7% | -5% | +-------------+-------------+--------------+--------------+------------------+------------------+ -| | 11 | 4 | 55 KiB | +8.5% | -18.8% | +| | 10 | 3 | 43 KiB | +90% | +50% | +-------------+-------------+--------------+--------------+------------------+------------------+ -| *disabled* | N/A | N/A | 22 KiB | N/A | N/A | +| | 9 | 2 | 39 KiB | +160% | +100% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| *disabled* | N/A | N/A | 19 KiB | N/A | N/A | +-------------+-------------+--------------+--------------+------------------+------------------+ *Don't assume this example is representative! Compressed size and compression time depend heavily on the kind of messages exchanged by the application!* -You can run the same benchmark for your application by creating a list of -typical messages and passing it to the ``_benchmark`` function_. +You can adapt the `compression.py`_ benchmark for your application by creating +a list of typical messages and passing it to the ``_benchmark`` function. -.. _benchmark: https://gist.github.com/aaugustin/fbea09ce8b5b30c4e56458eb081fe599 -.. _function: https://gist.github.com/aaugustin/fbea09ce8b5b30c4e56458eb081fe599#file-compression-py-L48-L144 +.. _compression.py: https://github.com/aaugustin/websockets/blob/main/performance/compression.py This `blog post by Ilya Grigorik`_ provides more details about how compression settings affect memory usage and how to optimize them. .. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ -This `experiment by Peter Thorson`_ suggests Window Bits = 11, Memory Level = +This `experiment by Peter Thorson`_ suggests Window Bits = 11 and Memory Level = 4 as a sweet spot for optimizing memory usage. .. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html +websockets defaults to Window Bits = 12 and Memory Level = 5 in order to stay +away from Window Bits = 10 or Memory Level = 3, where performance craters in +the benchmark. This raises doubts on what could happen at Window Bits = 11 and +Memory Level = 4 on a different set of messages. The defaults needs to be safe +for all applications, hence a more conservative choice. + Buffers ....... diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 4f520af38..34cc1f950 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -448,7 +448,11 @@ def enable_client_permessage_deflate( for extension_factory in extensions ): extensions = list(extensions) + [ - ClientPerMessageDeflateFactory(client_max_window_bits=True) + ClientPerMessageDeflateFactory( + server_max_window_bits=12, + client_max_window_bits=12, + compress_settings={"memLevel": 5}, + ) ] return extensions @@ -631,5 +635,11 @@ def enable_server_permessage_deflate( ext_factory.name == ServerPerMessageDeflateFactory.name for ext_factory in extensions ): - extensions = list(extensions) + [ServerPerMessageDeflateFactory()] + extensions = list(extensions) + [ + ServerPerMessageDeflateFactory( + server_max_window_bits=12, + client_max_window_bits=12, + compress_settings={"memLevel": 5}, + ) + ] return extensions diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 908cd91a4..d3c6f0ac6 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -512,33 +512,33 @@ def test_enable_client_permessage_deflate(self): ) in [ ( None, - (1, 0, None), + (1, 0, {"memLevel": 5}), ), ( [], - (1, 0, None), + (1, 0, {"memLevel": 5}), ), ( [ClientNoOpExtensionFactory()], - (2, 1, None), + (2, 1, {"memLevel": 5}), ), ( - [ClientPerMessageDeflateFactory(compress_settings={"level": 1})], - (1, 0, {"level": 1}), + [ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7})], + (1, 0, {"memLevel": 7}), ), ( [ - ClientPerMessageDeflateFactory(compress_settings={"level": 1}), + ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ClientNoOpExtensionFactory(), ], - (2, 0, {"level": 1}), + (2, 0, {"memLevel": 7}), ), ( [ ClientNoOpExtensionFactory(), - ClientPerMessageDeflateFactory(compress_settings={"level": 1}), + ClientPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ], - (2, 1, {"level": 1}), + (2, 1, {"memLevel": 7}), ), ]: with self.subTest(extensions=extensions): @@ -849,33 +849,33 @@ def test_enable_server_permessage_deflate(self): ) in [ ( None, - (1, 0, None), + (1, 0, {"memLevel": 5}), ), ( [], - (1, 0, None), + (1, 0, {"memLevel": 5}), ), ( [ServerNoOpExtensionFactory()], - (2, 1, None), + (2, 1, {"memLevel": 5}), ), ( - [ServerPerMessageDeflateFactory(compress_settings={"level": 1})], - (1, 0, {"level": 1}), + [ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7})], + (1, 0, {"memLevel": 7}), ), ( [ - ServerPerMessageDeflateFactory(compress_settings={"level": 1}), + ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ServerNoOpExtensionFactory(), ], - (2, 0, {"level": 1}), + (2, 0, {"memLevel": 7}), ), ( [ ServerNoOpExtensionFactory(), - ServerPerMessageDeflateFactory(compress_settings={"level": 1}), + ServerPerMessageDeflateFactory(compress_settings={"memLevel": 7}), ], - (2, 1, {"level": 1}), + (2, 1, {"memLevel": 7}), ), ]: with self.subTest(extensions=extensions): diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 80c9d56bb..353e5b370 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -985,11 +985,11 @@ def test_extensions_error_no_extensions(self, _process_extensions): def test_compression_deflate(self): server_extensions = self.loop.run_until_complete(self.client.recv()) self.assertEqual( - server_extensions, repr([PerMessageDeflate(False, False, 15, 15)]) + server_extensions, repr([PerMessageDeflate(False, False, 12, 12)]) ) self.assertEqual( repr(self.client.extensions), - repr([PerMessageDeflate(False, False, 15, 15)]), + repr([PerMessageDeflate(False, False, 12, 12)]), ) def test_compression_unsupported_server(self): From 66ded17c39cd6e5bb408edd9b53b9afad7ab19fc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 17:54:40 +0200 Subject: [PATCH 016/762] Main branch is now called main. --- README.rst | 12 ++++++------ docs/contributing.rst | 2 +- docs/heroku.rst | 12 ++++++------ docs/index.rst | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/README.rst b/README.rst index fa42a9be3..cd635835c 100644 --- a/README.rst +++ b/README.rst @@ -28,8 +28,8 @@ What is ``websockets``? ``websockets`` is a library for building WebSocket servers_ and clients_ in Python with a focus on correctness and simplicity. -.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py +.. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py +.. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. @@ -80,7 +80,7 @@ Does that look good? .. raw:: html
- +

websockets for enterprise

Available as part of the Tidelift Subscription

The maintainers of websockets and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. Learn more.

@@ -113,7 +113,7 @@ Docs`_ and see for yourself. .. _Read the Docs: https://websockets.readthedocs.io/ .. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers -.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/master/compliance/README.rst +.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/main/compliance/README.rst Why shouldn't I use ``websockets``? ----------------------------------- @@ -145,8 +145,8 @@ For anything else, please open an issue_ or send a `pull request`_. Participants must uphold the `Contributor Covenant code of conduct`_. -.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/master/CODE_OF_CONDUCT.md +.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/main/CODE_OF_CONDUCT.md ``websockets`` is released under the `BSD license`_. -.. _BSD license: https://github.com/aaugustin/websockets/blob/master/LICENSE +.. _BSD license: https://github.com/aaugustin/websockets/blob/main/LICENSE diff --git a/docs/contributing.rst b/docs/contributing.rst index 61c0b979c..59d7451e0 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -10,7 +10,7 @@ This project and everyone participating in it is governed by the `Code of Conduct`_. By participating, you are expected to uphold this code. Please report inappropriate behavior to aymeric DOT augustin AT fractalideas DOT com. -.. _Code of Conduct: https://github.com/aaugustin/websockets/blob/master/CODE_OF_CONDUCT.md +.. _Code of Conduct: https://github.com/aaugustin/websockets/blob/main/CODE_OF_CONDUCT.md *(If I'm the person with the inappropriate behavior, please accept my apologies. I know I can mess up. I can't expect you to tell me, but if you diff --git a/docs/heroku.rst b/docs/heroku.rst index 8af2ebd3d..d23dc64c0 100644 --- a/docs/heroku.rst +++ b/docs/heroku.rst @@ -16,10 +16,10 @@ Deploying to Heroku requires a git repository. Let's initialize one: $ mkdir websockets-echo $ cd websockets-echo - $ git init . + $ git init -b main Initialized empty Git repository in websockets-echo/.git/ $ git commit --allow-empty -m "Initial commit." - [master (root-commit) 1e7947d] Initial commit. + [main (root-commit) 1e7947d] Initial commit. Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if you haven't done that yet. @@ -32,7 +32,7 @@ you'll have to pick a different name because I'm already using .. code:: console - $ $ heroku create websockets-echo + $ heroku create websockets-echo Creating ⬢ websockets-echo... done https://websockets-echo.herokuapp.com/ | https://git.heroku.com/websockets-echo.git @@ -86,7 +86,7 @@ Confirm that you created the correct files and commit them to git: Procfile app.py requirements.txt $ git add . $ git commit -m "Deploy echo server to Heroku." - [master 8418c62] Deploy echo server to Heroku. + [main 8418c62] Deploy echo server to Heroku.  3 files changed, 19 insertions(+)  create mode 100644 Procfile  create mode 100644 app.py @@ -99,7 +99,7 @@ Our app is ready. Let's deploy it! .. code:: console - $ git push heroku master + $ git push heroku main ... lots of output... @@ -109,7 +109,7 @@ Our app is ready. Let's deploy it! remote: remote: Verifying deploy... done. To https://git.heroku.com/websockets-echo.git -  * [new branch] master -> master +  * [new branch] main -> main Validate deployment ------------------- diff --git a/docs/index.rst b/docs/index.rst index 257a806ea..4ec682f69 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,8 +21,8 @@ websockets ``websockets`` is a library for building WebSocket servers_ and clients_ in Python with a focus on correctness and simplicity. -.. _servers: https://github.com/aaugustin/websockets/blob/master/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/master/example/client.py +.. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py +.. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. From f9371fca175f799ae3f1cc1cb0d5122cfd25d8de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 19:03:11 +0200 Subject: [PATCH 017/762] Improve example in intro. * Refactor to increase clarity. * Avoid deprecated usage of asyncio.wait. * Clarify what happens when clients disconnect. * Restore call to main() -- it disappeared. Fix #757. --- example/counter.py | 49 +++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/example/counter.py b/example/counter.py index 81cbdb55c..a9ed61893 100755 --- a/example/counter.py +++ b/example/counter.py @@ -22,47 +22,46 @@ def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) -async def notify_state(): - if USERS: # asyncio.wait doesn't accept an empty list - message = state_event() - await asyncio.wait([user.send(message) for user in USERS]) - - -async def notify_users(): - if USERS: # asyncio.wait doesn't accept an empty list - message = users_event() - await asyncio.wait([user.send(message) for user in USERS]) - - -async def register(websocket): - USERS.add(websocket) - await notify_users() - - -async def unregister(websocket): - USERS.remove(websocket) - await notify_users() +async def broadcast(message): + # asyncio.wait doesn't accept an empty list + if not USERS: + return + # Ignore return value. If a user disconnects before we send + # the message to them, there's nothing we can do anyway. + await asyncio.wait([ + asyncio.create_task(user.send(message)) + for user in USERS + ]) async def counter(websocket, path): - # register(websocket) sends user_event() to websocket - await register(websocket) try: + # Register user + USERS.add(websocket) + await broadcast(users_event()) + # Send current state to user await websocket.send(state_event()) + # Manage state changes async for message in websocket: data = json.loads(message) if data["action"] == "minus": STATE["value"] -= 1 - await notify_state() + await broadcast(state_event()) elif data["action"] == "plus": STATE["value"] += 1 - await notify_state() + await broadcast(state_event()) else: logging.error("unsupported event: %s", data) finally: - await unregister(websocket) + # Unregister user + USERS.remove(websocket) + await broadcast(users_event()) async def main(): async with websockets.serve(counter, "localhost", 6789): await asyncio.Future() # run forever + + +if __name__ == "__main__": + asyncio.run(main()) From 32a135e4b33020eb1a2a45cfa5d90dc2e22cbeb1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 May 2021 08:26:01 +0200 Subject: [PATCH 018/762] Generate Github social preview. --- logo/github-social-preview.html | 38 ++++++++++++++++++++++++++++++++ logo/github-social-preview.png | Bin 0 -> 28346 bytes 2 files changed, 38 insertions(+) create mode 100644 logo/github-social-preview.html create mode 100644 logo/github-social-preview.png diff --git a/logo/github-social-preview.html b/logo/github-social-preview.html new file mode 100644 index 000000000..187a183b0 --- /dev/null +++ b/logo/github-social-preview.html @@ -0,0 +1,38 @@ + + + + GitHub social preview + + + +

Take a screenshot of this DOM node to make a PNG.

+

For 2x DPI screens.

+

+

For regular screens.

+

+ diff --git a/logo/github-social-preview.png b/logo/github-social-preview.png new file mode 100644 index 0000000000000000000000000000000000000000..59a51b6e33810bc0bc34b7d2ea6e47b0fd8be78b GIT binary patch literal 28346 zcmb5WcT|&2^e!4e!3Nk-ilEp5QITG4fFPh!A~jf$CPi9k31CA-Km|mlsWc%#=q->$ z5tJqfhENlVv_J?offTrt;CJs?_uRAA@A^mNGc&XI?0NRHpFQu(_xzfvf$$dbEf5Gq z_=@4B8xY7w@FRctCVueO)z=s35Qt~Ql}qO>0{Nzie8LdOHvW$cMMUe+$Bj^&2pF?r zANb%C{2zVjg>U@-HJT3s(F^~deQb9G9{f)~cs~D6KmLb%Md{F|G`_C>4DKO z3u14>obIH$l~*jLtJbhqSY+vj!IlU?1<#1Vy{bEtAM1JZ^X)vW*FVsuBo|lG+n%EI zJ};$*(Za^2O|{C7pY}qx)%q$_SgHirRd$xH`LxD-oA)(B-G_9M}@u5U9PscqM{k(IC_c{;5O8m@*&R_!w19OYA1zO>>_@ew3RIAPh)EHwzbyQwmz;2( z`BM>aa^0fL;L`<4dWAv0?%KBxzKVG(4_Ml}CQ6)>>bxQ@E-C*eKz~@kc9|nsKHn49 zmJoA# zrE16ObWnFB`FtMEJ8(vis1Dx&3ngt(8%pxeyh<6R})2 zfH}G7O>0`6)s!i!T6@--QaX3Z@(zX{63)L3qPIcKylat}6?K+gFn_TKo_Qx1eTPLL z5dT)y3$>B*EFq9<6@`=4WauyM|_`qH?f9CEW1F)DIC4k4yWqN>*H4&Aiau zH6>#pkZF5F<@lBm=p!kl35$dIr32F;*>wNu%2`jZ?fSn+p5hW==Z8rN1z3$N=F{7> zCP}0Tn?Nijt8@03OdGm7#r<>e=JCqkvrWAqb}2C*YUSdRzxL9|V|Q8ef7$&?s&0GM z&*!nGU)fA%JJo)VAsOQX4)PsE^CJh@mlpP`Rp**T!Mb4?tC>;ph!AK%BrIiz7MpUH zQ3i{XhplR8x*M$egW|xY@Q532(91m}Lc}ea{lG#`Ne(uPLqFnsoq0Ly6CV!@cYAz> z3xqnfTo#k=Y&9mPY3lFiFSia|!@}A^t>rq!5NJdn4d!>2)aREhvW*9g z!P=J3ZVY;XY107L)n1zYdy4xo8mob&^rgMuPxHtS5a1{J=+j9B-h=0SDdTz%Q=Qn>%mx?0rgGKz+$vz){3Wtc@}3q(rtjt-wWKX!ep zt&eJNEr}eZ+V@B!*C^i0%2ZtoLkJ{oX&q zL4JeR{t>|asbIOG;dbEnlU5P|euG{hh@@*3a*%NpQ7N_EJUBtHrd4}$Yf&TDtKg&U z$hEt@G*PMLeNyWQ=eU8YZN6V81Zr>?_#ufDDkr1OSGyBW_<;-qp%35TlUY2Psmw<& zkQc#q&TmThl84!>Hmg75|zu~xRRC6>^hp^v_0pJS_stIDg-J*B`!LIy;)mM z5cy{i@1)de?6>59hemN_FbXcrw!e)Op~-BuhC*Cqr|Tf(Blf|2pL)u+ABT))aJ}rshXx%wOlW0$snmw?%NCAg?W8 ziHw4)f(MZPH02de;`0$Jsm}toe7wYd_d2+Bc-9>QUwX_TCH#quk=F@1a1G&)oqZy5 zpO1IcgFgux^n#WlB9caF0}ifILcCas83P*uxXT3&Df_9T%we>SJuFp?s z_D)+tf@%OTW117mSIJAKqykU>qhs2+sv>wj$%2Qm zr)Tb0v1Dh&o^`B_-L>1=R*MoSrWf!_A_sW9!TgLud*U>*{w?z{t4bD6N~j-#;z^dK zmuU9FF1Rs3q|#ntA=akGA=&(pOF`TUkJ-}PgW2M}oPdFkhAt};7)5Ji(MNg!VxQ3g z%|dRgq`}qLMtv9`9d;8LV^ltXDOO|a&@Yy2(|?z%?sfhh!6y^E6~wgCre~T$vP?yV zK&S1g2bd&6?CoH*D^4R8SxQ~1f6raAgS%vxWR>!T5a`}WSWXC3DT8J|sRo6GmFqBv zx`<-j<(S+fnwXrFDxtUFDRN;nx3%y>_@1k*rEz7A(~wKHT*r?U5DmyA%eS+&a;!RZ z43CaKXw473HPGq4AS4(*aS|vG#|NG{&LPlWk+78zXwE#N0F)>VTQ-%LR`IqSSaRc9 z3PjYs*GtpJge*JJ+`OKFBP^%9tD}X}F+FKev3utLYsn0AclgESz9^-{IjRm_S}QkR zhi3CkyW&J~G_!B0t(}v2$o3&F6aRI{E$_hue+!_U+z<()bFF4^m4R~@l zVC-u=5H{iRonCNZvHkZ?9?g0D8;i|eXb(whIAxeFAfWesFJPJNao$~?v?h>^weTf4 zs}|*=&Zek=tX(4X{B7%2<1TxfyKGr&%4=5pYgPz!WCPW?C4Mwyc7V-v z@KWHWoyP-ik!2=sS1IX1N*l`XP20q*CIN;{t8Vu!&6wlIVgcofDC zk%;%jU&XjlG2}*NM-^N-=pxd@9J8f^^l%Dj#!n$oIv|%H6jUE=_W;sFzs*>L1_~W_ zd2dv? zO@O6d4w4c~$1;Rhr0lhuY)s5q%~3vnON287)tB^wt4o9;P)74GVQ+p#`OmkCi|ziPQpCU2yxQ5AXi6 zBZG!d-%DNaqq#*X-0w0 z<{!MUtzaKKG^!CW-8DYW`J)7+2~KuVO7-)m^s)}Gw$Cxmpm6%bFg-{^wV z7_rjUbwB1M%NIVur97~{nE$X&i1^7O{=LS)!S90TQ9yn@7y67o45K zZd>W-`U>$vzmpHrP8b-WkO-^Wr|sX@fy#gg(m7@DBzbkO#;ImZEz0Bo$gWayg6`pl z#67yP)^mS|eG_i{V&uOu&c}$ zu6wFo$(z_bhU~3Mc6=~C)o!IzZMO3PX#^qSq68N4%&Ihf!Q|BOo4qVK?{2p~MbmC`O zTdGMYfq0kl^bS+|Fj{j~?!hF`l^5pPe+Qc**6Fch2EVR6P@;AcJ+&w=dlE)GaG7%* zYeaYAtTs2fFDX#dhdVNL{QjQERcF*5U|Q1S%{DsF`4;ZuVy=dLF4+%CFR&t*NaOZUWzw3lJD0 z27d9HoCVlu&99X+M@aRe71O9IRAs;yL!j;(rtP28{v8xLzp0Y|CGm6| zkVruPpkc>SYp*M02pIQZd>FHGxPK9BF~M#JMH2{}PEh$~nhSKpstgPPBO#5#+ft8h z?Z<3c`{(d0i<^5X05knSH-c7@f-0F(MXYp+h5T|z1U?@z_)1DB{BLt*HNSZ||MdWp zLxql;!va<(qXv953gc_#e3vuVt%v{Y>g;26yr~Naf$qNmII*0}2O9C%G?s>U!ebG( z&;2(*h=Ogf|1Dc6&nnIifi{+dCKE^vc5k9t!1DZ$Hu#>m7Df~J^zmb{coM;A2hb%! z!_5k1CizYN|D<*kr7no(#u%i+$Ibd^VVbk12jPQ>V@F@~VRpW9#`Sdl>HTCU#H-?` zDXhBlK)p*Nk4)POv!{D3U~#b(y7W1^ZfNsVG=9W-+P)D~!MkH)7n|(Y9lBgGB_)Cbmk^-w)0E0Z zHyZSXu5hvsX?XkEj(-WdGE6Vn@OQ9k>c8=K-bw|y(oS0bA`xV$L@2`l05ysxV&|P{ z7nGlqf2XD9@Q*t5yIMKX`R;gS%O!=T%jLNo94zpA~c!s~-Va%mNlwi_!x(2jIF50hpzkrZaH#uSlL| z;g?!DK6SsrnLdRxCXIUeXY9aPj zo7Pp*h{CaY9lyIc_^98qa?f;g9c@Yy<7az8=zx;q9CtmtRZ;?Kzc##Lejm@ZiJHt0 z^qi*yxu?z(3a%53EtmK$Yi}w~-!M_nEz>K5T^|}*Cgcv4ks*;VAY{pvrSQ>M^Cjgg zfzx%ZXF=zPZt0N)(iIn0>^ei2wG~Zn)3y(uU`N{?0FSszOAU|HM zNt~S!N?N$j|FOxF@O1q=?`Px5p*=zoPkk5OCEm{up&4B~mFC?u0=U%Nli zqO&{q>Bsz%sien`3fSqmlyXzv@rPea4oxbytoQSWfBq$v@)1Q7eSUp3EhkI<2qAts zGz!$*^%jN#sQFi^Nh}t3g4Q9mUn(#*xh+-KCT?J$6ZNp@)IzD;+Tx@;oM5=#bl0Zj zszWV9&%UHg{XyOl%wYQxkXIVbZF+Y1aH)sP2SisiR%?DW>!O(F33OS2Uv$iX>@Awx zFvGQJQ;?8D!Alc;8`CkgiG$1A5s?SP)msFkxv_=Afh zikdW%)3f|h*IZR+o}_<5Dsbk27o#AR?c;9ioG4zvfI3;Ph8Gou#5b>B3YuYZ&GyYv z+7a95Hu7SgIgAl0ZMI78|y}L0v@J@Vfwlr|R zRHg5q`||&_^_PnsrYTgCW%!QS!E;M6-*U}k{9Med3td?eNKe^LsaE>;!E2%ZVSeANrfBb zya)}(B2i+20BSQ=%~LAkw`ndIX^}8YHNOz3Krbymf(oB!viv`n3WSx-h5KB&Fr%IX zD?YoFu?|w;TJ5n8w!i7@bfgYN^vrJDylMext^4;Fa-D^&R^8TI?Z=1*usC% z()CHxk*aKC59F2pQ%O~7yn|n!uL^!%--(^Ef=Ri8vR=i>>Anj1Y2X{9fL|JEax{No zJfiV8Ap$Oa;>Ia;M1KERVoToxi*;b%$|Y;u=3cfs0;bJZwd^^3P-?k*)Q1N67j*Uo zl1L7LfnCk3^&0NYGtVT!2kE~+x_r)t{P5KM%)4>LdzbFDiP6Hq0PZX*nCY6{KZK9q zzy0wYs)Cof1_GTR+KVJou1=fXyLIK9pCftUa?_EXY@_?gMn>btlP)|u{)oCUkTPby zM0iku6&(2Ls8sss5pEBQX}7YdN}!bH)cye5JZl+0b4uu;PIAOvhF~J-8}u_-6YtX;F_Bog?J!2C;f*Cy|Lu)i96~v0=VJ@ zvnzWntC@@`)?tH7Xh0?9h$^pd;+whADR_+$5KAN+cQ^5n!o9TiTu}NIACHn* zzt)@)^oP40PoB>-E9SRSn+B#PmFcD_Wi_N;bV3mfP1;vu-BbjMbEb6_6bkNru|ZfbQjHdCCiON9N|3sDc`WGe-5YVOVk&Y14=(ZC%%}Rd?^9;3J*U&U|n~ZhzTw9`l znKzpzsrulUA2t9Dil$76gqj2?1b7XBr#Ba>;7h5Xj|R=lHZ3;&SVP*L#+7=baMh81 zoIqIVQ49)Wuu+yYGbtPQqWa!?v3Z(XYl!c}st}+C7D6$`8dBIu7^^Kc?12}{eyF*P z^~TEk$PF4E-DZ~OdS9wnbRxBC<(1sp#6%Sa1eaTe^#+3Z%vgf#v6V7B73!QrkP82h-IxQ5G(ikqt zzRcdr!S+q?(Pz9abCg6taAhU9ajSYx!FwqfZt3K=vvJ6Nxc(%?jrby@mxyJ4=nu93 zCKVXeP&FGe@srw?DkaxnPP(jj+44|+`x*SEMzlR{@03VDznxC8JW%ZSh$NUw?HX(S%CcHjH#S^XWCQHna zEH5n2*b8U=qH6hF-@LsEaZCR%Ij7wj)8JpNZJ@i1F^b~wJ%VcI^Q z5kPBttXU|xU)i~0i#X4>3D=#grD((wp^B6!aY$p9GQ9wvNi~uPCjq0%6pyrT)bPe@ zsXZW2Df)ZPtq0FuZp{@IH)I8igEBP=84B(ff7NqIW<)XMhvJnzw&9Vi$7k19_uQ18 zjmXcXiS7rD8@jD(dq*wm${~8e1pe*&+NHwvno#$r^XQZ-j78x-&NFeaBK|O(z$ekg zE)S6QkT2N&nzu2Z9R)KS#JykaaM$!ilw! z7dLp;3uG^mLvcTA1G+TWrWF$519(JcefS_Fd8gyI8#LS#heA!WDhO{?iS&)Sp)lOsla6(0k2HaevG?C{^@OaW)Hx5weUCtTql2)CBG;D;DUL9X1GmhiS9s6QgAwk&2f= ze~Q@LB6o5DD*&GoksPCn@y7bacZ2#S3LZ3iEGY+F#!KCVJ3i-KIAXZRQ7@4&tXasW zZ{?N5;E+Gm&HSnKC2((2)1H`k*>CV-;zKgV<`|IEXL0Y#=>h)fY>}%VIN%ahuja$( zZ4|ox=fX7()aEcz(0Js5#-l2X-6U7x{`dbGNjQ4b0`|a+WdjF|gO2 z5yupbv)A{z-6}Z5Vr3iP)-5?$tX5~@_y#{m4g)Vjh2k$7IE+SYmh-YyK7J;v;O8a^ zqZ3&x$a^#m&2*^SX$fvN{Xvb`W`!_Gjhj{Q^I6?EZc=tCcH1i4#Ukc*^F@r+2vfNsWApiW#q*U4izptBsp$aiAOMZ(Se!(S_x{Z+bh*2Ht6p+k87xaZr?FDC}lRh?m<50?R zY(6M=gR}e)z!+)tJs=gIS|Ch-%qe|{s^)A_V+SWaY_>1~MAt0^(GyA)`@o2n&qbTM z$)J#1#oPz<(|aKw1-1tD5+6}OT%;L**SVI|#*O86b8Bh^1R|w249Xu0 z3tdfW5ZI|Pt9Zxik~@p=zFO-d_vMX%9l%rB8CMvFf?AMHd5UOq?|E_r_l(4!XAbHO zx7=WaAQSka9rRMKYe2y1EfD~^5V13teECv~RzC9FI0f7=24wdLEaD|KgHIhRI^kYj z9l~=0$_J4O^xOcV@#+)c6OF=fNu4lro^dH|m~^QU-#`K?5?+}I(NA^~$B*+EPUgB< ziJNx1_#4ZCy_O|bmWq}M3znA2xz_oV@|-M0ZoZ*EQLgCZ$ObKLmPxpV&BegJ*m`MLz)}uclKD;~eGxCQ@XC(I37tyY z6=BBe1th^dx0!swz>*K?_^qit=F~>siz6=Kfe$9_Q0<`tiYZ zUQ%XJ`R;-Ax$eFn&OvPpov;;=*nrTkw6J9p(N1Rc5rjR*0cJ-zHM!DgC=lt}Td`Q@ zHMvNjs+k)I5WOq3XJ335f#`u-X+h@QwWfESa2HOsROk))?pigV7&ALni;>C*|F`1H zYcn>h(B=K!*lR!$-rfcz;V0XGwwyuPHitk=qQimpD;G_|3psbpn^JdLtXGMtX`$_keQgj%zI^|Rk(RHv=HF%}(1_HddxDV2f zgPl!}{Nr_D66OB;%Dsmwh0o!G#*wEY776>Wz*Gfj_%wE}xHlwC<<=wc;?5}7h-M}t zw10BSN(13foH#q;ac=hWD^r$!Uf!aAAF|pH`&H9BZ)&o6GkDwaG{jXOr4@m%18vxj>Q3QQh+ypVPc)G*{TmsgUDH3`UZ1_q_rk5Xw znT_ttsdplOju`14V}$L6K-dENARpUuCqp6AD&NUP*Etgmgg1!PM*tn1C?*PbBW?ZhYMV z#KlnXMiV$&+S^{R0W(;|qM8IFLQKfFm#B+~lxM{p_L$j~wqX}utS#7PeJnmT$f^Sh zJ&Ev#?oB*!XGALSeC1?F@R3POzS>q{t-$lDFCnYMSHFUV=Dxb03ahn z^aLPrW>(uD%^6Z5ftzncFJoIQOaDv;t&0pVzDtffEU6l70^wud8Sw=2;SIonBvSJ; zvNhA|E=NBDos=NxnKu3&OE+a%C1X90O%&6NT2pj6jS&9jW{(s9rc5K+yIf&wi~h#Z zr{{6)PiG`nve5-E7YV{9u!>%?es#dm0^f#NAzfwQ7#GTilp4MvWMdbj^T@_g#=<3Y zB3i1IomeEuS_fz!d~hL4dxstbw{|BkUU;1xe6q|~%nY<`U+q2d?p0S<+9Dz0HMZJF z{2dt*R-S3nBnXiKkC+e&UEFbABZH@YmkGzMnqIN{!4JjhcS4~h;Ji$Wy{@o>nLf8d zH!}NDs5ex=E5zF%69aDQ2>;Z{sjlF}Md{L6kz~9pZ0jT{$`CfN*It+up}U-;4lq;} z6#LJeYFq`79I6O^u_+tcx!?Hr>O(2?B|})nosrd}T`kelPj=vi-wXV=KJ#qXBBA~b zw)!mKou!%Y5;w~fx6PEI&6sYSy~LFd-_tRwaP~ymtqrpgb=v>UTJ7p~CBN$+QV+S- z)-`C(PD9p1FfOt4krNsrafw6O@`Ue%1dp&5*>-I_kL@j()nzwb2M0}A!5Qe8cf+{@ z_!QjI>YeVU(-ioulr4lCsE<2q`#e(V(Q*h0B>72)4<=eLbq_GNbIJB3V#H<8PA81@ z2ran44*B}>7$`b`;7-1F7u80{QlQQ`+*F?e(TBx$`D0I!=AV4~;^g(@2ClmAK;Y$n z>kDuWOeO*O9J{F8jps|HE2LmwFQF!<#>n_xP@LE`4&|N~({O$8g~`{yvk?NOuz`*B z`5vL^IN9NQ7(T^8KeohGD|>Ok#SI{oYoBV3wL4UQcVc(Q$sz{i z`-QBi3j!Kkw+fI08VBB-MaR4*YrU zVUwRxtM%yx38^Xb(t;)4vmLaJN~wR0wgvy=nHzQ6lIhX12>(n#da1vR1lkr%VJVFO zw^zv>eLfGTEA_;^60$IUeZSblkDe11x_5_i9Msp;=lFS07E0@2nrI=k*AIQ6wNbSY zDHi+`-07UjM6Ve+olXlEPAryJp{h%w+^fHWthnE^mjotREsuUNBPE6Y z=wPyfyzi~u`ZuRh`fXRzBOWfs#MQJl!#lA7U31hu?r9-^U#Kt5Z5?VJ*8%Lxg_P3( z`7kIAdVsIdGxGMGOKHF*JwChl`} z1V{(7XKirUTRol4r`}l@DnPt9MY?f8>Q#CcaT1aLwtDh=_sdZWY| z%XHUL$Xs6bdaB)KpYe9P>ynMo?jW|dXwJ%aR0vM7SJ`3?_FC*u(QB3P%MbZ#CTr9F z4Ucbg6NvQhC4xlw^{N5oQ6tb+6xOSnfmojTbE27;rtWG_6;SIVtnYI>diBv8@g2Rc zFlfyYis~W6KxqW4_sE0EsthkGd=Rss8|d6CkYq)w2^`Ul%zueN8YqMqy8%tkt8Kk_ z00aR4gTl#Eh3)vpRJv22E`~FiNK2{-j1<{-6kA(lCw|TLaYfrRcYyQq?VBir=>ykP5;XAXNSSNrC{5_{Q>A72B{vI=wC{R0r7%(yU2*OC#mO|f@Lw0 zz$AY+1ai(>FB^gAfWmVN4z9o*JqG`5Gi3=DD;&ELW(C~hmZz0G74Mtb2wD1gC((nL zX(j${QP?rWiYxF6&p=!T^?0P+ccOS|r-ZOpEmn8Fr99=-3MkFt5EbA>s2LD(>1*RF zrOz&Bq5DkbgzbJzHq$+JaslZJX^56JQ%2`BDc}b@CB_7k_->&AB zk|eshsZL-jc#dawB)<8c=}aUWq*{+l@PpXh-wh!}*wRF(Tmr79*8Kdg>D^ZvYy|5% z3w}vQCcITZPax8di#7jadz^D4sM(5r!7mo2fhgbPxBy&$Kk@5Ty*-+Nt`$tR5n5q* z(P!rKi%{f|S)88(abRGe?@aFx@E7+_{tJ=qc3bY~y^45Dz9l{M`rf@wn-7OjSy-%p zN+YVRjQ*=wmK@9G9)ZiC`_t%aVDo?<<%((@-qvu>Uu{%y|J9%jQnL_CEG?Oz%=( z?qX~C)Fj3lmo$a35g@)Z!D@2N9{!Dw&i1mkqxR`BJ!g_9Fh&A>IFCWnO;}X20ME8S zG}-IV>LjMhJcBlNZgwnh{LP{RXCv>shK>s10(yQ?a0e7DnIJ99<(cPSTAYOYmkC(c z46pK%68{vbTIi1^@6!I|DNC1Tl?Cu7K3}ppUVV0UHNc2`U-JXD$Mg+0X{z?DyW=>% zzJEDM0Ej5PT;cwj#n!hYU{_shCIVu<>FA%xUhbAM`tjLc$SQ`sxk*c?h^XqJ@jb)t zX?(!CcqQ=2BMHp9S5A^U=Y8F(SDG-w2P*u>eZn>9CkzSYSJ<8pY1p1~-nUei9W4_1 zxz=8bSIXw25E6#uR{zu;+KKf(VWs^OsCY_D-k%)_J^-JB$L$0}sUFR&=wI$^WI7c@ zqgwI#=vS9~irY4wkIOHV(v^P#rh78p0~yjqHzaLj&kSgX$#k= z3h|~_-5)GVZ=(O1sy(`Q^#C_C-K3Is0Or0tK?h1SXd2A0&nxtu&Tu_$2arhp16(#0 zfK{Ss0%)n%mB;#j8sDx0FQ10zJ7O#9=w9pHm^+KNXV7#ly6n0PPCZjVzfWV203JTqT37dZH`Dof%LaF=av?!1b!O z3hD&f4s|kK?hBgqL@@SozdgD2dT-J<0=Cs_FzR;Bv4Ed9%szEiu-?22h338FSq`X; zkOpt}U^gcVh-zb=4fOe&zQoFpAtk2VxrZFSSH)YWfNnAZtg$!vHI`4yBI@ z;(@bYCGDt54BtV*`-$4}?Y~22yWWL5w{T9KjN(Crqb5X(>m^`KFcLYM;?Va|t9_k_ zZ!7)#Dis^Z`wXyg{yMoYSsfk?+;5} zW@w4(Dsl(C>L_Vf?o@rkivF-H^&6q8@9K4zVQ6phu@8~0vwf0>7^gE+)y`e$Mc@Gs zV29x%HBjd2`2F>xH=LKOjA2E;Nn5hxf~ZsQV?-v}Ng~8DV;jibaKQtCdQ6*gS#s1O z9{@@_zJnNJcA1Y0nsz)x4u1hnYdFpw2_Xp~!6hMlmzOAj@uUD=y)a>I2X{rGaKQk+ zj$q}TLflqB_$w2PfpEd{BBDN%yZ?_79!`hmK5jrDm!}i@yqS zc?gmbCUQT@?DFuxnSYPG*5GEGi=(DejG6v_|2#~6t+xKpQyPk0W4xJvva5Gz3IsLj zYAz;lSB5+X4&gaP4$<-dEE21V&hGusy&lEOzFbz*6G+tX%Zg0172=u~?^e~7mKKG$ z90dyy63X2jlUV)`$a5j?S6L+RXk?hTxen4z^ajW)AMW<@8ND|(_d@iVc;6KG-Ub8| zo&j!9gzq~yuxSDiLoMz>1T1FrK@K0_?S6|dS-L-)(UR!-wKgY1>2b-ybF!|nC}%wdrkq1uWWUMyyFbN{ni$IDvHYpuSS3f zzoQ%dtHiVmX+?QP0C&WHb;9`^hiJcjt)pj43KoXjp?rS7LpSQhJM1G- zeJ%!G?7>3{yY~w)_Kf+H4KfrVyVbz@wc-Y0lC{udDcH8>nKL}8dCC7o>gpvwSh5W> z27C5x@CcdLx8^oKn3(zvzc_3P3CskY)CT&xXwuxag?JwKNzZ)%GB;-YgzFXhMvqkCTqB?(ElquaNE*cuJCTpuTa zQ^Eym-T2-;-A~gr(B*Xg&|N6`vDU7vp`$4c-1yj-yY-hEbOKUw?(p9P2cc)tMc@-AkK%&RBtI?cmzGPRQI>c?{JbhaN9h)7gXB z)U7tsIy`RC93oIMhhMPLnLw{*DRB$>#Rgc~64k0J3_88nQP$l(@>;&{ARD{0C0)SW zmi4j3Kej~>xX_)I@G0+p1-1&-Pyjg<(j!(*x{O|KQG66yJx(Q)m_h?{5z-ARe4sQ= ziW0@}vWiBn(v%wc4e+ z<8P}fl4ws8;-EaR^TGgm06af25Tfni zLmwi*zYh%L-^aHNSmfJQdTV_G7|8b$A_(da_k({QaGLwE$r1dp$x#nNG~-SJ6aLNI zaQeUH3n7C4-!p+3?o6IPGT{+)k^hrFxeE^_@PJN}06XI}2VeNTKb znSr`N*eV~G;aO%=GL9+GlA28V&v6`yy+i0D2@(Wzk%1LdXGR|ZGH5>NE>M`Qx z&S>^k1nZ0Yi87?AiPyC{B?|_m>#DHXjZWJ?<>A$DN9c zy-j1OY<{yV&Zj7=EOv3EtNr^5r5w>SY8>|~tnkRzxo)!)@Pf5b!m;SJu#v`{?3-2D z&bNEbu5ZDn=N0QGkG;w8mG*26b|o|#glVu7FVDdpUj0bSXz5GA$S!O_&fohHuDJb$ zIA23+xbF|RhvX*sx6gWEH-_8HUR%yKg)DJmQ4cqVXvfqX4L`MNu>aA{ z5uxOb2Q*(n2I_wW^-nxa81HIr^gEq1J=nh)i zx>bwM@`S9qpB(K-@}X-i;=S#+^2t3Fy%_yj?;`ECs)w^*uqC z(Yl7F5bGzp`&P8A`J5hF$V6WCyyJM^W=2-WNT5qnjw12iCmk|koV-w}o2c??*VSLk zN=B5qw?oi0YT~l{J4C&}ktlIK{m^nHgC<)2(_Kdb0Ty5VyOV7wd-LXp^MPKb~cSEw8 zj?XpFu|>zXJ6?L=aqPYG-P?kk>Md^*Z{9Tdy_t8U?GJp-*vmm1{X(7HEZyGV_2){S z7cMz{y{oP)PWa^}aHL7r`qIH#qmRy>j#*->7^O4Y&D>7pQn$Atk8j@`QDEvHasTvT zlZp>?hKh4;&8B@)f4A^8R4SSq=t2kX!P9B6`izKPJ%lpTBC^4`{}le6)#anE)8fIi zth>t1b%Uo~81eGuYOK+>HaVI3jM5#oeHPo&Ri-Yhn202$WE40phaRh2IBNo(JOtkB zn4|rIZMSUg2+V`B@zKBZi-j|mPG@g1cyRjKy}ZU}dHIu06tgu#Sx&)r#}c&RHAR#1 z2khJVgT}|_^k=lQ3#x3wcPL?HE?>3(@o|`Vm!N>y{e%aXn}J=#Ys!0ERxF-8kdH)t zM0osJme;X+v@9olzzWKrAJ)&-jK2G(H5uxOdp4C1et92BNQ+7BeDvD%+*73b&Q@*N zEgNoHb4qsHJlzwnyPc^s^YSP|;YGytozo6{4KsGH`xI5=ZiEJP8dEm&EEYmEvp&e< z?~G53=`Iyx>53=LHS&jQR^N6}aU>faHaJjxYaqanEk2S`x@D8P`gz3oAcuTJJ#g4L z*G5^)aktKd44#v2u9&CwiC_|?nO&Vp%;bNGh&{nS`)~{mQ*S@CBN29GTOEJP>6owl zE&N*B(K+)6WvL$b_^?t6($`%&V&;B^y_25$xhudLk?6-#Zc0_&wr%)OnPW@0*SM57 zJ?{+7z?Y8{LisSlKRtQi0_@9ERfXz@ySC)^AII-8(LN_;?yz=ISJLP2w*$Cq#ynFZ zcwkCW#JufZEh>zbo+G|@ux}{n2|wp>0<GA5^(JS!;e-3i>JcrgtFgR6Z7@lU0d^z^H@N<1m)N?Pb?b}|v@i56XL+C^Q)c@L9 zmJDLn{key1@NF^Xs>aPjNq*J6A}lANU-BO9<9E-DUv=hd7~dy7@l(ZcqOks_Q$>FT zY7l>Fhhq)o^V$odv|phrk~4SvhA=nl06lH>UOS(7LRvB~D`MkvfFpWraNqNx!Jj?j zw_{lR2dxn}@PG3tMmjEXK6{xbZtjjU-+lj4hL4b=Y{ttaPpontcHX*H043&FX zarG^Kt}{C+#wgPZUfV%w(<@rloehj$+r|0(&Q5{xth?fhk+PZ(Ea2s&L!qiD41OPZ zM8)!3R;x40rt`&=xBTB?2iqii+VpWM{jc*z{=Dh{h$_x8u6)spf`-+HVdajE&0a1o z&r1B0*yX>U;qf~<2FR)f#@|~f=cUD(`llpd?}04Z|Jvf|v3v4q%1@8ph0X0#%-{F) z!J~C(&fffVZt`Khf9O`U((iapWK1vaO~9gFn#2CVr^j+al%vzKd#ZNI-qJOBd!y~B zTc~O8Z(E(CaWb;EJgTo6w{4n?_^uEpHq#n8cj(BzPk|SoerhU?!46&8{QP%^=Aqvn zuL>Uy2Uka);6M2xcs9Us<#8P&l|be!bdmYm zt|b>Z^?KtTnBaYLF+O_D)jJz6mE4VZZu~Q?*}CE9iL-Ub`bt`L_;t#rkJCjCYM-uF ztrxgrABY~6^4qZU+gbmaeBsGq_vO}cq1!i3(PK?tpZv46a5>Am@|GdQ?$!y_Z4Rn~ z!cUwwlap;uVg+9o9!R5CnT8V&$Q93a31~L_UyXfbP+URNE-daAT!K5n-GWPSm*DR1 z8eA6F0D<7{?(QxD7J?-NhsA?)H_3bNulv>a{+X(+bE-~vPftHjchAgX4YIw%`L|Sh znhp8DzM}b6_>?T@L&{y`pp_qWP`o|w3qg_;*pv!pR%2Hc=dsa|41DrfpY?awkc0>dWTZ5=$hxm2@xdZMVNe~?fVm-0VfXFj@JQS0K(KsbM(!DlubqP zPB|ygXCDG~8=P>1E*jJJS zUQP(DnS^E9Ja=foHYrr3hp*VqIR~}CAUNq9zuu_-em29=9kHjh%<+(hn9peAcHIYo znbqFq2fr5*8)NG<3wdV_#IryD^X$4_b^G;vG;;A)2e+0s(;xR|{i|nt;yujaLI~d7d$N6%`i4w7970O7|qFR|vX!SnQ(?mf`>N zzzB5RE$^0JRLBT&?|Ljxl46A_nrwG3P>wkI#%e1}s3(_Q<_voe$+kl-Ef=7!&noho zsyJjE7{$<{n@==C2b79?p(FxSdl_@@xHog7>{p)526XL4rCeOieq+j6xQ@IvQn3K` zr(hG5{Gx;oJV0<0?t`FllnDVZ5*s9pH$o-s>5#fS35AcyQKndw>7@6f%+Dc70lvSW zg%1`1bE~x`2YSCE-falE*(5?pT+QYqeza~iP|06;TdUIVl@f;bK;r?0ZIcSItakTo zi_oa=&Z7~tlu2Kfj6f%x_#^?v>|z`n#>N~UsSYG^4(E3T`8fW6J}wbi<;bxl-|${( z;R;z~F`Yo5LOfWQRee<^#^!nHhb>INLaW%rRRLcRxnw+`vi>%>$Jt4w91bLYO}wrL z7pINATqAvpi}C3I+xR>2uUxVwyMi^{+~1pyq#63SwJ+7`5OU^KUwZM`bI6T-)PTAV zlACfg2O9;pB=AC868uK|XM{-UD#x= zj0>N&nnM`8qq8H`9eET~EL3NPbB*|=7MX1c)QU?=z`c;?v3)!+=LlitYn5 zjVai|5@x4##NLfDGJUIai~@7g;)IY#Hu`-HtJ4`4G?uaSzSTRgqcwu9$LrNo6UPL{ zt^z&3yDi|oPFFaqxO(>+?ClUn-cE0k$CK(!=diyTUrWL2z_uN^&1i1l^c?X^)Hv}Q zjT&0is!UaW46f{@wX(^aV%+hZr|Q>HO~L39t>JU!s%0dI7hp{}$yGZtDatZar7W>@ z-Fw*8+%l0wg%BZ1|l8G`;k~`$(2qR6Kq&9Se-d^ zEMqKVrP`l6=20)09%m&nt@mV;N`dH`v|sgt*Mlgm4n?Lxo-6F0SPWFG0>{^iOLDe4 zP*YoJh3gTgjRGKAYqRA%cVagP18rx^6YQ(uW)rx>ujQGmT77Bc0gXzu(zgVPI^1dA+*YwQ1 zGog|n=wj|mYo$9us9Qqr)2c(#T`!orXr7Z!o0VSswh>TfLRIs#IZZdHG?9pGroCDg zGlfCegWU`^iQ)a;#HQm}d2XNIBO{;J?t^s#dFDP6fuY&0V_?Oc6@uX)7Jyy#3g;FZ zRYdxP!!ymzQ2*3q<(?jofCv%cj!8g;bL@PE*h+;`uM-eNl9PY1a~CW1$l0C{iYoI1 z-?;+8V{G}h|L6dTqh>7n$rj7_adGeM+vaZv@LHs8V$ttu7?UdzB0ht9N1Q|qKudDu z+;9X^lgja`npDqcRR)=22B){lK_1{vk? zIl?1>AAN;me&o*!c)Zh|5POtIW%LJwY9VJB)|seSS^#QGF1=bn$6N?8y9sg7*&m0% zlnj;%7Cw$km?ivxQFu&H0?l986Sv5qBfn=N87$~3i{Ous2JZSnDjy=C+$x+fGaElC z**7uuQT-}p;5vRsY{P!eVfYQb(!8jLJVX+*|lNi>HrxK+sK zkA7(e9R{zb4HEG=F>kfW+O?KtQLpigop+_f9ZD~pM4bleiM?UQ#2||QbBFG+^HEA` z{|s4_X68{zU4$k4o{blj!nqY+X!P9_bV>cVypbT)g{Qg!4P^~!Cq`-L{16lI2PK#Y zOg+ItLLw@Fy*p`_p3+{M=Hf&ql)@{9A?|6X+U+vX=v36x9U-&r1{kY4mL&i{-8<_t z-`5Q=yuKGboCx^GtP;FyK}83g zkzP!|62)?+KB$V<+bNQ4Wsh5sei77vG?e>f7$I8z7_!36TF@ii@M$6e)rr+R(w%&B zlE3xYg(GD2ILK^GNiYxaD&ntx$)66DA^iCvr#`fYR z&@}M}02PqmQ=75si|`US_rOGhM*YW~jM@LFR!yi^=>&=mi|(l>^X2N9)DQ$RZqZ#2 zG5TaBQpK@(lqC~r1etlK^u!jy1fb>cqKsVJ+)mH7+wX^N$KnKu&8|Znrf6PZ6Yilg z{82J^`Z|0h0iLilzgouaM#|COCO}Xa+np`R_m(80re- z+yqs-d*c6Wbx~#ff9Kfvpt$5n|ANQktT=3_qs;S8LbS+)Qc7?PvsnVw>04)>``6ah zyT@?ts1NF$PE~P>S`k(Jwz9z%sMw}5z!NPjV@|uFho4>7dJYNJGx4!!N(a(|q=U;fY1?v89y-#TI~dwipsgN=P!FoA)?9UJ(i zz)(InVmyNPd&_6{ulv9YPt&{<{&VLp^8E&3iX=e~hg8Jt_?J_}^8o|xD#fbj<5^;< z;8Uk4+_+)xsRKigZE0N)A@lXL5l2Ul&(-TwpGD^acI z3R2su(X`{v-I|$%RH0n%kyc@#dtdj^`OGn0BJN#ifOUJ3uhK#p5QmWjX-|$Q7*@~D zi|i=&a*LxV7Ru}^m!b3b#<@D=QgpkC+eU9Q({=$4?0lo#oa*R@f<-~!k2`z$yQ_W- zARB$zf@Py%Lh@(4oEyXeCHr^AnPp@8}p1Cy2O&n)bw24s%LIB~`UUPHFM`j(SqpG@>#^ zS=s_%3K3lBgDXTB8cw1R_7Zq+GWzZJBHZJ)E2510f@DFx)u;vSR=Zvhv?WgbkP>c3 ziAAoT{SqFo_CCoP@HqS*EIJPh77?sgVMD{CS2*2gzaF}LlR~ZN)lxf|K#6{Qm)IC) zkmKNok;qY#!A&t{>7e61{BoR5hi%p`%6%3*+$zY`%<=H~?d*ECC&2zI7Vl4{sUX(O zw19rGn=f_0xL7ZHC%4O3xGX;~K<~{quelajuXdQ-re&0h{^LOss;LV}gR5z|fT>fz zsz8%9l@-+jf{Ozdj-?&ZNlz7k2 z&$S9OmO39}klAWYL@vka9E7hr4(DB7N$Ni@ApDuGu$Wd49Aodoc)&+dr8KD4ss_z~ zN)ehSI}W))>cC1y+``}wd8L_oIR&E9xL~5Kt48<;D1Cf9!=BfU&ok^u#8+%0SL?-* zYVThE(ls#22;CsffvmZ^^cGyi{2tEmwyNW@{OuFVSpEYyyHgFlrq^-CQL%QHBdiORC`ZRB7)WdQ?nne z5I*IRILL^aXp}h$7)I#@a7-&RC9pYGyJH)}&wC49T~7zatq%BerqTXgQe4w}z;RWi z#WTJ16ah-d@b(||WWSE|d2;~w=D3QDe16vJN^%PUF`x;iWvSc6Ety}nEra#BomwuZ}_Z6r6Sa9iD9L0xy>r;IONgG zn15cqa=4pih*ATDoCO&W0i0Lg2#*s&IaPdOGZZ0FzX8!bCNNZUNxF zV&U!IwcX<^kO^lX2%9&n%^Yx5jf5p4td6Z?yB zAC*S<7^hVww`dalbp+)rQTy;E8)9uV7H!*r=e5%dFr1#3CUSC``JGgVozNyLV5(CA zh-Mr=BO2kbz%$J}vpXQz%I=O`_)GA^+7fD=zAI7= zV~#Y77*d}=jb_Ioh|?d(`zmc;&rW@0{bhQ^9?KZV=SGMAyriiAjqZmoeA3W+;Tmk0 z>Z+txrY>VjO04b9D(<`4CAXiPe?VaItr(Hm6B%`rLMO5c62z=7t$sTbA% z_(AUKCOPzkDcq;AD7hixc2vf}!l|~PTmts47vxnuHOF4ygjbC5mjUI=zZCkHsRd-Q zQ+4gR&5hs3H4xKgB~CfflTqRh4KA^PmT(XGl~GN*GgIwMzhBTT?Ni`NcDl6O6H81x zH^nxYK!RvQF2Dp4xYGgTV{^@Kj(O%7=-rFusSgyj4uG&9%D$6U#ykB#&%@hI`3%(t z+kdR}K6gcTw^G=}V`D<|T%2L-1pGBhBD3`*WsN~QuvTe1535z3u=h#fM5VpF&@;d; z!0XIfu>G!qm%Lym>`Q&3Bq=kE9ny-6KtXxl%5b3|rnF&Re(8IK*!*3Bp0hSW1nES) zh6=eArj+W6B%Ig|70#{Ly>-gIu2#mjq%p(;kD(rgG36OEl7{{Q-w|<7UK&M6Zfqr> zS8;K;X|%G5Q(SNxBMb>w#W!SK%hFgj4KVjpKph24i`nox)w0J8@QR*o!z#Z z(`$CQ9`w-zb#&sxM!)5F6(!zfhKH7r| z8>=Q9_@==*7IKyF8_{EO6A-=}vEH1ME*e>@?M$Wp zWK&})gnZQbA}b<#3Znc1LB@prQ+81ePv2!9uhLguratg+BxL@FZ|2WlL0 zJL76xQ4!gkTW=WiFt-uq%e+WONIFI@)o0HLyyWo(j-i&iOZ;JF`4;W>Uzx%<8PMX} z2tOIU)R2q17^xlEMJLL9*P{;m84|%bhUp^eakjHBwACn|TM`vHr7pPC02=$!$*?9p zj&n>27PDKzM?ok+`-OI@+#tY7h%7EH0?tF4szDxc;}@Az8L_|ptbcCxfhTyi@cz=Q z5`JoCqQ|+gD;gL~CVT##*&SQS0D1(ikI$etr`_UOKbNF?`pM$Kbp8c`RpyDu7K#U@ z+$!0|fexaY_Oq4svv_BrI?1=*T`#;Pr{*NvzXZ`z9si<{uAt#14$E>-mnFKt_CVAV{I=CA7kx0mzNq1g~$qG2au|EpyV`4`7t8g2ePk)7v$j+ z<;V2x`0QQ0{UHK+;zgh3bL-nAzc;1>?w5X(hZD3h>G8TM$%f2>$QGbrOK9!+r~NVm z1I5!lrv2_|QLCqG-9=R*D7x=h%muHye^YeVGTTgq>okRTK8&7E7jh_((_%FD1RYgP zo(daiGj0hB)yQ$Am|dvFk*VagrWQCehuqm6e%)JP#NKCwvcR^fzW)tN)0~n9mM(% zAa&}qPcjBKn9i17pE} z&rq|5W06ljr}ZnP9H#i`EUT##cq0BZJY36B_00C8H{XKeD|8QzukHt)Srw+_U2AOi z4ztUt`kvO21nGXhPk&=TTSU{#K|}2MaFj|a57sZ^uJ^XGjBgg~r>+BMyp^W!VZx*g zwHMR{KCHL0tYP=Yd9G3ZNt^l@g3^kO*<8QUhtEpfXD=&1*kq**&<%BGDg;MO#NrS} z4=uBr&t^rh9d_IyHKD2`jY)uo?xr|SWLvrRI=gGuy*FgLoT$Oc1j-Dhu3kuRKK z2O0RY)^w>lajU3!84nHrRvXM|>#DN)%~?)C<%IA-@dM4~{1@aunizbZJrC>8f4P)CS0tbJKrWx< zdHH<@<$U?RIo6Vv8a#;pDQc0QTU||v@8<04J;nNG$T8HvXXJOjKYn9FhlBdY?Wj(r zO14Cd#sQh1O?)}gvzj)kLjPbZn-PGT37B%qT^az7+dmR@>{&CD(Jh{I=nQN{9QI7W zd|7wbIX_o2O8jfJJ6ZO+)E5rxtiqd}b)j9JGrAQF24|i*HKvc*JRxE8%&^R^#`Rm! zuJ69j?|LWfX+}ed3fKwLka&+B6`GSCEET)AidOL{U;w(|ulb7wXE#c=kpJSKOP(h> z@cBKr`c@#n_p%L(AHn_cEY*5cWVvVCCc_0-wJ2@sQ|@e*5MXAi=RK<0@+T8tp+^hn zlPyh1d2Hv7qHw+j(Rtvo?UktfCP0woio~>cc{8WtxBsgOece$XnxfL7*Ccw<-&BXM zE(Iro((zP2wVrVOV71(-z^N?6zC=_WO`Thq;;uaUHqmA?=cN)oAvllqJAU+V6h<$d z7Ncjksk3J2fvL1(52CX;o8J9rm-U-2Ctgw}YPhGS%cup1jHTB3fUiP2k?JvrTf*It zQYU2kF%dkc0wvPc(uM~D9Y(w$fn{FdwkN-AWqkY_jq;I)I1#O?kOFEGGt>xK2aJgt z6p79tNH&xGgHmuAUv3MjkTVxCKD1Iq%UPd{+iz=+MbnD&_8)p`FqBQ;=}ika!t(^C zc)lR5NC($u@4p$hn6IggNw3%RD2V0$T_RPOU6NsC>A%dv|%8J*-| z_231;Ap{}lrTnF72-swcdZc%*sOeu=@aiQYt^9)dNBOz z0xzX;h(pxT!$vWgB*z#()g5M(51saff7YK>Z8rMpsMYexHbGu>fq-#+goITMX5gmH zf85Brc;|WGDBa0n${@(fas~nst|89U0+1#I(<0DmWl`eJzF*|=^>8p_xz%$K=AS(% z9r;Qb(IOr^e{n!kMGZ@PxeunW$8G_&Eqm=2+*0}s+)+5R*xEg{t|MTkrH)Spq(ubO z6zlg-zfal}MjD%u(IT4Y;~pr(#xf|0$1ANP++G3`E7aM&H64wpEf0>UeQfF+)R0c} z;v%nzC6#9dYMZiuf7NU!kWo--e3d7B$!81VqrWbW487yY@&nx}zij|{88PP%{Y7lq z(1T^gM2XnZr;z}n&pyw5;^YFFh%3of-NP&mn2M+0B`(N|$y|`1va|M84`Y`t1Vts{ z7RS7zAE?an+j$vJ$9*>4yV#cSX@rBn4Cpi0qDAQCX)d#B+uZ(vBS;y-&Vd|@{}z(3 z;hN>R?A(Fz@6b>D^__z1FaqxeV+Ml=QcX+*fA9;wv&H?4Dr1`2VO|1Dw3tSZ#XFJx zBI2KdCftM%ZjZ6|#D?upg!B!T;3Wj}YoC^W#FNCX%iq7#siP_Lj=^RwZmHI|!r8k1 zxRqo@3U|pMd~Q^jj;nFQ@SCOYao13xt^8qd7Z%c*G0Z8mm7YmpN_bB#o;4U=!*tFVXflyB#j4rLp6JZKB` z4y;$)MwN{pkcP?eCCP&f$8lc%=+`(10_48RRF0XuAvu%$``}&e4)w{2uLo`Y&Nnia zId}FuJqR-j;y5-RGju`tj^CP+#ag~Flhwc6t#i6fd9?&t$lv~+#2EBT(xgbR4>=Dy zU8#rcu7~O)4?5O8VT3fxAwT|H1pdDbfzXhx{(X@5{>K3k@<#?x|LySSK53}`atQu+ wxZv*zga38-AI1CsJ@{{>{rld3DJ>K Date: Sun, 16 May 2021 13:55:52 +0200 Subject: [PATCH 019/762] Standardize badges --- README.rst | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index cd635835c..70dec88db 100644 --- a/README.rst +++ b/README.rst @@ -2,26 +2,26 @@ :width: 480px :alt: websockets -|rtd| |pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |tests| +|licence| |version| |pyversions| |wheel| |tests| |docs| -.. |rtd| image:: https://readthedocs.org/projects/websockets/badge/?version=latest - :target: https://websockets.readthedocs.io/ - -.. |pypi-v| image:: https://img.shields.io/pypi/v/websockets.svg +.. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg +.. |version| image:: https://img.shields.io/pypi/v/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-l| image:: https://img.shields.io/pypi/l/websockets.svg +.. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg +.. |wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://github.com/aaugustin/websockets/actions/workflows/tests.yml/badge.svg +.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml +.. |docs| image:: https://img.shields.io/readthedocs/websockets.svg + :target: https://websockets.readthedocs.io/ + What is ``websockets``? ----------------------- From f029e81605a80f37da1aea3c73644b7edca7f7a9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 May 2021 09:16:54 +0200 Subject: [PATCH 020/762] Move Heroku docs inside how-to directory. --- docs/howto/django.rst | 4 ++-- docs/{ => howto}/heroku.rst | 25 ++++--------------------- docs/index.rst | 2 +- example/deployment/heroku/app.py | 16 ++++++++++++++++ 4 files changed, 23 insertions(+), 24 deletions(-) rename docs/{ => howto}/heroku.rst (89%) create mode 100644 example/deployment/heroku/app.py diff --git a/docs/howto/django.rst b/docs/howto/django.rst index fd170c387..1b4b4d3b9 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -1,5 +1,5 @@ -Using websockets with Django -============================ +Integrate with Django +===================== If you're looking at adding real-time capabilities to a Django project with WebSocket, you have two main options. diff --git a/docs/heroku.rst b/docs/howto/heroku.rst similarity index 89% rename from docs/heroku.rst rename to docs/howto/heroku.rst index d23dc64c0..876a9160a 100644 --- a/docs/heroku.rst +++ b/docs/howto/heroku.rst @@ -1,9 +1,9 @@ -Deploying to Heroku -=================== +Deploy to Heroku +================ This guide describes how to deploy a websockets server to Heroku_. We're going to deploy a very simple app. The process would be identical for a more -realistic app. +realistic app. It would be similiar on other Platorm as a Service providers. .. _Heroku: https://www.heroku.com/ @@ -39,24 +39,7 @@ you'll have to pick a different name because I'm already using Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: -.. code:: python - - #!/usr/bin/env python - - import asyncio - import os - - import websockets - - async def echo(websocket, path): - async for message in websocket: - await websocket.send(message) - - async def main(): - async with websockets.serve(echo, "", int(os.environ["PORT"])): - await asyncio.Future() # run forever - - asyncio.run(main()) +.. literalinclude:: ../../example/deployment/heroku/app.py The server relies on the ``$PORT`` environment variable to tell on which port it will listen, according to Heroku's conventions. diff --git a/docs/index.rst b/docs/index.rst index 4ec682f69..57d158e8e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,7 +60,7 @@ These guides will help you build and deploy a ``websockets`` application. deployment extensions howto/django - heroku + howto/heroku Reference --------- diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py new file mode 100644 index 000000000..aceb754a3 --- /dev/null +++ b/example/deployment/heroku/app.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import asyncio +import os + +import websockets + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + +async def main(): + async with websockets.serve(echo, "", int(os.environ["PORT"])): + await asyncio.Future() # run forever + +asyncio.run(main()) From ad8ea999391ccd3a7d97edd7a36bd228fdc6c09e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 May 2021 14:31:48 +0200 Subject: [PATCH 021/762] Restructure documentation. Lots of small improvements while proof-reading. --- docs/extensions.rst | 97 ------------------------ docs/{ => howto}/cheatsheet.rst | 6 +- docs/howto/compression.rst | 51 +++++++++++++ docs/{ => howto}/deployment.rst | 20 ++--- docs/howto/extensions.rst | 31 ++++++++ docs/{ => howto}/faq.rst | 6 +- docs/howto/heroku.rst | 2 +- docs/howto/index.rst | 35 +++++++++ docs/index.rst | 49 ++++++------ docs/{intro.rst => intro/index.rst} | 42 +++++------ docs/license.rst | 4 - docs/limitations.rst | 10 --- docs/{ => project}/changelog.rst | 35 +++++---- docs/{ => project}/contributing.rst | 20 ++--- docs/project/index.rst | 10 +++ docs/project/license.rst | 4 + docs/{ => project}/tidelift.rst | 2 +- docs/{api => reference}/client.rst | 0 docs/{api => reference}/extensions.rst | 10 +++ docs/{api => reference}/index.rst | 9 ++- docs/reference/limitations.rst | 35 +++++++++ docs/{api => reference}/server.rst | 0 docs/{api => reference}/utilities.rst | 0 docs/{ => topics}/design.rst | 99 +++++++++++++------------ docs/topics/index.rst | 8 ++ docs/{ => topics}/lifecycle.graffle | Bin docs/{ => topics}/lifecycle.svg | 0 docs/{ => topics}/protocol.graffle | Bin docs/{ => topics}/protocol.svg | 0 docs/{ => topics}/security.rst | 18 +++-- example/health_check_server.py | 3 +- example/secure_client.py | 3 +- example/secure_server.py | 3 +- example/shutdown_client.py | 2 +- example/shutdown_server.py | 2 +- 35 files changed, 345 insertions(+), 271 deletions(-) delete mode 100644 docs/extensions.rst rename docs/{ => howto}/cheatsheet.rst (93%) create mode 100644 docs/howto/compression.rst rename docs/{ => howto}/deployment.rst (92%) create mode 100644 docs/howto/extensions.rst rename docs/{ => howto}/faq.rst (98%) create mode 100644 docs/howto/index.rst rename docs/{intro.rst => intro/index.rst} (85%) delete mode 100644 docs/license.rst delete mode 100644 docs/limitations.rst rename docs/{ => project}/changelog.rst (94%) rename docs/{ => project}/contributing.rst (76%) create mode 100644 docs/project/index.rst create mode 100644 docs/project/license.rst rename docs/{ => project}/tidelift.rst (99%) rename docs/{api => reference}/client.rst (100%) rename docs/{api => reference}/extensions.rst (53%) rename docs/{api => reference}/index.rst (92%) create mode 100644 docs/reference/limitations.rst rename docs/{api => reference}/server.rst (100%) rename docs/{api => reference}/utilities.rst (100%) rename docs/{ => topics}/design.rst (87%) create mode 100644 docs/topics/index.rst rename docs/{ => topics}/lifecycle.graffle (100%) rename docs/{ => topics}/lifecycle.svg (100%) rename docs/{ => topics}/protocol.graffle (100%) rename docs/{ => topics}/protocol.svg (100%) rename docs/{ => topics}/security.rst (56%) diff --git a/docs/extensions.rst b/docs/extensions.rst deleted file mode 100644 index f5e2f497f..000000000 --- a/docs/extensions.rst +++ /dev/null @@ -1,97 +0,0 @@ -Extensions -========== - -.. currentmodule:: websockets.extensions - -The WebSocket protocol supports extensions_. - -At the time of writing, there's only one `registered extension`_ with a public -specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. - -.. _extensions: https://tools.ietf.org/html/rfc6455#section-9 -.. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name - -Per-Message Deflate -------------------- - -:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable -the Per-Message Deflate extension by default. - -If you want to disable it, set ``compression=None``:: - - import websockets - - websockets.connect(..., compression=None) - - websockets.serve(..., compression=None) - - -.. _per-message-deflate-configuration-example: - -You can also configure the Per-Message Deflate extension explicitly if you -want to customize compression settings:: - - import websockets - from websockets.extensions import permessage_deflate - - websockets.connect( - ..., - extensions=[ - permessage_deflate.ClientPerMessageDeflateFactory( - server_max_window_bits=11, - client_max_window_bits=11, - compress_settings={'memLevel': 4}, - ), - ], - ) - - websockets.serve( - ..., - extensions=[ - permessage_deflate.ServerPerMessageDeflateFactory( - server_max_window_bits=11, - client_max_window_bits=11, - compress_settings={'memLevel': 4}, - ), - ], - ) - -The window bits and memory level values chosen in these examples reduce memory -usage. You can read more about :ref:`optimizing compression settings -`. - -Refer to the API documentation of -:class:`~permessage_deflate.ClientPerMessageDeflateFactory` and -:class:`~permessage_deflate.ServerPerMessageDeflateFactory` for details. - -Writing an extension --------------------- - -During the opening handshake, WebSocket clients and servers negotiate which -extensions will be used with which parameters. Then each frame is processed by -extensions before being sent or after being received. - -As a consequence, writing an extension requires implementing several classes: - -* Extension Factory: it negotiates parameters and instantiates the extension. - - Clients and servers require separate extension factories with distinct APIs. - - Extension factories are the public API of an extension. - -* Extension: it decodes incoming frames and encodes outgoing frames. - - If the extension is symmetrical, clients and servers can use the same - class. - - Extensions are initialized by extension factories, so they don't need to be - part of the public API of an extension. - -``websockets`` provides abstract base classes for extension factories and -extensions. See the API documentation for details on their methods: - -* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for - :extension factories, -* :class:`Extension` for extensions. - - diff --git a/docs/cheatsheet.rst b/docs/howto/cheatsheet.rst similarity index 93% rename from docs/cheatsheet.rst rename to docs/howto/cheatsheet.rst index a71f08d74..86684c44c 100644 --- a/docs/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -65,7 +65,7 @@ Client Debugging --------- -If you don't understand what ``websockets`` is doing, enable logging:: +If you don't understand what websockets is doing, enable logging:: import logging logger = logging.getLogger('websockets') @@ -79,8 +79,8 @@ The logs contain: * All frames at the ``DEBUG`` level — this can be very verbose If you're new to ``asyncio``, you will certainly encounter issues that are -related to asynchronous programming in general rather than to ``websockets`` -in particular. Fortunately Python's official documentation provides advice to +related to asynchronous programming in general rather than to websockets in +particular. Fortunately Python's official documentation provides advice to `develop with asyncio`_. Check it out: it's invaluable! .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html diff --git a/docs/howto/compression.rst b/docs/howto/compression.rst new file mode 100644 index 000000000..9023cec56 --- /dev/null +++ b/docs/howto/compression.rst @@ -0,0 +1,51 @@ +Compression +=========== + +:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable +compression by default. + +If you want to disable it, set ``compression=None``:: + + import websockets + + websockets.connect(..., compression=None) + + websockets.serve(..., compression=None) + +.. _per-message-deflate-configuration-example: + +You can also configure the Per-Message Deflate extension explicitly if you +want to customize compression settings:: + + import websockets + from websockets.extensions import permessage_deflate + + websockets.connect( + ..., + extensions=[ + permessage_deflate.ClientPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={'memLevel': 4}, + ), + ], + ) + + websockets.serve( + ..., + extensions=[ + permessage_deflate.ServerPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={'memLevel': 4}, + ), + ], + ) + +The window bits and memory level values chosen in these examples reduce memory +usage. You can read more about :ref:`optimizing compression settings +`. + +Refer to the API documentation of +:class:`~permessage_deflate.ClientPerMessageDeflateFactory` and +:class:`~permessage_deflate.ServerPerMessageDeflateFactory` for details. diff --git a/docs/deployment.rst b/docs/howto/deployment.rst similarity index 92% rename from docs/deployment.rst rename to docs/howto/deployment.rst index aa1af211c..35720e8f1 100644 --- a/docs/deployment.rst +++ b/docs/howto/deployment.rst @@ -34,8 +34,8 @@ On Unix systems, shutdown is usually triggered by sending a signal. Here's a full example for handling SIGTERM on Unix: -.. literalinclude:: ../example/shutdown_server.py - :emphasize-lines: 13,17-19 +.. literalinclude:: ../../example/shutdown_server.py + :emphasize-lines: 12-15,17 This example is easily adapted to handle other signals. If you override the default handler for SIGINT, which raises :exc:`KeyboardInterrupt`, be aware @@ -151,8 +151,8 @@ Under normal circumstances, buffers are almost always empty. Under high load, if a server receives more messages than it can process, bufferbloat can result in excessive memory use. -By default ``websockets`` has generous limits. It is strongly recommended to -adapt them to your application. When you call :func:`~legacy.server.serve`: +By default websockets has generous limits. It is strongly recommended to adapt +them to your application. When you call :func:`~legacy.server.serve`: - Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of messages your application generates. @@ -171,12 +171,12 @@ Port sharing The WebSocket protocol is an extension of HTTP/1.1. It can be tempting to serve both HTTP and WebSocket on the same port. -The author of ``websockets`` doesn't think that's a good idea, due to the -widely different operational characteristics of HTTP and WebSocket. +The author of websockets doesn't think that's a good idea, due to the widely +different operational characteristics of HTTP and WebSocket. -``websockets`` provide minimal support for responding to HTTP requests with -the :meth:`~legacy.server.WebSocketServerProtocol.process_request` hook. Typical +websockets provide minimal support for responding to HTTP requests with the +:meth:`~legacy.server.WebSocketServerProtocol.process_request` hook. Typical use cases include health checks. Here's an example: -.. literalinclude:: ../example/health_check_server.py - :emphasize-lines: 9-11,17-19 +.. literalinclude:: ../../example/health_check_server.py + :emphasize-lines: 9-11,20 diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst new file mode 100644 index 000000000..fdaf09f63 --- /dev/null +++ b/docs/howto/extensions.rst @@ -0,0 +1,31 @@ +Writing an extension +==================== + +During the opening handshake, WebSocket clients and servers negotiate which +extensions will be used with which parameters. Then each frame is processed by +extensions before being sent or after being received. + +As a consequence, writing an extension requires implementing several classes: + +* Extension Factory: it negotiates parameters and instantiates the extension. + + Clients and servers require separate extension factories with distinct APIs. + + Extension factories are the public API of an extension. + +* Extension: it decodes incoming frames and encodes outgoing frames. + + If the extension is symmetrical, clients and servers can use the same + class. + + Extensions are initialized by extension factories, so they don't need to be + part of the public API of an extension. + +websockets provides abstract base classes for extension factories and +extensions. See the API documentation for details on their methods: + +* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for + :extension factories, +* :class:`Extension` for extensions. + + diff --git a/docs/faq.rst b/docs/howto/faq.rst similarity index 98% rename from docs/faq.rst rename to docs/howto/faq.rst index abd396cde..0196fd2c0 100644 --- a/docs/faq.rst +++ b/docs/howto/faq.rst @@ -181,7 +181,7 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: -.. literalinclude:: ../example/shutdown_client.py +.. literalinclude:: ../../example/shutdown_client.py :emphasize-lines: 10-13 How do I disable TLS/SSL certificate verification? @@ -324,8 +324,8 @@ coroutines make it easier to manage control flow in concurrent code. If you prefer callback-based APIs, you should use another library. -Can I use ``websockets`` synchronously, without ``async`` / ``await``? -...................................................................... +Can I use websockets synchronously, without ``async`` / ``await``? +.................................................................. You can convert every asynchronous call to a synchronous call by wrapping it in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 876a9160a..c6d81ab1d 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -3,7 +3,7 @@ Deploy to Heroku This guide describes how to deploy a websockets server to Heroku_. We're going to deploy a very simple app. The process would be identical for a more -realistic app. It would be similiar on other Platorm as a Service providers. +realistic app. It would be similar on other Platform as a Service providers. .. _Heroku: https://www.heroku.com/ diff --git a/docs/howto/index.rst b/docs/howto/index.rst new file mode 100644 index 000000000..bcacbc174 --- /dev/null +++ b/docs/howto/index.rst @@ -0,0 +1,35 @@ +How-to guides +============= + +If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. + +.. toctree:: + :maxdepth: 2 + + faq + cheatsheet + +The following guides will help you integrate websockets into a broader system. + +.. toctree:: + :maxdepth: 2 + + django + +The WebSocket protocol makes provisions for extending or specializing its +features, which websockets supports fully. + +.. toctree:: + :maxdepth: 2 + + extensions + +Once your application is ready, learn how to deploy it with a convenient, +optimized, and secure setup. + +.. toctree:: + :maxdepth: 2 + + deployment + compression + heroku diff --git a/docs/index.rst b/docs/index.rst index 57d158e8e..000c19ca7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,25 +1,28 @@ websockets ========== -|pypi-v| |pypi-pyversions| |pypi-l| |pypi-wheel| |tests| +|licence| |version| |pyversions| |wheel| |tests| |docs| -.. |pypi-v| image:: https://img.shields.io/pypi/v/websockets.svg +.. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg +.. |version| image:: https://img.shields.io/pypi/v/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-l| image:: https://img.shields.io/pypi/l/websockets.svg +.. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |pypi-wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg +.. |wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://github.com/aaugustin/websockets/actions/workflows/tests.yml/badge.svg +.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml -``websockets`` is a library for building WebSocket servers_ and clients_ in -Python with a focus on correctness and simplicity. +.. |docs| image:: https://img.shields.io/readthedocs/websockets.svg + :target: https://websockets.readthedocs.io/ + +websockets is a library for building WebSocket servers_ and clients_ in Python +with a focus on correctness and simplicity. .. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py .. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py @@ -40,27 +43,22 @@ Do you like it? Let's dive in! Tutorials --------- -If you're new to ``websockets``, this is the place to start. +If you're new to websockets, this is the place to start. .. toctree:: :maxdepth: 2 - intro - faq + intro/index How-to guides ------------- -These guides will help you build and deploy a ``websockets`` application. +These guides will help you build and deploy a websockets application. .. toctree:: :maxdepth: 2 - cheatsheet - deployment - extensions - howto/django - howto/heroku + howto/index Reference --------- @@ -70,19 +68,17 @@ Find all the details you could ask for, and then some. .. toctree:: :maxdepth: 2 - api/index + reference/index -Discussions ------------ +Topics +------ -Get a deeper understanding of how ``websockets`` is built and why. +Get a deeper understanding of how websockets is built and why. .. toctree:: :maxdepth: 2 - design - limitations - security + topics/index Project ------- @@ -92,7 +88,4 @@ This is about websockets-the-project rather than websockets-the-software. .. toctree:: :maxdepth: 2 - changelog - contributing - license - For enterprise + project/index diff --git a/docs/intro.rst b/docs/intro/index.rst similarity index 85% rename from docs/intro.rst rename to docs/intro/index.rst index 58d482a09..49e17b668 100644 --- a/docs/intro.rst +++ b/docs/intro/index.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -``websockets`` requires Python ≥ 3.7. +websockets requires Python ≥ 3.7. You should use the latest version of Python if possible. If you're using an older version, be aware that for each minor version (3.x), only the latest @@ -15,7 +15,7 @@ bugfix release (3.x.y) is officially supported. Installation ------------ -Install ``websockets`` with:: +Install websockets with:: pip install websockets @@ -28,19 +28,19 @@ Here's a WebSocket server example. It reads a name from the client, sends a greeting, and closes the connection. -.. literalinclude:: ../example/server.py - :emphasize-lines: 8,17 +.. literalinclude:: ../../example/server.py + :emphasize-lines: 8,18 .. _client-example: -On the server side, ``websockets`` executes the handler coroutine ``hello`` -once for each WebSocket connection. It closes the connection when the handler +On the server side, websockets executes the handler coroutine ``hello`` once +for each WebSocket connection. It closes the connection when the handler coroutine returns. Here's a corresponding WebSocket client example. -.. literalinclude:: ../example/client.py - :emphasize-lines: 8,10 +.. literalinclude:: ../../example/client.py + :emphasize-lines: 10 Using :func:`connect` as an asynchronous context manager ensures the connection is closed before exiting the ``hello`` coroutine. @@ -60,13 +60,13 @@ Sockets Layer (SSL). WSS requires TLS certificates like HTTPS. Here's how to adapt the server example to provide secure connections. See the documentation of the :mod:`ssl` module for configuring the context securely. -.. literalinclude:: ../example/secure_server.py - :emphasize-lines: 19,23-25 +.. literalinclude:: ../../example/secure_server.py + :emphasize-lines: 19-21,26 Here's how to adapt the client. -.. literalinclude:: ../example/secure_client.py - :emphasize-lines: 10,15-18 +.. literalinclude:: ../../example/secure_client.py + :emphasize-lines: 10-12,18 This client needs a context because the server uses a self-signed certificate. @@ -81,11 +81,11 @@ Here's an example of how to run a WebSocket server and connect from a browser. Run this script in a console: -.. literalinclude:: ../example/show_time.py +.. literalinclude:: ../../example/show_time.py Then open this HTML file in a browser. -.. literalinclude:: ../example/show_time.html +.. literalinclude:: ../../example/show_time.html :language: html Synchronization example @@ -102,11 +102,11 @@ serialized. Run this script in a console: -.. literalinclude:: ../example/counter.py +.. literalinclude:: ../../example/counter.py Then open this HTML file in several browsers. -.. literalinclude:: ../example/counter.html +.. literalinclude:: ../../example/counter.html :language: html Common patterns @@ -147,8 +147,8 @@ messages to send on the WebSocket connection. :exc:`~exceptions.ConnectionClosed` exception when the client disconnects, which breaks out of the ``while True`` loop. -Both -.... +Both sides +.......... You can read and write messages on the same connection by combining the two patterns shown above and running the two tasks in parallel:: @@ -194,16 +194,16 @@ handler may subscribe to some channels on a message broker, for example. That's all! ----------- -The design of the ``websockets`` API was driven by simplicity. +The design of the websockets API was driven by simplicity. You don't have to worry about performing the opening or the closing handshake, answering pings, or any other behavior required by the specification. -``websockets`` handles all this under the hood so you don't have to. +websockets handles all this under the hood so you don't have to. One more thing... ----------------- -``websockets`` provides an interactive client:: +websockets provides an interactive client:: $ python -m websockets wss://echo.websocket.org/ diff --git a/docs/license.rst b/docs/license.rst deleted file mode 100644 index 842d3b07f..000000000 --- a/docs/license.rst +++ /dev/null @@ -1,4 +0,0 @@ -License -------- - -.. literalinclude:: ../LICENSE diff --git a/docs/limitations.rst b/docs/limitations.rst deleted file mode 100644 index bd6d32b2f..000000000 --- a/docs/limitations.rst +++ /dev/null @@ -1,10 +0,0 @@ -Limitations ------------ - -The client doesn't attempt to guarantee that there is no more than one -connection to a given IP address in a CONNECTING state. - -The client doesn't support connecting through a proxy. - -There is no way to fragment outgoing messages. A message is always sent in a -single frame. diff --git a/docs/changelog.rst b/docs/project/changelog.rst similarity index 94% rename from docs/changelog.rst rename to docs/project/changelog.rst index 17f640816..db29a7290 100644 --- a/docs/changelog.rst +++ b/docs/project/changelog.rst @@ -1,5 +1,5 @@ Changelog ---------- +========= .. currentmodule:: websockets @@ -8,9 +8,9 @@ Changelog Backwards-compatibility policy .............................. -``websockets`` is intended for production use. Therefore, stability is a goal. +websockets is intended for production use. Therefore, stability is a goal. -``websockets`` also aims at providing the best API for WebSocket in Python. +websockets also aims at providing the best API for WebSocket in Python. While we value stability, we value progress more. When an improvement requires changing a public API, we make the change and document it in this changelog. @@ -77,25 +77,25 @@ They may change at any time. them, you should adjust the import path. * The ``client``, ``server``, ``protocol``, and ``auth`` modules were - moved from the ``websockets`` package to ``websockets.legacy`` - sub-package, as part of an upcoming refactoring. Despite the name, - they're still fully supported. The refactoring should be a transparent - upgrade for most uses when it's available. The legacy implementation - will be preserved according to the `backwards-compatibility policy`_. + moved from the websockets package to ``websockets.legacy`` sub-package, + as part of an upcoming refactoring. Despite the name, they're still + fully supported. The refactoring should be a transparent upgrade for + most uses when it's available. The legacy implementation will be + preserved according to the `backwards-compatibility policy`_. * The ``framing``, ``handshake``, ``headers``, ``http``, and ``uri`` - modules in the ``websockets`` package are deprecated. These modules - provided low-level APIs for reuse by other WebSocket implementations, - but that never happened. Keeping these APIs public makes it more - difficult to improve websockets for no actual benefit. + modules in the websockets package are deprecated. These modules provided + low-level APIs for reuse by other WebSocket implementations, but that + never happened. Keeping these APIs public makes it more difficult to + improve websockets for no actual benefit. .. note:: **Version 9.0 may require changes if you use static code analysis tools.** - Convenience imports from the ``websockets`` module are performed lazily. - While this is supported by Python, static code analysis tools such as mypy - are unable to understand the behavior. + Convenience imports from the websockets module are performed lazily. While + this is supported by Python, static code analysis tools such as mypy are + unable to understand the behavior. If you depend on such tools, use the real import path, which can be found in the API documentation:: @@ -148,8 +148,7 @@ They may change at any time. *July 21, 2019* -* Restored the ability to import ``WebSocketProtocolError`` from - ``websockets``. +* Restored the ability to import ``WebSocketProtocolError`` from websockets. 8.0 ... @@ -261,7 +260,7 @@ Also: .. warning:: - ``websockets`` **now sends Ping frames at regular intervals and closes the + websockets **now sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong frame.** See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. diff --git a/docs/contributing.rst b/docs/project/contributing.rst similarity index 76% rename from docs/contributing.rst rename to docs/project/contributing.rst index 59d7451e0..c3e8dfc4c 100644 --- a/docs/contributing.rst +++ b/docs/project/contributing.rst @@ -24,11 +24,12 @@ Bug reports, patches and suggestions are welcome! Please open an issue_ or send a `pull request`_. -Feedback about the documentation is especially valuable — the authors of -``websockets`` feel more confident about writing code than writing docs :-) +Feedback about the documentation is especially valuable, as the primary author +feels more confident about writing code than writing docs :-) If you're wondering why things are done in a certain way, the :doc:`design -document ` provides lots of details about the internals of websockets. +document <../topics/design>` provides lots of details about the internals of +websockets. .. _issue: https://github.com/aaugustin/websockets/issues/new .. _pull request: https://github.com/aaugustin/websockets/compare/ @@ -41,14 +42,14 @@ places to ask questions, for example Stack Overflow. If you want to ask a question anyway, please make sure that: -- it's a question about ``websockets`` and not about :mod:`asyncio`; -- it isn't answered by the documentation; +- it's a question about websockets and not about :mod:`asyncio`; +- it isn't answered in the documentation; - it wasn't asked already. A good question can be written as a suggestion to improve the documentation. -Bitcoin users -------------- +Cryptocurrency users +-------------------- websockets appears to be quite popular for interfacing with Bitcoin or other cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. @@ -58,7 +59,8 @@ carbon footprint of all cryptocurrencies drops to a non-bullshit level. Please stop heating the planet where my children are supposed to live, thanks. -Since ``websockets`` is released under an open-source license, you can use it -for any purpose you like. However, I won't spend any of my time to help. +Since websockets is released under an open-source license, you can use it for +any purpose you like. However, I won't spend any of my time to help you. I will summarily close issues related to Bitcoin or cryptocurrency in any way. +Thanks for your understanding. diff --git a/docs/project/index.rst b/docs/project/index.rst new file mode 100644 index 000000000..931fbe2d4 --- /dev/null +++ b/docs/project/index.rst @@ -0,0 +1,10 @@ +Project +======= + +.. toctree:: + :maxdepth: 2 + + changelog + contributing + license + For enterprise diff --git a/docs/project/license.rst b/docs/project/license.rst new file mode 100644 index 000000000..426110020 --- /dev/null +++ b/docs/project/license.rst @@ -0,0 +1,4 @@ +License +======= + +.. literalinclude:: ../../LICENSE diff --git a/docs/tidelift.rst b/docs/project/tidelift.rst similarity index 99% rename from docs/tidelift.rst rename to docs/project/tidelift.rst index 43b457aaf..42100fade 100644 --- a/docs/tidelift.rst +++ b/docs/project/tidelift.rst @@ -4,7 +4,7 @@ websockets for enterprise Available as part of the Tidelift Subscription ---------------------------------------------- -.. image:: _static/tidelift.png +.. image:: ../_static/tidelift.png :height: 150px :width: 150px :align: left diff --git a/docs/api/client.rst b/docs/reference/client.rst similarity index 100% rename from docs/api/client.rst rename to docs/reference/client.rst diff --git a/docs/api/extensions.rst b/docs/reference/extensions.rst similarity index 53% rename from docs/api/extensions.rst rename to docs/reference/extensions.rst index 71f015bb2..6ca20dbc8 100644 --- a/docs/api/extensions.rst +++ b/docs/reference/extensions.rst @@ -1,6 +1,16 @@ Extensions ========== +.. currentmodule:: websockets.extensions + +The WebSocket protocol supports extensions_. + +At the time of writing, there's only one `registered extension`_ with a public +specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. + +.. _extensions: https://tools.ietf.org/html/rfc6455#section-9 +.. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name + Per-Message Deflate ------------------- diff --git a/docs/api/index.rst b/docs/reference/index.rst similarity index 92% rename from docs/api/index.rst rename to docs/reference/index.rst index 0a616cbce..9fc4a0092 100644 --- a/docs/api/index.rst +++ b/docs/reference/index.rst @@ -1,8 +1,8 @@ -API -=== +Reference +========= -``websockets`` provides complete client and server implementations, as shown -in the :doc:`getting started guide <../intro>`. +websockets provides complete client and server implementations, as shown in +the :doc:`getting started guide <../intro/index>`. The process for opening and closing a WebSocket connection depends on which side you're implementing. @@ -44,6 +44,7 @@ both in the client API and server API. server extensions utilities + limitations All public APIs can be imported from the :mod:`websockets` package, unless noted otherwise. This convenience feature is incompatible with static code diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst new file mode 100644 index 000000000..505186770 --- /dev/null +++ b/docs/reference/limitations.rst @@ -0,0 +1,35 @@ +Limitations +=========== + +Client +------ + +The client doesn't attempt to guarantee that there is no more than one +connection to a given IP address in a CONNECTING state. This behavior is +mandated by :rfc:`6455`. However, :func:`~websockets.connect()` isn't the +right layer for enforcing this constraint. It's the caller's responsibility. + +The client doesn't support connecting through a HTTP proxy (`issue 364`_) or a +SOCKS proxy (`issue 475`_). + +.. _issue 364: https://github.com/aaugustin/websockets/issues/364 +.. _issue 475: https://github.com/aaugustin/websockets/issues/475 + +Server +------ + +At this time, there are no known limitations affecting only the server. + +Both sides +---------- + +There is no way to control compression of outgoing frames on a per-frame basis +(`issue 538`_). If compression is enabled, all frames are compressed. + +.. _issue 538: https://github.com/aaugustin/websockets/issues/538 + +There is no way to receive each fragment of a fragmented messages as it +arrives (`issue 479`_). websockets always reassembles framented messages +before returning them. + +.. _issue 479: https://github.com/aaugustin/websockets/issues/479 diff --git a/docs/api/server.rst b/docs/reference/server.rst similarity index 100% rename from docs/api/server.rst rename to docs/reference/server.rst diff --git a/docs/api/utilities.rst b/docs/reference/utilities.rst similarity index 100% rename from docs/api/utilities.rst rename to docs/reference/utilities.rst diff --git a/docs/design.rst b/docs/topics/design.rst similarity index 87% rename from docs/design.rst rename to docs/topics/design.rst index 61b42b528..fa2093433 100644 --- a/docs/design.rst +++ b/docs/topics/design.rst @@ -3,8 +3,8 @@ Design .. currentmodule:: websockets -This document describes the design of ``websockets``. It assumes familiarity -with the specification of the WebSocket protocol in :rfc:`6455`. +This document describes the design of websockets. It assumes familiarity with +the specification of the WebSocket protocol in :rfc:`6455`. It's primarily intended at maintainers. It may also be useful for users who wish to understand what happens under the hood. @@ -13,8 +13,8 @@ wish to understand what happens under the hood. Internals described in this document may change at any time. - Backwards compatibility is only guaranteed for :doc:`public APIs `. - + Backwards compatibility is only guaranteed for :doc:`public APIs + <../reference/index>`. Lifecycle --------- @@ -91,7 +91,7 @@ the same :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` the opening handshake fails, in order to close the TCP connection. Splitting the responsibilities between two tasks makes it easier to guarantee -that ``websockets`` can terminate connections: +that websockets can terminate connections: - within a fixed timeout, - without leaking pending tasks, @@ -112,15 +112,16 @@ the TCP connection is closed. Opening handshake ----------------- -``websockets`` performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~legacy.client.connect` executes it before -returning the protocol to the caller. On the server side, it's executed before -passing the protocol to the ``ws_handler`` coroutine handling the connection. +websockets performs the opening handshake when establishing a WebSocket +connection. On the client side, :meth:`~legacy.client.connect` executes it +before returning the protocol to the caller. On the server side, it's executed +before passing the protocol to the ``ws_handler`` coroutine handling the +connection. While the opening handshake is asymmetrical — the client sends an HTTP Upgrade request and the server replies with an HTTP Switching Protocols response — -``websockets`` aims at keeping the implementation of both sides consistent -with one another. +websockets aims at keeping the implementation of both sides consistent with +one another. On the client side, :meth:`~legacy.client.WebSocketClientProtocol.handshake`: @@ -151,8 +152,8 @@ lies in the negotiation of extensions and, to a lesser extent, of the subprotocol. The server knows everything about both sides and decides what the parameters should be for the connection. The client merely applies them. -If anything goes wrong during the opening handshake, ``websockets`` -:ref:`fails the connection `. +If anything goes wrong during the opening handshake, websockets :ref:`fails +the connection `. .. _data-transfer: @@ -192,8 +193,8 @@ Data flow ......... The following diagram shows how data flows between an application built on top -of ``websockets`` and a remote endpoint. It applies regardless of which side -is the server or the client. +of websockets and a remote endpoint. It applies regardless of which side is +the server or the client. .. image:: protocol.svg :target: _images/protocol.svg @@ -205,7 +206,7 @@ termination is discussed in another section below. Receiving data .............. -The left side of the diagram shows how ``websockets`` receives data. +The left side of the diagram shows how websockets receives data. Incoming data is written to a :class:`~asyncio.StreamReader` in order to implement flow control and provide backpressure on the TCP connection. @@ -224,8 +225,8 @@ When it encounters a control frame: unsolicited pong). Running this process in a task guarantees that control frames are processed -promptly. Without such a task, ``websockets`` would depend on the application -to drive the connection by having exactly one coroutine awaiting +promptly. Without such a task, websockets would depend on the application to +drive the connection by having exactly one coroutine awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` at any time. While this happens naturally in many use cases, it cannot be relied upon. @@ -236,7 +237,7 @@ complexity added for handling backpressure and termination correctly. Sending data ............ -The right side of the diagram shows how ``websockets`` sends data. +The right side of the diagram shows how websockets sends data. :meth:`~legacy.protocol.WebSocketCommonProtocol.send` writes one or several data frames containing the message. While sending a fragmented message, concurrent @@ -274,13 +275,13 @@ state and sends a close frame. When the other side sends a close frame, :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close -timeout, ``websockets`` :ref:`fails the connection `. +timeout, websockets :ref:`fails the connection `. The closing handshake can take up to ``2 * close_timeout``: one ``close_timeout`` to write a close frame and one ``close_timeout`` to receive a close frame. -Then ``websockets`` terminates the TCP connection. +Then websockets terminates the TCP connection. .. _connection-termination: @@ -330,12 +331,12 @@ the connection drops regardless of what happens on the network. Connection failure ------------------ -If the opening handshake doesn't complete successfully, ``websockets`` fails -the connection by closing the TCP connection. +If the opening handshake doesn't complete successfully, websockets fails the +connection by closing the TCP connection. -Once the opening handshake has completed, ``websockets`` fails the connection -by canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and -sending a close frame if appropriate. +Once the opening handshake has completed, websockets fails the connection by +canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` +and sending a close frame if appropriate. :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which closes @@ -369,8 +370,8 @@ Cancellation User code ......... -``websockets`` provides a WebSocket application server. It manages connections -and passes them to user-provided connection handlers. This is an *inversion of +websockets provides a WebSocket application server. It manages connections and +passes them to user-provided connection handlers. This is an *inversion of control* scenario: library code calls user code. If a connection drops, the corresponding handler should terminate. If the @@ -387,15 +388,15 @@ interacts with finalization logic. In the example above, what if a handler gets interrupted with :exc:`~asyncio.CancelledError` while it's finalizing the tasks it started, after detecting that the connection dropped? -``websockets`` considers that cancellation may only be triggered by the caller -of a coroutine when it doesn't care about the results of that coroutine -anymore. (Source: `Guido van Rossum `_). Since connection handlers run -arbitrary user code, ``websockets`` has no way of deciding whether that code -is still doing something worth caring about. +arbitrary user code, websockets has no way of deciding whether that code is +still doing something worth caring about. -For these reasons, ``websockets`` never cancels connection handlers. Instead -it expects them to detect when the connection is closed, execute finalization +For these reasons, websockets never cancels connection handlers. Instead it +expects them to detect when the connection is closed, execute finalization logic if needed, and exit. Conversely, cancellation isn't a concern for WebSocket clients because they @@ -404,14 +405,14 @@ don't involve inversion of control. Library ....... -Most :doc:`public APIs ` of ``websockets`` are coroutines. They may -be canceled, for example if the user starts a task that calls these coroutines -and cancels the task later. ``websockets`` must handle this situation. +Most :doc:`public APIs <../reference/index>` of websockets are coroutines. +They may be canceled, for example if the user starts a task that calls these +coroutines and cancels the task later. websockets must handle this situation. Cancellation during the opening handshake is handled like any other exception: the TCP connection is closed and the exception is re-raised. This can only happen on the client side. On the server side, the opening handshake is -managed by ``websockets`` and nothing results in a cancellation. +managed by websockets and nothing results in a cancellation. Once the WebSocket connection is established, internal tasks :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and @@ -473,16 +474,16 @@ The solution to this problem is backpressure. Any part of the server that receives inputs faster than it can process them and send the outputs must propagate that information back to the previous part in the chain. -``websockets`` is designed to make it easy to get backpressure right. +websockets is designed to make it easy to get backpressure right. -For incoming data, ``websockets`` builds upon :class:`~asyncio.StreamReader` -which propagates backpressure to its own buffer and to the TCP stream. Frames -are parsed from the input stream and added to a bounded queue. If the queue -fills up, parsing halts until the application reads a frame. +For incoming data, websockets builds upon :class:`~asyncio.StreamReader` which +propagates backpressure to its own buffer and to the TCP stream. Frames are +parsed from the input stream and added to a bounded queue. If the queue fills +up, parsing halts until the application reads a frame. -For outgoing data, ``websockets`` builds upon :class:`~asyncio.StreamWriter` -which implements flow control. If the output buffers grow too large, it waits -until they're drained. That's why all APIs that write frames are asynchronous. +For outgoing data, websockets builds upon :class:`~asyncio.StreamWriter` which +implements flow control. If the output buffers grow too large, it waits until +they're drained. That's why all APIs that write frames are asynchronous. Of course, it's still possible for an application to create its own unbounded buffers and break the backpressure. Be careful with queues. @@ -510,8 +511,8 @@ capacity — typically because the system is bottlenecked by the output and constantly regulated by backpressure — reducing the size of buffers minimizes negative consequences. -By default ``websockets`` has rather high limits. You can decrease them -according to your application's characteristics. +By default websockets has rather high limits. You can decrease them according +to your application's characteristics. Bufferbloat can happen at every level in the stack where there is a buffer. For each connection, the receiving side contains these buffers: diff --git a/docs/topics/index.rst b/docs/topics/index.rst new file mode 100644 index 000000000..157278f76 --- /dev/null +++ b/docs/topics/index.rst @@ -0,0 +1,8 @@ +Topics +====== + +.. toctree:: + :maxdepth: 2 + + design + security diff --git a/docs/lifecycle.graffle b/docs/topics/lifecycle.graffle similarity index 100% rename from docs/lifecycle.graffle rename to docs/topics/lifecycle.graffle diff --git a/docs/lifecycle.svg b/docs/topics/lifecycle.svg similarity index 100% rename from docs/lifecycle.svg rename to docs/topics/lifecycle.svg diff --git a/docs/protocol.graffle b/docs/topics/protocol.graffle similarity index 100% rename from docs/protocol.graffle rename to docs/topics/protocol.graffle diff --git a/docs/protocol.svg b/docs/topics/protocol.svg similarity index 100% rename from docs/protocol.svg rename to docs/topics/protocol.svg diff --git a/docs/security.rst b/docs/topics/security.rst similarity index 56% rename from docs/security.rst rename to docs/topics/security.rst index e9acf0629..39de08120 100644 --- a/docs/security.rst +++ b/docs/topics/security.rst @@ -17,9 +17,9 @@ Memory use An attacker who can open an arbitrary number of connections will be able to perform a denial of service by memory exhaustion. If you're concerned by denial of service attacks, you must reject suspicious connections - before they reach ``websockets``, typically in a reverse proxy. + before they reach websockets, typically in a reverse proxy. -With the default settings, opening a connection uses 325 KiB of memory. +With the default settings, opening a connection uses 70 KiB of memory. Sending some highly compressed messages could use up to 128 MiB of memory with an amplification factor of 1000 between network traffic and memory use. @@ -30,10 +30,12 @@ improve security in addition to improving performance. Other limits ------------ -``websockets`` implements additional limits on the amount of data it accepts -in order to minimize exposure to security vulnerabilities. +websockets implements additional limits on the amount of data it accepts in +order to minimize exposure to security vulnerabilities. -In the opening handshake, ``websockets`` limits the number of HTTP headers to -256 and the size of an individual header to 4096 bytes. These limits are 10 to -20 times larger than what's expected in standard use cases. They're hard-coded. -If you need to change them, monkey-patch the constants in ``websockets.http``. +In the opening handshake, websockets limits the number of HTTP headers to 256 +and the size of an individual header to 4096 bytes. These limits are 10 to 20 +times larger than what's expected in standard use cases. They're hard-coded. + +If you need to change these limits, you can monkey-patch the constants in +``websockets.http11``. diff --git a/example/health_check_server.py b/example/health_check_server.py index 2565f9c48..c5bb6d5ab 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -16,7 +16,8 @@ async def echo(websocket, path): async def main(): async with websockets.serve( - echo, "localhost", 8765, process_request=health_check + echo, "localhost", 8765, + process_request=health_check, ): await asyncio.Future() # run forever diff --git a/example/secure_client.py b/example/secure_client.py index 455b6492a..2657ba68b 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -14,7 +14,8 @@ async def hello(): uri = "wss://localhost:8765" async with websockets.connect( - uri, ssl=ssl_context + uri, + ssl=ssl_context, ) as websocket: name = input("What's your name? ") diff --git a/example/secure_server.py b/example/secure_server.py index 55b5a4231..e0ef6e53b 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -22,7 +22,8 @@ async def hello(websocket, path): async def main(): async with websockets.serve( - hello, "localhost", 8765, ssl=ssl_context + hello, "localhost", 8765, + ssl=ssl_context, ): await asyncio.Future() # run forever diff --git a/example/shutdown_client.py b/example/shutdown_client.py index ba1287801..539dd0304 100755 --- a/example/shutdown_client.py +++ b/example/shutdown_client.py @@ -7,8 +7,8 @@ async def client(): uri = "ws://localhost:8765" async with websockets.connect(uri) as websocket: - loop = asyncio.get_running_loop() # Close the connection when receiving SIGTERM. + loop = asyncio.get_running_loop() loop.add_signal_handler( signal.SIGTERM, loop.create_task, websocket.close()) diff --git a/example/shutdown_server.py b/example/shutdown_server.py index 5732313cb..1ae44af1e 100755 --- a/example/shutdown_server.py +++ b/example/shutdown_server.py @@ -9,9 +9,9 @@ async def echo(websocket, path): await websocket.send(message) async def server(): + # Set the stop condition when receiving SIGTERM. loop = asyncio.get_running_loop() stop = loop.create_future() - # Set the stop condition when receiving SIGTERM. loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with websockets.serve(echo, "localhost", 8765): await stop From ad2e643b4a7a98afeb0bfe6c0d7b7105f2061779 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 May 2021 18:10:31 +0200 Subject: [PATCH 022/762] Add graceful shutdown to Heroku guide. --- docs/howto/heroku.rst | 62 ++++++++++++++++++++++++++------ example/deployment/heroku/app.py | 20 +++++++++-- 2 files changed, 68 insertions(+), 14 deletions(-) diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index c6d81ab1d..6d4346dcd 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -1,9 +1,11 @@ Deploy to Heroku ================ -This guide describes how to deploy a websockets server to Heroku_. We're going -to deploy a very simple app. The process would be identical for a more -realistic app. It would be similar on other Platform as a Service providers. +This guide describes how to deploy a websockets server to Heroku_. The same +principles should apply to other Platform as a Service providers. + +We're going to deploy a very simple app. The process would be identical for a +more realistic app. .. _Heroku: https://www.heroku.com/ @@ -41,8 +43,17 @@ Here's the implementation of the app, an echo server. Save it in a file called .. literalinclude:: ../../example/deployment/heroku/app.py -The server relies on the ``$PORT`` environment variable to tell on which port -it will listen, according to Heroku's conventions. +Heroku expects the server to `listen on a specific port`_, which is provided +in the ``$PORT`` environment variable. The app reads it and passes it to +:func:`~websockets.server.serve`. + +.. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port + +Heroku sends a ``SIGTERM`` signal to all processes when `shutting down a +dyno`_. When the app receives this signal, it closes connections and exits +cleanly. + +.. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown Configure deployment -------------------- @@ -70,7 +81,7 @@ Confirm that you created the correct files and commit them to git: $ git add . $ git commit -m "Deploy echo server to Heroku." [main 8418c62] Deploy echo server to Heroku. -  3 files changed, 19 insertions(+) +  3 files changed, 32 insertions(+)  create mode 100644 Procfile  create mode 100644 app.py  create mode 100644 requirements.txt @@ -78,7 +89,7 @@ Confirm that you created the correct files and commit them to git: Deploy ------ -Our app is ready. Let's deploy it! +The app is ready. Let's deploy it! .. code:: console @@ -87,7 +98,7 @@ Our app is ready. Let's deploy it! ... lots of output... remote: -----> Launching... - remote: Released v3 + remote: Released v1 remote: https://websockets-echo.herokuapp.com/ deployed to Heroku remote: remote: Verifying deploy... done. @@ -97,9 +108,9 @@ Our app is ready. Let's deploy it! Validate deployment ------------------- -Of course we'd like to confirm that our application is running as expected! +Of course you'd like to confirm that your application is running as expected! -Since it's a WebSocket server, we need a WebSocket client, such as the +Since it's a WebSocket server, you need a WebSocket client, such as the interactive client that comes with websockets. If you're currently building a websockets server, perhaps you're already in a @@ -121,7 +132,7 @@ Connect the interactive client — using the name of your Heroku app instead of Connected to wss://websockets-echo.herokuapp.com/. > -Great! Our app is running! +Great! Your app is running! In this example, I used a secure connection (``wss://``). It worked because Heroku served a valid TLS certificate for ``websockets-echo.herokuapp.com``. @@ -135,3 +146,32 @@ then press Ctrl-D to terminate the connection: > Hello! < Hello! Connection closed: code = 1000 (OK), no reason. + +You can also confirm that your application shuts down gracefully. Connect an +interactive client again — remember to replace ``websockets-echo`` with your app: + +.. code:: console + + $ python -m websockets wss://websockets-echo.herokuapp.com/ + Connected to wss://websockets-echo.herokuapp.com/. + > + +In another shell, restart the dyno — again, replace ``websockets-echo`` with your app: + +.. code:: console + + $ heroku dyno:restart -a websockets-echo + Restarting dynos on ⬢ websockets-echo... done + +Go back to the first shell. The connection is closed with code 1001 (going +away). + +.. code:: console + + $ python -m websockets wss://websockets-echo.herokuapp.com/ + Connected to wss://websockets-echo.herokuapp.com/. + Connection closed: code = 1001 (going away), no reason. + +If graceful shutdown wasn't working, the server wouldn't perform a closing +handshake and the connection would be closed with code 1006 (connection closed +abnormally). diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index aceb754a3..ff9ba2775 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -1,16 +1,30 @@ #!/usr/bin/env python import asyncio +import signal import os import websockets + async def echo(websocket, path): async for message in websocket: await websocket.send(message) + async def main(): - async with websockets.serve(echo, "", int(os.environ["PORT"])): - await asyncio.Future() # run forever + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, + host="", + port=int(os.environ["PORT"]), + ): + await stop + -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) From 83d2f1dedcd104ad7172d471a8ebc3d4971257b3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 17 May 2021 22:10:37 +0200 Subject: [PATCH 023/762] Document deployment to Kubernetes. Refs #445. --- docs/howto/heroku.rst | 7 +- docs/howto/index.rst | 1 + docs/howto/kubernetes.rst | 213 ++++++++++++++++++ example/deployment/kubernetes/Dockerfile | 7 + example/deployment/kubernetes/app.py | 49 ++++ example/deployment/kubernetes/benchmark.py | 27 +++ example/deployment/kubernetes/deployment.yaml | 35 +++ 7 files changed, 334 insertions(+), 5 deletions(-) create mode 100644 docs/howto/kubernetes.rst create mode 100644 example/deployment/kubernetes/Dockerfile create mode 100755 example/deployment/kubernetes/app.py create mode 100755 example/deployment/kubernetes/benchmark.py create mode 100644 example/deployment/kubernetes/deployment.yaml diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 6d4346dcd..92dee5f27 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -55,8 +55,8 @@ cleanly. .. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown -Configure deployment --------------------- +Deploy application +------------------ In order to build the app, Heroku needs to know that it depends on websockets. Create a ``requirements.txt`` file containing this line: @@ -86,9 +86,6 @@ Confirm that you created the correct files and commit them to git:  create mode 100644 app.py  create mode 100644 requirements.txt -Deploy ------- - The app is ready. Let's deploy it! .. code:: console diff --git a/docs/howto/index.rst b/docs/howto/index.rst index bcacbc174..c5cb4b8f6 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -33,3 +33,4 @@ optimized, and secure setup. deployment compression heroku + kubernetes diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst new file mode 100644 index 000000000..ef5c963d7 --- /dev/null +++ b/docs/howto/kubernetes.rst @@ -0,0 +1,213 @@ +Deploy to Kubernetes +==================== + +This guide describes how to deploy a websockets server to Kubernetes_. It +assumes familiarity with Docker and Kubernetes. + +We're going to deploy a simple app to a local Kubernetes cluster and to ensure +that it scales as expected. + +In a more realistic context, you would follow your organization's practices +for deploying to Kubernetes, but you would apply the same principles as far as +websockets is concerned. + +.. _Kubernetes: https://kubernetes.io/ + +Containerize application +------------------------ + +Here's the app we're going to deploy. Save it in a file called +``app.py``: + +.. literalinclude:: ../../example/deployment/kubernetes/app.py + +This is an echo server with one twist: every message blocks the server for +100ms, which creates artificial starvation of CPU time. This makes it easier +to saturate the server for load testing. + +The app exposes a health check on ``/healthz``. It also provides two other +endpoints for testing purposes: ``/inemuri`` will make the app unresponsive +for 10 seconds and ``/seppuku`` will terminate it. + +The quest for the perfect Python container image is out of scope of this +guide, so we'll go for the simplest possible configuration instead: + +.. literalinclude:: ../../example/deployment/kubernetes/Dockerfile + +After saving this ``Dockerfile``, build the image: + +.. code:: console + + $ docker build -t websockets-test:1.0 . + +Test your image by running: + +.. code:: console + + $ docker run --name run-websockets-test --publish 32080:80 --rm \ + websockets-test:1.0 + +Then, in another shell, in a virtualenv where websockets is installed, connect +to the app and check that it echoes anything you send: + +.. code:: console + + $ python -m websockets ws://localhost:32080/ + Connected to ws://localhost:32080/. + > Hey there! + < Hey there! + > + +Now, in yet another shell, stop the app with: + +.. code:: console + + $ docker kill -s TERM run-websockets-test + +Going to the shell where you connected to the app, you can confirm that it +shut down gracefully: + +.. code:: console + + $ python -m websockets ws://localhost:32080/ + Connected to ws://localhost:32080/. + > Hey there! + < Hey there! + Connection closed: code = 1001 (going away), no reason. + +If it didn't, you'd get code 1006 (connection closed abnormally). + +Deploy application +------------------ + +Configuring Kubernetes is even further beyond the scope of this guide, so +we'll use a basic configuration for testing, with just one Service_ and one +Deployment_: + +.. literalinclude:: ../../example/deployment/kubernetes/deployment.yaml + +For local testing, a service of type NodePort_ is good enough. For deploying +to production, you would configure an Ingress_. + +.. _Service: https://kubernetes.io/docs/concepts/services-networking/service/ +.. _Deployment: https://kubernetes.io/docs/concepts/workloads/controllers/deployment/ +.. _NodePort: https://kubernetes.io/docs/concepts/services-networking/service/#nodeport +.. _Ingress: https://kubernetes.io/docs/concepts/services-networking/ingress/ + +After saving this to a file called ``deployment.yaml``, you can deploy: + +.. code:: console + + $ kubectl apply -f deployment.yaml + service/websockets-test created + deployment.apps/websockets-test created + +Now you have a deployment with one pod running: + +.. code:: console + + $ kubectl get deployment websockets-test + NAME READY UP-TO-DATE AVAILABLE AGE + websockets-test 1/1 1 1 10s + $ kubectl get pods -l app=websockets-test + NAME READY STATUS RESTARTS AGE + websockets-test-86b48f4bb7-nltfh 1/1 Running 0 10s + +You can connect to the service — press Ctrl-D to exit: + +.. code:: console + + $ python -m websockets ws://localhost:32080/ + Connected to ws://localhost:32080/. + Connection closed: code = 1000 (OK), no reason. + +Validate deployment +------------------- + +First, let's ensure the liveness probe works by making the app unresponsive: + +.. code:: console + + $ curl http://localhost:32080/inemuri + Sleeping for 10s + +Since we have only one pod, we know that this pod will go to sleep. + +The liveness probe is configured to run every second. By default, liveness +probes time out after one second and have a threshold of three failures. +Therefore Kubernetes should restart the pod after at most 5 seconds. + +Indeed, after a few seconds, the pod reports a restart: + +.. code:: console + + $ kubectl get pods -l app=websockets-test + NAME READY STATUS RESTARTS AGE + websockets-test-86b48f4bb7-nltfh 1/1 Running 1 42s + +Next, let's take it one step further and crash the app: + +.. code:: console + + $ curl http://localhost:32080/seppuku + Terminating + +The pod reports a second restart: + +.. code:: console + + $ kubectl get pods -l app=websockets-test + NAME READY STATUS RESTARTS AGE + websockets-test-86b48f4bb7-nltfh 1/1 Running 2 72s + +All good — Kubernetes delivers on its promise to keep our app alive! + +Scale deployment +---------------- + +Of course, Kubernetes is for scaling. Let's scale — modestly — to 10 pods: + +.. code:: console + + $ kubectl scale deployment.apps/websockets-test --replicas=10 + deployment.apps/websockets-test scaled + +After a few seconds, we have 10 pods running: + +.. code:: console + + $ kubectl get deployment websockets-test + NAME READY UP-TO-DATE AVAILABLE AGE + websockets-test 10/10 10 10 10m + +Now let's generate load. We'll use this script: + +.. literalinclude:: ../../example/deployment/kubernetes/benchmark.py + +We'll connect 500 clients in parallel, meaning 50 clients per pod, and have +each client send 6 messages. Since the app blocks for 100ms before responding, +if connections are perfectly distributed, we expect a total run time slightly +over 50 * 6 * 0.1 = 30 seconds. + +Let's try it: + +.. code:: console + + $ ulimit -n 512 + $ time python benchmark.py 500 6 + python benchmark.py 500 6 2.40s user 0.51s system 7% cpu 36.471 total + +A total runtime of 36 seconds is in the right ballpark. Repeating this +experiment with other parameters shows roughly consistent results, with the +high variability you'd expect from a quick benchmark without any effort to +stabilize the test setup. + +Finally, we can scale back to one pod. + +.. code:: console + + $ kubectl scale deployment.apps/websockets-test --replicas=1 + deployment.apps/websockets-test scaled + $ kubectl get deployment websockets-test + NAME READY UP-TO-DATE AVAILABLE AGE + websockets-test 1/1 1 1 15m diff --git a/example/deployment/kubernetes/Dockerfile b/example/deployment/kubernetes/Dockerfile new file mode 100644 index 000000000..83ed8722c --- /dev/null +++ b/example/deployment/kubernetes/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.9-alpine + +RUN pip install websockets + +COPY app.py . + +CMD ["python", "app.py"] diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py new file mode 100755 index 000000000..dcc29bd1c --- /dev/null +++ b/example/deployment/kubernetes/app.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +import asyncio +import http +import signal +import sys +import time + +import websockets + + +async def slow_echo(websocket, path): + async for message in websocket: + # Block the event loop! This allows saturating a single asyncio + # process without opening an impractical number of connections. + time.sleep(0.1) # 100ms + await websocket.send(message) + + +async def health_check(path, request_headers): + if path == "/healthz": + return http.HTTPStatus.OK, [], b"OK\n" + if path == "/inemuri": + loop = asyncio.get_running_loop() + loop.call_later(1, time.sleep, 10) + return http.HTTPStatus.OK, [], b"Sleeping for 10s\n" + if path == "/seppuku": + loop = asyncio.get_running_loop() + loop.call_later(1, sys.exit, 69) + return http.HTTPStatus.OK, [], b"Terminating\n" + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + slow_echo, + host="", + port=80, + process_request=health_check, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/kubernetes/benchmark.py b/example/deployment/kubernetes/benchmark.py new file mode 100755 index 000000000..600c47316 --- /dev/null +++ b/example/deployment/kubernetes/benchmark.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python + +import asyncio +import sys +import websockets + + +URI = "ws://localhost:32080" + + +async def run(client_id, messages): + async with websockets.connect(URI) as websocket: + for message_id in range(messages): + await websocket.send("{client_id}:{message_id}") + await websocket.recv() + + +async def benchmark(clients, messages): + await asyncio.wait([ + asyncio.create_task(run(client_id, messages)) + for client_id in range(clients) + ]) + + +if __name__ == "__main__": + clients, messages = int(sys.argv[1]), int(sys.argv[2]) + asyncio.run(benchmark(clients, messages)) diff --git a/example/deployment/kubernetes/deployment.yaml b/example/deployment/kubernetes/deployment.yaml new file mode 100644 index 000000000..ba58dd62b --- /dev/null +++ b/example/deployment/kubernetes/deployment.yaml @@ -0,0 +1,35 @@ +apiVersion: v1 +kind: Service +metadata: + name: websockets-test +spec: + type: NodePort + ports: + - port: 80 + nodePort: 32080 + selector: + app: websockets-test +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: websockets-test +spec: + selector: + matchLabels: + app: websockets-test + template: + metadata: + labels: + app: websockets-test + spec: + containers: + - name: websockets-test + image: websockets-test:1.0 + livenessProbe: + httpGet: + path: /healthz + port: 80 + periodSeconds: 1 + ports: + - containerPort: 80 From af5bfd1b52b70ace251abe2e0ca7541b316ea2c3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 19 May 2021 09:09:58 +0200 Subject: [PATCH 024/762] Deployment deployment under Supervisor. Refs #445. --- docs/howto/index.rst | 1 + docs/howto/supervisor.rst | 131 ++++++++++++++++++ example/deployment/supervisor/app.py | 28 ++++ .../deployment/supervisor/supervisord.conf | 7 + 4 files changed, 167 insertions(+) create mode 100644 docs/howto/supervisor.rst create mode 100644 example/deployment/supervisor/app.py create mode 100644 example/deployment/supervisor/supervisord.conf diff --git a/docs/howto/index.rst b/docs/howto/index.rst index c5cb4b8f6..033dfe15d 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -34,3 +34,4 @@ optimized, and secure setup. compression heroku kubernetes + supervisor diff --git a/docs/howto/supervisor.rst b/docs/howto/supervisor.rst new file mode 100644 index 000000000..6d679ca2b --- /dev/null +++ b/docs/howto/supervisor.rst @@ -0,0 +1,131 @@ +Deploy with Supervisor +====================== + +This guide proposes a simple way to deploy a websockets server directly on a +Linux or BSD operating system. + +We'll configure Supervisor_ to run several server processes and to restart +them if needed. + +.. _Supervisor: http://supervisord.org/ + +We'll bind all servers to the same port. The OS will take care of balancing +connections. + +Create and activate a virtualenv: + +.. code:: console + + $ python -m venv supervisor-websockets + $ . supervisor-websockets/bin/activate + +Install websockets and Supervisor: + +.. code:: console + + $ pip install websockets + $ pip install supervisor + +Save this app to a file called ``app.py``: + +.. literalinclude:: ../../example/deployment/supervisor/app.py + +This is an echo server with two features added for the purpose of this guide: + +* It shuts down gracefully when receiving a ``SIGTERM`` signal; +* It enables the ``reuse_port`` option of :meth:`~asyncio.loop.create_server`, + which in turns sets ``SO_REUSEPORT`` on the accept socket. + +Save this Supervisor configuration to ``supervisord.conf``: + +.. literalinclude:: ../../example/deployment/supervisor/supervisord.conf + +This is the minimal configuration required to keep four instances of the app +running, restarting them if they exit. + +Now start Supervisor in the foreground: + +.. code:: console + + $ supervisord -c supervisord.conf -n + INFO Increased RLIMIT_NOFILE limit to 1024 + INFO supervisord started with pid 43596 + INFO spawned: 'websockets-test_00' with pid 43597 + INFO spawned: 'websockets-test_01' with pid 43598 + INFO spawned: 'websockets-test_02' with pid 43599 + INFO spawned: 'websockets-test_03' with pid 43600 + INFO success: websockets-test_00 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + INFO success: websockets-test_01 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + INFO success: websockets-test_02 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + INFO success: websockets-test_03 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + +In another shell, after activating the virtualenv, we can connect to the app — +press Ctrl-D to exit: + +.. code:: console + + $ python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + > Hello! + < Hello! + Connection closed: code = 1000 (OK), no reason. + +Look at the pid of an instance of the app in the logs and terminate it: + +.. code:: console + + $ kill -TERM 43597 + +The logs show that Supervisor restarted this instance: + +.. code:: console + + INFO exited: websockets-test_00 (exit status 0; expected) + INFO spawned: 'websockets-test_00' with pid 43629 + INFO success: websockets-test_00 entered RUNNING state, process has stayed up for > than 1 seconds (startsecs) + +Now let's check what happens when we shut down Supervisor, but first let's +establish a connection and leave it open: + +.. code:: console + + $ python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + > + +Look at the pid of supervisord itself in the logs and terminate it: + +.. code:: console + + $ kill -TERM 43596 + +The logs show that Supervisor terminated all instances of the app before +exiting: + +.. code:: console + + WARN received SIGTERM indicating exit request + INFO waiting for websockets-test_00, websockets-test_01, websockets-test_02, websockets-test_03 to die + INFO stopped: websockets-test_02 (exit status 0) + INFO stopped: websockets-test_03 (exit status 0) + INFO stopped: websockets-test_01 (exit status 0) + INFO stopped: websockets-test_00 (exit status 0) + +And you can see that the connection to the app was closed gracefully: + +.. code:: console + + $ python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + Connection closed: code = 1001 (going away), no reason. + +In this example, we've been sharing the same virtualenv for supervisor and +websockets. + +In a real deployment, you would likely: + +* Install Supervisor with the package manager of the OS. +* Create a virtualenv dedicated to your application. +* Add ``environment=PATH="path/to/your/virtualenv/bin"`` in the Supervisor + configuration. Then ``python app.py`` runs in that virtualenv. + diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py new file mode 100644 index 000000000..5fa596d0a --- /dev/null +++ b/example/deployment/supervisor/app.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +import asyncio +import signal + +import websockets + + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, "", 8080, + reuse_port=True, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/supervisor/supervisord.conf b/example/deployment/supervisor/supervisord.conf new file mode 100644 index 000000000..76a664d91 --- /dev/null +++ b/example/deployment/supervisor/supervisord.conf @@ -0,0 +1,7 @@ +[supervisord] + +[program:websockets-test] +command = python app.py +process_name = %(program_name)s_%(process_num)02d +numprocs = 4 +autorestart = true From b3e50f341de7df9b0b82787111017787f9e82ce4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 21 May 2021 22:39:41 +0200 Subject: [PATCH 025/762] Include all files in Heroku example. --- example/deployment/heroku/Procfile | 1 + example/deployment/heroku/requirements.txt | 1 + 2 files changed, 2 insertions(+) create mode 100644 example/deployment/heroku/Procfile create mode 100644 example/deployment/heroku/requirements.txt diff --git a/example/deployment/heroku/Procfile b/example/deployment/heroku/Procfile new file mode 100644 index 000000000..2e35818f6 --- /dev/null +++ b/example/deployment/heroku/Procfile @@ -0,0 +1 @@ +web: python app.py diff --git a/example/deployment/heroku/requirements.txt b/example/deployment/heroku/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/deployment/heroku/requirements.txt @@ -0,0 +1 @@ +websockets From 559a5724a96e1e4d921ed48d37f49876069e65da Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 21 May 2021 22:40:57 +0200 Subject: [PATCH 026/762] Spell check docs. --- docs/reference/limitations.rst | 2 +- docs/spelling_wordlist.txt | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst index 505186770..ecdde23bf 100644 --- a/docs/reference/limitations.rst +++ b/docs/reference/limitations.rst @@ -29,7 +29,7 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _issue 538: https://github.com/aaugustin/websockets/issues/538 There is no way to receive each fragment of a fragmented messages as it -arrives (`issue 479`_). websockets always reassembles framented messages +arrives (`issue 479`_). websockets always reassembles fragmented messages before returning them. .. _issue 479: https://github.com/aaugustin/websockets/issues/479 diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 363413b4b..19c4b7c44 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -23,12 +23,15 @@ ctrl daemonize datastructures django +dyno fractalideas IPv iterable keepalive KiB +Kubernetes lifecycle +liveness lookups MiB mypy @@ -36,16 +39,19 @@ nginx onmessage parsers permessage +pid pong pongs pythonic redis +runtime scalable serializers subclasses subclassing subprotocol subprotocols +supervisord tidelift tls tox From 2ced669fa9ad1976d4bf197f5d966eaae7fafb21 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 20:35:46 +0200 Subject: [PATCH 027/762] Make changelog more readable. --- docs/project/changelog.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index db29a7290..f28ca7f0b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -36,9 +36,10 @@ They may change at any time. .. note:: - **Version 10.0 deprecates the** ``loop`` **parameter from all APIs for the - same reasons the same change was made in Python 3.8. See the release notes - of Python 3.10 for details.** + **Version 10.0 deprecates the** ``loop`` **parameter from all APIs.** + + This reflects a decision made in Python 3.8. See the release notes of + Python 3.10 for details. * Added compatibility with Python 3.10. From 4d2cd03bae4558253d51997e6cee9bd87103b236 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 20:36:58 +0200 Subject: [PATCH 028/762] Standardize style in Makefile. --- Makefile | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 06832945c..b15cd13c9 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,7 @@ export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src +export PYTHONWARNINGS=default default: coverage style @@ -12,13 +13,13 @@ style: mypy --strict src test: - python -W default -m unittest + python -m unittest coverage: - python -m coverage erase - python -W default -m coverage run -m unittest - python -m coverage html - python -m coverage report --show-missing --fail-under=100 + coverage erase + coverage run -m unittest + coverage html + coverage report --show-missing --fail-under=100 build: python setup.py build_ext --inplace From c0750dac4bbf30138fa825054fb065f0355e4150 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 07:41:41 +0200 Subject: [PATCH 029/762] Add low tech icon generator. --- logo/github-social-preview.html | 7 ++++--- logo/icon.html | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 logo/icon.html diff --git a/logo/github-social-preview.html b/logo/github-social-preview.html index 187a183b0..7f2b45bad 100644 --- a/logo/github-social-preview.html +++ b/logo/github-social-preview.html @@ -1,5 +1,5 @@ - + GitHub social preview + + +

Take a screenshot of these DOM nodes to2x make a PNG.

+

8x8 / 16x16 @ 2x

+

16x16 / 32x32 @ 2x

+

32x32 / 32x32 @ 2x

+

32x32 / 64x64 @ 2x

+

64x64 / 128x128 @ 2x

+

128x128 / 256x256 @ 2x

+

256x256 / 512x512 @ 2x

+

512x512 / 1024x1024 @ 2x

+ + From dd6d6bce2d1ac7e50a0a90e44c8073a8a17a2c05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 20 May 2021 22:10:15 +0200 Subject: [PATCH 030/762] Make it easier to customize authentication. --- docs/project/changelog.rst | 3 +++ docs/reference/server.rst | 9 +++++++-- src/websockets/legacy/auth.py | 36 ++++++++++++++++++++++++++++------- tests/legacy/test_auth.py | 18 ++++++++++++++++++ 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f28ca7f0b..c2c869c13 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -45,6 +45,9 @@ They may change at any time. * Optimized default compression settings to reduce memory usage. +* Made it easier to customize authentication with + :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. + 9.0.2 ..... diff --git a/docs/reference/server.rst b/docs/reference/server.rst index c1822dda6..1a2dd1c88 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -96,10 +96,15 @@ Basic authentication .. autoclass:: BasicAuthWebSocketServerProtocol - .. automethod:: process_request + .. attribute:: realm + + Scope of protection. + + If provided, it should contain only ASCII characters because the + encoding of non-ASCII characters is undefined. .. attribute:: username Username of the authenticated user. - + .. automethod:: check_credentials diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e0beede57..e7e69cac1 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -35,22 +35,44 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): """ + realm = "" + def __init__( self, *args: Any, - realm: str, - check_credentials: Callable[[str, str], Awaitable[bool]], + realm: Optional[str] = None, + check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, **kwargs: Any, ) -> None: - self.realm = realm - self.check_credentials = check_credentials + if realm is not None: + self.realm = realm # shadow class attribute + self._check_credentials = check_credentials super().__init__(*args, **kwargs) + async def check_credentials(self, username: str, password: str) -> bool: + """ + Check whether credentials are authorized. + + If ``check_credentials`` returns ``True``, the WebSocket handshake + continues. If it returns ``False``, the handshake fails with a HTTP + 401 error. + + This coroutine may be overridden in a subclass, for example to + authenticate against a database or an external service. + + """ + if self._check_credentials is not None: + return await self._check_credentials(username, password) + + return False + async def process_request( - self, path: str, request_headers: Headers + self, + path: str, + request_headers: Headers, ) -> Optional[HTTPResponse]: """ - Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed. + Check HTTP Basic Auth and return a HTTP 401 response if needed. """ try: @@ -84,7 +106,7 @@ async def process_request( def basic_auth_protocol_factory( - realm: str, + realm: Optional[str] = None, credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None, diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index bb8c6a6eb..c4dbd88ad 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -25,6 +25,11 @@ async def process_request(self, path, request_headers): return await super().process_request(path, request_headers) +class CheckWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): + async def check_credentials(self, username, password): + return password == "letmein" + + class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): create_protocol = basic_auth_protocol_factory( @@ -103,6 +108,19 @@ def test_basic_auth_custom_protocol(self): self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) + @with_server(create_protocol=CheckWebSocketServerProtocol) + @with_client(user_info=("hello", "letmein")) + def test_basic_auth_custom_protocol_subclass(self): + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + # CustomWebSocketServerProtocol doesn't override check_credentials + @with_server(create_protocol=CustomWebSocketServerProtocol) + def test_basic_auth_defaults_to_deny_all(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(user_info=("hello", "iloveyou")) + self.assertEqual(raised.exception.status_code, 401) + @with_server(create_protocol=create_protocol) def test_basic_auth_missing_credentials(self): with self.assertRaises(InvalidStatusCode) as raised: From 8ab85e54a764a26529b801e263e5223b1b2921c9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 08:11:46 +0200 Subject: [PATCH 031/762] Experiment authentication techniques. --- experiments/authentication/app.py | 225 ++++++++++++++++++ experiments/authentication/cookie.html | 15 ++ experiments/authentication/cookie.js | 36 +++ experiments/authentication/cookie_iframe.html | 9 + experiments/authentication/cookie_iframe.js | 36 +++ experiments/authentication/favicon.ico | Bin 0 -> 5430 bytes experiments/authentication/first_message.html | 14 ++ experiments/authentication/first_message.js | 11 + experiments/authentication/index.html | 12 + experiments/authentication/query_param.html | 14 ++ experiments/authentication/query_param.js | 11 + experiments/authentication/script.js | 51 ++++ experiments/authentication/style.css | 69 ++++++ experiments/authentication/test.html | 15 ++ experiments/authentication/test.js | 6 + experiments/authentication/user_info.html | 14 ++ experiments/authentication/user_info.js | 11 + 17 files changed, 549 insertions(+) create mode 100644 experiments/authentication/app.py create mode 100644 experiments/authentication/cookie.html create mode 100644 experiments/authentication/cookie.js create mode 100644 experiments/authentication/cookie_iframe.html create mode 100644 experiments/authentication/cookie_iframe.js create mode 100644 experiments/authentication/favicon.ico create mode 100644 experiments/authentication/first_message.html create mode 100644 experiments/authentication/first_message.js create mode 100644 experiments/authentication/index.html create mode 100644 experiments/authentication/query_param.html create mode 100644 experiments/authentication/query_param.js create mode 100644 experiments/authentication/script.js create mode 100644 experiments/authentication/style.css create mode 100644 experiments/authentication/test.html create mode 100644 experiments/authentication/test.js create mode 100644 experiments/authentication/user_info.html create mode 100644 experiments/authentication/user_info.js diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py new file mode 100644 index 000000000..c3c9a0557 --- /dev/null +++ b/experiments/authentication/app.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python + +import asyncio +import http +import http.cookies +import pathlib +import signal +import urllib.parse +import uuid + +import websockets + + +# User accounts database + +USERS = {} + + +def create_token(user, lifetime=1): + """Create token for user and delete it once its lifetime is over.""" + token = uuid.uuid4().hex + USERS[token] = user + asyncio.get_running_loop().call_later(lifetime, USERS.pop, token) + return token + + +def get_user(token): + """Find user authenticated by token or return None.""" + return USERS.get(token) + + +# Utilities + + +def get_cookie(raw, key): + cookie = http.cookies.SimpleCookie(raw) + morsel = cookie.get(key) + if morsel is not None: + return morsel.value + + +def get_query_param(path, key): + query = urllib.parse.urlparse(path).query + params = urllib.parse.parse_qs(query) + values = params.get(key, []) + if len(values) == 1: + return values[0] + + +# Main HTTP server + +CONTENT_TYPES = { + ".css": "text/css", + ".html": "text/html; charset=utf-8", + ".ico": "image/x-icon", + ".js": "text/javascript", +} + + +async def serve_html(path, request_headers): + user = get_query_param(path, "user") + path = urllib.parse.urlparse(path).path + if path == "/": + if user is None: + page = "index.html" + else: + page = "test.html" + else: + page = path[1:] + + try: + template = pathlib.Path(__file__).with_name(page) + except ValueError: + pass + else: + if template.is_file(): + headers = {"Content-Type": CONTENT_TYPES[template.suffix]} + body = template.read_bytes() + if user is not None: + token = create_token(user) + body = body.replace(b"TOKEN", token.encode()) + return http.HTTPStatus.OK, headers, body + + return http.HTTPStatus.NOT_FOUND, {}, b"Not found\n" + + +async def noop_handler(websocket, path): + pass + + +# Send credentials as the first message in the WebSocket connection + + +async def first_message_handler(websocket, path): + token = await websocket.recv() + user = get_user(token) + if user is None: + await websocket.close(1011, "authentication failed") + return + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + + +# Add credentials to the WebSocket URI in a query parameter + + +class QueryParamProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + token = get_query_param(path, "token") + if token is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = get_user(token) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + +async def query_param_handler(websocket, path): + user = websocket.user + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + + +# Set a cookie on the domain of the WebSocket URI + + +class CookieProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + if "Upgrade" not in headers: + template = pathlib.Path(__file__).with_name(path[1:]) + headers = {"Content-Type": CONTENT_TYPES[template.suffix]} + body = template.read_bytes() + return http.HTTPStatus.OK, headers, body + + token = get_cookie(headers.get("Cookie", ""), "token") + if token is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = get_user(token) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + +async def cookie_handler(websocket, path): + user = websocket.user + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + + +# Adding credentials to the WebSocket URI in user information + + +class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): + async def check_credentials(self, username, password): + if username != "token": + return False + + user = get_user(password) + if user is None: + return False + + self.user = user + return True + + +async def user_info_handler(websocket, path): + user = websocket.user + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + + +# Start all five servers + + +async def main(): + # Set the stop condition when receiving SIGINT or SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGINT, stop.set_result, None) + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + noop_handler, + host="", + port=8000, + process_request=serve_html, + ), websockets.serve( + first_message_handler, + host="", + port=8001, + ), websockets.serve( + query_param_handler, + host="", + port=8002, + create_protocol=QueryParamProtocol, + ), websockets.serve( + cookie_handler, + host="", + port=8003, + create_protocol=CookieProtocol, + ), websockets.serve( + user_info_handler, + host="", + port=8004, + create_protocol=UserInfoProtocol, + ): + print("Running on http://localhost:8000/") + await stop + print("\rExiting") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/experiments/authentication/cookie.html b/experiments/authentication/cookie.html new file mode 100644 index 000000000..ca17358fd --- /dev/null +++ b/experiments/authentication/cookie.html @@ -0,0 +1,15 @@ + + + + Cookie | WebSocket Authentication + + + +

[??] Cookie

+

[OK] Cookie

+

[KO] Cookie

+ + + + + diff --git a/experiments/authentication/cookie.js b/experiments/authentication/cookie.js new file mode 100644 index 000000000..9c5d5f078 --- /dev/null +++ b/experiments/authentication/cookie.js @@ -0,0 +1,36 @@ +// wait for "load" rather than "DOMContentLoaded" +// to ensure that the iframe has finished loading +window.addEventListener("load", () => { + // create message channel to communicate + // with the iframe + const channel = new MessageChannel(); + const port1 = channel.port1; + + // receive WebSocket events from the iframe + const expected = getExpectedEvents(); + var actual = []; + port1.onmessage = ({ data }) => { + // respond to messages + if (data.type == "message") { + port1.postMessage({ + type: "message", + message: `Goodbye ${data.data.slice(6, -1)}.`, + }); + } + // run tests + actual.push(data); + testStep(expected, actual); + }; + + // send message channel to the iframe + const iframe = document.querySelector("iframe"); + const origin = "http://localhost:8003"; + const ports = [channel.port2]; + iframe.contentWindow.postMessage("", origin, ports); + + // send token to the iframe + port1.postMessage({ + type: "open", + token: token, + }); +}); diff --git a/experiments/authentication/cookie_iframe.html b/experiments/authentication/cookie_iframe.html new file mode 100644 index 000000000..9f49ebb9a --- /dev/null +++ b/experiments/authentication/cookie_iframe.html @@ -0,0 +1,9 @@ + + + + Cookie iframe | WebSocket Authentication + + + + + diff --git a/experiments/authentication/cookie_iframe.js b/experiments/authentication/cookie_iframe.js new file mode 100644 index 000000000..88b5d1f2b --- /dev/null +++ b/experiments/authentication/cookie_iframe.js @@ -0,0 +1,36 @@ +// receive message channel from the parent window +window.addEventListener("message", ({ ports }) => { + const port2 = ports[0]; + var websocket; + port2.onmessage = ({ data }) => { + switch (data.type) { + case "open": + websocket = initWebSocket(data.token, port2); + break; + case "message": + websocket.send(data.message); + break; + case "close": + websocket.close(data.code, data.reason); + break; + } + }; +}); + +// open WebSocket connection and relay events to the parent window +function initWebSocket(token, port2) { + document.cookie = `token=${token}; SameSite=Strict`; + const websocket = new WebSocket("ws://localhost:8003/"); + + websocket.addEventListener("open", ({ type }) => { + port2.postMessage({ type }); + }); + websocket.addEventListener("message", ({ type, data }) => { + port2.postMessage({ type, data }); + }); + websocket.addEventListener("close", ({ type, code, reason, wasClean }) => { + port2.postMessage({ type, code, reason, wasClean }); + }); + + return websocket; +} diff --git a/experiments/authentication/favicon.ico b/experiments/authentication/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..36e855029d705e72d44428bda6e8cb6d3dd317ed GIT binary patch literal 5430 zcmeH~eNdFw6~-?im3Ahb_(wY9G?@uAnl{s!P5}k6EMLo)Am6YM4N*)C21rB%KSG3E zv;hMxiW&u@(MF93wOCQ3rD-sRXbhdyP7o3K#!S*q0u${7l(0J*XQZVahM#RhT&Mw%^BqMk{U@J1~1H2BJEU64r)0!B+s<_@|*2 zqhiBw1#eC^Kd~FBvDYyg*$#wV#Xwjm9t&v2p)n2r@@vI8+WgsgcG9|#V(P|dLI)7j zfdSf?v#0~lgtfyurUjuNe*Ta2+T*8ca~J%;(ZdETkY^gva%XNcqcj2kdN-G|0yeVF}ZFP=4Z!J2R#i0#02+Ir+*zVG@y z*5YS?yJuBTDvF9 z+e4qHUo$?zisgL(ZFDT{$HRZz5|LeB6ok5h0Mx8E{O!%Wpod<2-phQZDi;R}!Uptt z_Sffbz;E>+X1=#0_Q!9(68Y81O(FP$*?76Z67toHt(cKzH39p~78|EMbIvF7dZI9CxZHU=B0 z<9Jap>Rt-q+)$@;bvu`KC6(K3*mRuQ9MaslWomPE8-LktGQPhv{>AfKV-~kmXM9f| z|2D6HQbVN2tW*=RJWA^l;TTP>949?wSa4Hnl)qW|Kbb4f8Fvyy2h!?Rc2a%__sk=;L*m&%aOk`2^Q3=D%J$`lrna32wt&`ua?J=IYOX`)ey| zsDH*}F--J3_a!wE;q=XC^%sGy0H6D|y~p17*k>&_PGWEtANNEx9#Nf`$LIWDXd7VU z-lc0`F21g#6;|0_CyK+_IPED%Vmjd+bq#|NSGn)ee~4%xL9&NlK@0!?spEQZeA9Sw zOg84YE{XKby>e*LH9+>d2-$me4ei*@T2#ST@l)pX-aNEoMD}B9TkIH*)9-fc#{JDF zsS`tE`y^thLD#X6zKIf~ACX*0P0;qeI{2O1>Ze+hd$YWhu%sV8DP0&!?gSFKPsfuu z=^dyc`Wh;U66imQy~wci5ZZc7tYMe4y3>E-zOC46crNM2&=cI_Q|MnbFb{~Qf9lvq zl)zjpqW|G=)`eCR4jO{~dh5&OHAgQLZpnV9ydbWbam~@=o9J4d>8Y0WqGf3>kS6_3 zH=ygFmaRky%tJKeAJJ-p{_kqA)#Yh(*$efw^_AAoNKZthCz1G^u_xP8J>82TYJl`x zPtw0=VBb|il)!vMb4^CrJ8A0?#hMfy%!Slv9#Q+&`23lKphoC2`RmZF?C`@YT|BPQis$fd)qZc3H#T~U(G%W)mfglo{sU;_HV?#Giog+ zWdZKb7(gC1)GY7CGNOdEKE$SWVT5(5)r3};^sUup2XfhuD&bJR@3U{@`dIhp%!9oK z`hCKgZ~0C9j|Y3#aMr|#eN9k{rX#rbA2Ga;A`w_nTG-C ztBDe-vb<91XNIG%>P!!=&(~otdZi|$F7=v_x=f{SqJ-l`K`Vb(5MVvJ!Jze!ng;g7 z>?QBK{=AowD1llEwR+6*IO(WiJm0l|D{Ep{uL)g4S}$^l5>9OnTX~|`Xtnjy{y6%g zO~ax51$Wq&ClMvEKQ9vBhc`yq?ukr~>U(@1IJr6Gc0;i-fhgfV}k+ptmI=o*Qx0@RDGY}eIiR4)*!i%Vr#UY#Q>p{*{nuh%O zw-{rDA=?=E^vn-xf;^a)yc&4S#>DB;t~+h17%G6M9Y7Y%ttpeEt)~ z*kB1&)1jq0b@m6ll1AkWCmP=Q^&;&&DfRHy%i&r*$hok-P}6XevH3fT@AR)C)pTj8 zN-n>!BLP|-wu%PrJgTk5T@A${HyM{c{o!N z{uyKPol0D`gW)%O|F9}G?&6Mwv&}o=zc^nRJKR(eeeKWX(PfP1M^laIn|c0j*0+@b zZDY{i*oZ&{!b(QfqZlSb-*-dLEEhC+xWJOGU};h)wgh3bYC?&N1*5L&qtO{IrNa-n mx{c17;Wp~=fE$Kp5swEk + + + First message | WebSocket Authentication + + + +

[??] First message

+

[OK] First message

+

[KO] First message

+ + + + diff --git a/experiments/authentication/first_message.js b/experiments/authentication/first_message.js new file mode 100644 index 000000000..1acf048ba --- /dev/null +++ b/experiments/authentication/first_message.js @@ -0,0 +1,11 @@ +window.addEventListener("DOMContentLoaded", () => { + const websocket = new WebSocket("ws://localhost:8001/"); + websocket.onopen = () => websocket.send(token); + + websocket.onmessage = ({ data }) => { + // event.data is expected to be "Hello !" + websocket.send(`Goodbye ${data.slice(6, -1)}.`); + }; + + runTest(websocket); +}); diff --git a/experiments/authentication/index.html b/experiments/authentication/index.html new file mode 100644 index 000000000..c37deef27 --- /dev/null +++ b/experiments/authentication/index.html @@ -0,0 +1,12 @@ + + + + WebSocket Authentication + + + +
+ +
+ + diff --git a/experiments/authentication/query_param.html b/experiments/authentication/query_param.html new file mode 100644 index 000000000..27aa454a4 --- /dev/null +++ b/experiments/authentication/query_param.html @@ -0,0 +1,14 @@ + + + + Query parameter | WebSocket Authentication + + + +

[??] Query parameter

+

[OK] Query parameter

+

[KO] Query parameter

+ + + + diff --git a/experiments/authentication/query_param.js b/experiments/authentication/query_param.js new file mode 100644 index 000000000..6a54d0b6c --- /dev/null +++ b/experiments/authentication/query_param.js @@ -0,0 +1,11 @@ +window.addEventListener("DOMContentLoaded", () => { + const uri = `ws://localhost:8002/?token=${token}`; + const websocket = new WebSocket(uri); + + websocket.onmessage = ({ data }) => { + // event.data is expected to be "Hello !" + websocket.send(`Goodbye ${data.slice(6, -1)}.`); + }; + + runTest(websocket); +}); diff --git a/experiments/authentication/script.js b/experiments/authentication/script.js new file mode 100644 index 000000000..ec4e5e670 --- /dev/null +++ b/experiments/authentication/script.js @@ -0,0 +1,51 @@ +var token = window.parent.token; + +function getExpectedEvents() { + return [ + { + type: "open", + }, + { + type: "message", + data: `Hello ${window.parent.user}!`, + }, + { + type: "close", + code: 1000, + reason: "", + wasClean: true, + }, + ]; +} + +function isEqual(expected, actual) { + // good enough for our purposes here! + return JSON.stringify(expected) === JSON.stringify(actual); +} + +function testStep(expected, actual) { + if (isEqual(expected, actual)) { + document.body.className = "ok"; + } else if (isEqual(expected.slice(0, actual.length), actual)) { + document.body.className = "test"; + } else { + document.body.className = "ko"; + } +} + +function runTest(websocket) { + const expected = getExpectedEvents(); + var actual = []; + websocket.addEventListener("open", ({ type }) => { + actual.push({ type }); + testStep(expected, actual); + }); + websocket.addEventListener("message", ({ type, data }) => { + actual.push({ type, data }); + testStep(expected, actual); + }); + websocket.addEventListener("close", ({ type, code, reason, wasClean }) => { + actual.push({ type, code, reason, wasClean }); + testStep(expected, actual); + }); +} diff --git a/experiments/authentication/style.css b/experiments/authentication/style.css new file mode 100644 index 000000000..6e3918cca --- /dev/null +++ b/experiments/authentication/style.css @@ -0,0 +1,69 @@ +/* page layout */ + +body { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + margin: 0; + height: 100vh; +} +div.title, iframe { + width: 100vw; + height: 20vh; + border: none; +} +div.title { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; +} +h1, p { + margin: 0; + width: 24em; +} + +/* text style */ + +h1, input, p { + font-family: monospace; + font-size: 3em; +} +input { + color: #333; + border: 3px solid #999; + padding: 1em; +} +input:focus { + border-color: #333; + outline: none; +} +input::placeholder { + color: #999; + opacity: 1; +} + +/* test results */ + +body.test { + background-color: #666; + color: #fff; +} +body.ok { + background-color: #090; + color: #fff; +} +body.ko { + background-color: #900; + color: #fff; +} +body > p { + display: none; +} +body > p.title, +body.test > p.test, +body.ok > p.ok, +body.ko > p.ko { + display: block; +} diff --git a/experiments/authentication/test.html b/experiments/authentication/test.html new file mode 100644 index 000000000..3883d6a39 --- /dev/null +++ b/experiments/authentication/test.html @@ -0,0 +1,15 @@ + + + + WebSocket Authentication + + + +

WebSocket Authentication

+ + + + + + + diff --git a/experiments/authentication/test.js b/experiments/authentication/test.js new file mode 100644 index 000000000..428830ff3 --- /dev/null +++ b/experiments/authentication/test.js @@ -0,0 +1,6 @@ +// for connecting to WebSocket servers +var token = document.body.dataset.token; + +// for test assertions only +const params = new URLSearchParams(window.location.search); +var user = params.get("user"); diff --git a/experiments/authentication/user_info.html b/experiments/authentication/user_info.html new file mode 100644 index 000000000..7b9c99c73 --- /dev/null +++ b/experiments/authentication/user_info.html @@ -0,0 +1,14 @@ + + + + User information | WebSocket Authentication + + + +

[??] User information

+

[OK] User information

+

[KO] User information

+ + + + diff --git a/experiments/authentication/user_info.js b/experiments/authentication/user_info.js new file mode 100644 index 000000000..1dab2ce4c --- /dev/null +++ b/experiments/authentication/user_info.js @@ -0,0 +1,11 @@ +window.addEventListener("DOMContentLoaded", () => { + const uri = `ws://token:${token}@localhost:8004/`; + const websocket = new WebSocket(uri); + + websocket.onmessage = ({ data }) => { + // event.data is expected to be "Hello !" + websocket.send(`Goodbye ${data.slice(6, -1)}.`); + }; + + runTest(websocket); +}); From f7a62680bc7df0f17efac83109daf3bdc14bc0f5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 15:14:52 +0200 Subject: [PATCH 032/762] Simplify cookie-based authentication. --- experiments/authentication/cookie.js | 51 ++++++++------------- experiments/authentication/cookie_iframe.js | 43 ++++------------- 2 files changed, 27 insertions(+), 67 deletions(-) diff --git a/experiments/authentication/cookie.js b/experiments/authentication/cookie.js index 9c5d5f078..2cca34fcb 100644 --- a/experiments/authentication/cookie.js +++ b/experiments/authentication/cookie.js @@ -1,36 +1,23 @@ -// wait for "load" rather than "DOMContentLoaded" -// to ensure that the iframe has finished loading -window.addEventListener("load", () => { - // create message channel to communicate - // with the iframe - const channel = new MessageChannel(); - const port1 = channel.port1; +// send token to iframe +window.addEventListener("DOMContentLoaded", () => { + const iframe = document.querySelector("iframe"); + iframe.addEventListener("load", () => { + iframe.contentWindow.postMessage(token, "http://localhost:8003"); + }); +}); - // receive WebSocket events from the iframe - const expected = getExpectedEvents(); - var actual = []; - port1.onmessage = ({ data }) => { - // respond to messages - if (data.type == "message") { - port1.postMessage({ - type: "message", - message: `Goodbye ${data.data.slice(6, -1)}.`, - }); - } - // run tests - actual.push(data); - testStep(expected, actual); - }; +// once iframe has set cookie, open WebSocket connection +window.addEventListener("message", ({ origin }) => { + if (origin !== "http://localhost:8003") { + return; + } - // send message channel to the iframe - const iframe = document.querySelector("iframe"); - const origin = "http://localhost:8003"; - const ports = [channel.port2]; - iframe.contentWindow.postMessage("", origin, ports); + const websocket = new WebSocket("ws://localhost:8003/"); - // send token to the iframe - port1.postMessage({ - type: "open", - token: token, - }); + websocket.onmessage = ({ data }) => { + // event.data is expected to be "Hello !" + websocket.send(`Goodbye ${data.slice(6, -1)}.`); + }; + + runTest(websocket); }); diff --git a/experiments/authentication/cookie_iframe.js b/experiments/authentication/cookie_iframe.js index 88b5d1f2b..2d2e692e8 100644 --- a/experiments/authentication/cookie_iframe.js +++ b/experiments/authentication/cookie_iframe.js @@ -1,36 +1,9 @@ -// receive message channel from the parent window -window.addEventListener("message", ({ ports }) => { - const port2 = ports[0]; - var websocket; - port2.onmessage = ({ data }) => { - switch (data.type) { - case "open": - websocket = initWebSocket(data.token, port2); - break; - case "message": - websocket.send(data.message); - break; - case "close": - websocket.close(data.code, data.reason); - break; - } - }; -}); - -// open WebSocket connection and relay events to the parent window -function initWebSocket(token, port2) { - document.cookie = `token=${token}; SameSite=Strict`; - const websocket = new WebSocket("ws://localhost:8003/"); +// receive token from the parent window, set cookie and notify parent +window.addEventListener("message", ({ origin, data }) => { + if (origin !== "http://localhost:8000") { + return; + } - websocket.addEventListener("open", ({ type }) => { - port2.postMessage({ type }); - }); - websocket.addEventListener("message", ({ type, data }) => { - port2.postMessage({ type, data }); - }); - websocket.addEventListener("close", ({ type, code, reason, wasClean }) => { - port2.postMessage({ type, code, reason, wasClean }); - }); - - return websocket; -} + document.cookie = `token=${data}; SameSite=Strict`; + window.parent.postMessage("", "http://localhost:8000"); +}); From 920207c6d4d2b31def30970fc47c968806c2b64e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 14:58:00 +0200 Subject: [PATCH 033/762] Discuss authentication in docs. --- docs/spelling_wordlist.txt | 2 + docs/topics/authentication.rst | 348 +++++++++++++++++++++++++++++++++ docs/topics/authentication.svg | 63 ++++++ docs/topics/index.rst | 1 + 4 files changed, 414 insertions(+) create mode 100644 docs/topics/authentication.rst create mode 100644 docs/topics/authentication.svg diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 19c4b7c44..030917491 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -25,6 +25,7 @@ datastructures django dyno fractalideas +iframe IPv iterable keepalive @@ -65,3 +66,4 @@ websocket websockets ws wss +www diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst new file mode 100644 index 000000000..44c9c6151 --- /dev/null +++ b/docs/topics/authentication.rst @@ -0,0 +1,348 @@ +Authentication +============== + +The WebSocket protocol was designed for creating web applications that need +bidirectional communication between clients running in browsers and servers. + +In most practical use cases, WebSocket servers need to authenticate clients in +order to route communications appropriately and securely. + +:rfc:`6455` stays elusive when it comes to authentication: + + This protocol doesn't prescribe any particular way that servers can + authenticate clients during the WebSocket handshake. The WebSocket + server can use any client authentication mechanism available to a + generic HTTP server, such as cookies, HTTP authentication, or TLS + authentication. + +None of these three mechanisms works well in practice. Using cookies is +cumbersome, HTTP authentication isn't supported by all mainstream browsers, +and TLS authentication in a browser is an esoteric user experience. + +Fortunately, there are better alternatives! Let's discuss them. + +System design +------------- + +Consider a setup where the WebSocket server is separate from the HTTP server. + +Most servers built with websockets to complement a web application adopt this +design because websockets doesn't aim at supporting HTTP. + +The following diagram illustrates the authentication flow. + +.. image:: authentication.svg + +Assuming the current user is authenticated with the HTTP server (1), the +application needs to obtain credentials from the HTTP server (2) in order to +send them to the WebSocket server (3), who can check them against the database +of user accounts (4). + +Usernames and passwords aren't a good choice of credentials here, if only +because passwords aren't available in clear text in the database. + +Tokens linked to user accounts are a better choice. These tokens must be +impossible to forge by an attacker. For additional security, they can be +short-lived or even single-use. + +Sending credentials +------------------- + +Assume the web application obtained authentication credentials, likely a +token, from the HTTP server. There's four options for passing them to the +WebSocket server. + +1. **Sending credentials as the first message in the WebSocket connection.** + + This is fully reliable and the most secure mechanism in this discussion. It + has two minor downsides: + + * Authentication is performed at the application layer. Ideally, it would + be managed at the protocol layer. + + * Authentication is performed after the WebSocket handshake, making it + impossible to monitor authentication failures with HTTP response codes. + +2. **Adding credentials to the WebSocket URI in a query parameter.** + + This is also fully reliable but less secure. Indeed, it has a major + downside: + + * URIs end up in logs, which leaks credentials. Even if that risk could be + lowered with single-use tokens, it is usually considered unacceptable. + + Authentication is still performed at the application layer but it can + happen before the WebSocket handshake, which improves separation of + concerns and enables responding to authentication failures with HTTP 401. + +3. **Setting a cookie on the domain of the WebSocket URI.** + + Cookies are undoubtedly the most common and hardened mechanism for sending + credentials from a web application to a server. In a HTTP application, + credentials would be a session identifier or a serialized, signed session. + + Unfortunately, when the WebSocket server runs on a different domain from + the web application, this idea bumps into the `Same-Origin Policy`_. For + security reasons, setting a cookie on a different origin is impossible. + + The proper workaround consists in: + + * creating a hidden iframe_ served from the domain of the WebSocket server + * sending the token to the iframe with postMessage_ + * setting the cookie in the iframe + + before opening the WebSocket connection. + + Sharing a parent domain (e.g. example.com) between the HTTP server (e.g. + www.example.com) and the WebSocket server (e.g. ws.example.com) and setting + the cookie on that parent domain would work too. + + However, the cookie would be shared with all subdomains of the parent + domain. For a cookie containing credentials, this is unacceptable. + +.. _Same-Origin Policy: https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy +.. _iframe: https://developer.mozilla.org/en-US/docs/Web/HTML/Element/iframe +.. _postMessage: https://developer.mozilla.org/en-US/docs/Web/API/MessagePort/postMessage + +4. **Adding credentials to the WebSocket URI in user information.** + + Letting the browser perform HTTP Basic Auth is a nice idea in theory. + + In practice it doesn't work due to poor support in browsers. + + As of May 2021: + + * Chrome 90 behaves as expected. + + * Firefox 88 caches credentials too aggressively. + + When connecting again to the same server with new credentials, it reuses + the old credentials, which may be expired, resulting in an HTTP 401. Then + the next connection succeeds. Perhaps errors clear the cache. + + When tokens are short-lived on single-use, this bug produces an + interesting effect: every other WebSocket connection fails. + + * Safari 14 ignores credentials entirely. + +Two other options are off the table: + +1. **Setting a custom HTTP header** + + This would be the most elegant mechanism, solving all issues with the options + discussed above. + + Unfortunately, it doesn't work because the `WebSocket API`_ doesn't support + `setting custom headers`_. + +.. _WebSocket API: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API +.. _setting custom headers: https://github.com/whatwg/html/issues/3062 + +2. **Authenticating with a TLS certificate** + + While this is suggested by the RFC, installing a TLS certificate is too far + from the mainstream experience of browser users. This could make sense in + high security contexts. I hope developers working on such projects don't + take security advice from the documentation of random open source projects. + +Let's experiment! +----------------- + +The `experiments/authentication`_ directory demonstrates these techniques. + +Run the experiment in an environment where websockets is installed: + +.. _experiments/authentication: https://github.com/aaugustin/websockets/tree/main/experiments/authentication + +.. code:: console + + $ python experiments/authentication/app.py + Running on http://localhost:8000/ + +When you browse to the HTTP server at http://localhost:8000/ and you submit a +username, the server creates a token and returns a testing web page. + +This page opens WebSocket connections to four WebSocket servers running on +four different origins. It attempts to authenticate with the token in four +different ways. + +First message +............. + +As soon as the connection is open, the client sends a message containing the +token: + +.. code:: javascript + + const websocket = new WebSocket("ws://.../"); + websocket.onopen = () => websocket.send(token); + + // ... + +At the beginning of the connection handler, the server receives this message +and authenticates the user. If authentication fails, the server closes the +connection: + +.. code:: python + + async def first_message_handler(websocket, path): + token = await websocket.recv() + user = get_user(token) + if user is None: + await websocket.close(1011, "authentication failed") + return + + ... + +Query parameter +............... + +The client adds the token to the WebSocket URI in a query parameter before +opening the connection: + +.. code:: javascript + + const uri = `ws://.../?token=${token}`; + const websocket = new WebSocket(uri); + + // ... + +The server intercepts the HTTP request, extracts the token and authenticates +the user. If authentication fails, it returns a HTTP 401: + +.. code:: python + + class QueryParamProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + token = get_query_parameter(path, "token") + if token is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = get_user(token) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + async def query_param_handler(websocket, path): + user = websocket.user + + ... + +Cookie +...... + +The client sets a cookie containing the token before opening the connection. + +The cookie must be set by an iframe loaded from the same origin as the +WebSocket server. This requires passing the token to this iframe. + +.. code:: javascript + + // in main window + iframe.contentWindow.postMessage(token, "http://..."); + + // in iframe + document.cookie = `token=${data}; SameSite=Strict`; + + // in main window + const websocket = new WebSocket("ws://.../"); + + // ... + +This sequence must be synchronized between the main window and the iframe. +This involves several events. Look at the full implementation for details. + +The server intercepts the HTTP request, extracts the token and authenticates +the user. If authentication fails, it returns a HTTP 401: + +.. code:: python + + class CookieProtocol(websockets.WebSocketServerProtocol): + async def process_request(self, path, headers): + # Serve iframe on non-WebSocket requests + ... + + token = get_cookie(headers.get("Cookie", ""), "token") + if token is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" + + user = get_user(token) + if user is None: + return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + + self.user = user + + async def cookie_handler(websocket, path): + user = websocket.user + + ... + +User information +................ + +The client adds the token to the WebSocket URI in user information before +opening the connection: + +.. code:: javascript + + const uri = `ws://token:${token}@.../`; + const websocket = new WebSocket(uri); + + // ... + +Since HTTP Basic Auth is designed to accept a username and a password rather +than a token, we send ``token`` as username and the token as password. + +The server intercepts the HTTP request, extracts the token and authenticates +the user. If authentication fails, it returns a HTTP 401: + +.. code:: python + + class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): + async def check_credentials(self, username, password): + if username != "token": + return False + + user = get_user(password) + if user is None: + return False + + self.user = user + return True + + async def user_info_handler(websocket, path): + user = websocket.user + + ... + +Machine-to-machine authentication +--------------------------------- + +When the WebSocket client is a standalone program rather than a script running +in a browser, there are far fewer constraints. HTTP Authentication is the best +solution in this scenario. + +To authenticate a websockets client with HTTP Basic Authentication +(:rfc:`7617`), include the credentials in the URI: + +.. code:: python + + async with websockets.connect( + f"wss://{username}:{password}@example.com", + ) as websocket: + ... + +(You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they +contain unsafe characters.) + +To authenticate a websockets client with HTTP Bearer Authentication +(:rfc:`6750`), add a suitable ``Authorization`` header: + +.. code:: python + + async with websockets.connect( + "wss://example.com", + extra_headers={"Authorization": f"Bearer {token}"} + ) as websocket: + ... diff --git a/docs/topics/authentication.svg b/docs/topics/authentication.svg new file mode 100644 index 000000000..ad2ad0e44 --- /dev/null +++ b/docs/topics/authentication.svg @@ -0,0 +1,63 @@ +HTTPserverWebSocketserverweb appin browseruser accounts(1) authenticate user(2) obtain credentials(3) send credentials(4) authenticate user \ No newline at end of file diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 157278f76..5363de0ce 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -4,5 +4,6 @@ Topics .. toctree:: :maxdepth: 2 + authentication design security From c91b4c2a01bbf8dd41d521a29710470c8b73599b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 18:51:27 +0200 Subject: [PATCH 034/762] Use constant-time comparison for passwords. --- docs/project/changelog.rst | 2 ++ src/websockets/legacy/auth.py | 28 +++++++++++++++------------- tests/legacy/test_auth.py | 13 ++++++++++--- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c2c869c13..a44dd0418 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -45,6 +45,8 @@ They may change at any time. * Optimized default compression settings to reduce memory usage. +* Protected against timing attacks on HTTP Basic Auth. + * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e7e69cac1..16016e6fd 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -6,6 +6,7 @@ import functools +import hmac import http from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast @@ -154,24 +155,23 @@ def basic_auth_protocol_factory( if credentials is not None: if is_credentials(credentials): - - async def check_credentials(username: str, password: str) -> bool: - return (username, password) == credentials - + credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): credentials_list = list(credentials) - if all(is_credentials(item) for item in credentials_list): - credentials_dict = dict(credentials_list) - - async def check_credentials(username: str, password: str) -> bool: - return credentials_dict.get(username) == password - - else: + if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") - else: raise TypeError(f"invalid credentials argument: {credentials}") + credentials_dict = dict(credentials_list) + + async def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + if create_protocol is None: # Not sure why mypy cannot figure this out. create_protocol = cast( @@ -180,5 +180,7 @@ async def check_credentials(username: str, password: str) -> bool: ) return functools.partial( - create_protocol, realm=realm, check_credentials=check_credentials + create_protocol, + realm=realm, + check_credentials=check_credentials, ) diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index c4dbd88ad..2b670c31f 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -1,3 +1,4 @@ +import hmac import unittest import urllib.error @@ -27,7 +28,7 @@ async def process_request(self, path, request_headers): class CheckWebSocketServerProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): - return password == "letmein" + return hmac.compare_digest(password, "letmein") class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): @@ -81,7 +82,7 @@ def test_basic_auth_bad_multiple_credentials(self): ) async def check_credentials(username, password): - return password == "iloveyou" + return hmac.compare_digest(password, "iloveyou") create_protocol_check_credentials = basic_auth_protocol_factory( realm="auth-tests", @@ -158,7 +159,13 @@ def test_basic_auth_unsupported_credentials_details(self): self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n") @with_server(create_protocol=create_protocol) - def test_basic_auth_invalid_credentials(self): + def test_basic_auth_invalid_username(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(user_info=("goodbye", "iloveyou")) + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_invalid_password(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(user_info=("hello", "ihateyou")) self.assertEqual(raised.exception.status_code, 401) From 4454e7df32ab82b439dec1b2ab64071bc22b0b38 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 21:30:37 +0200 Subject: [PATCH 035/762] Normalize code style. --- example/deployment/supervisor/app.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index 5fa596d0a..484566bc8 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -18,7 +18,9 @@ async def main(): loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) async with websockets.serve( - echo, "", 8080, + echo, + host="", + port=8080, reuse_port=True, ): await stop From 2222beaaa2900847ecf7b3b487ac6106254504d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 21 May 2021 22:38:10 +0200 Subject: [PATCH 036/762] Improve Django integration guide. Switch to a simpler and more secure authentication design. --- docs/howto/django.rst | 193 +++++++++++++++++-------------- example/django/authentication.py | 40 ++----- example/django/notifications.py | 21 ++-- 3 files changed, 121 insertions(+), 133 deletions(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 1b4b4d3b9..e776f5f8c 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -17,39 +17,32 @@ WebSocket, you have two main options. This guide shows how to implement the second technique with websockets. It assumes familiarity with Django. -Authenticating connections --------------------------- - -Since the websockets server will run outside of Django, we need to connect it -to ``django.contrib.auth``. - -Our clients are running in browser. The `WebSocket API`_ doesn't support -setting `custom headers`_ so our options boil down to: - -* HTTP Basic Auth: this seems technically possible but isn't supported by - Firefox (`bug 1229443`_) so browser support is clearly insufficient. -* Sharing cookies: this is technically possible if there's a common parent - domain between the Django server (e.g. api.example.com) and the websockets - server (e.g. ws.example.com). However, there's a risk to share cookies too - widely (e.g. with anything under .example.com here). For authentication - cookies, this risk seems unacceptable. -* Sending an authentication ticket: Django generates a secure single-use token - with the user ID. The browser includes this token in the WebSocket URI when - it connects to the server in order to authenticate. It could also send the - ticket over the WebSocket connection in the first message, however this is a - bit more difficult to monitor, as you can't detect authentication failures - simply by looking at HTTP response codes. - -.. _custom headers: https://github.com/whatwg/html/issues/3062 -.. _WebSocket API: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API -.. _bug 1229443: https://bugzilla.mozilla.org/show_bug.cgi?id=1229443 - -To generate our authentication tokens, we'll use `django-sesame`_, a small -library designed exactly for this purpose. +Authenticate connections +------------------------ + +Since the websockets server runs outside of Django, we need to integrate it +with ``django.contrib.auth``. + +We will generate authentication tokens in the Django project. Then we will +send them to the websockets server, where they will authenticate the user. + +Generating a token for the current user and making it available in the browser +is up to you. You could render the token in a template or fetch it with an API +call. + +Refer to the topic guide on :doc:`authentication <../topics/authentication>` +for details on this design. + +Generate tokens +............... + +We want secure, short-lived tokens containing the user ID. We'll rely on +`django-sesame`_, a small library designed exactly for this purpose. .. _django-sesame: https://github.com/aaugustin/django-sesame -Add django-sesame to the dependencies of your Django project, install it and configure it in the settings of the project: +Add django-sesame to the dependencies of your Django project, install it, and +configure it in the settings of the project: .. code:: python @@ -65,15 +58,22 @@ You don't need ``"sesame.middleware.AuthenticationMiddleware"``. It is for authenticating users in the Django server, while we're authenticating them in the websockets server. -We'd like our tokens to be valid for 30 seconds and usable only once. A -shorter lifespan is possible but it would make manual testing difficult. - -Configure django-sesame accordingly in the settings of your Django project: +We'd like our tokens to be valid for 30 seconds. We expect web pages to load +and to establish the WebSocket connection within this delay. Configure +django-sesame accordingly in the settings of your Django project: .. code:: python SESAME_MAX_AGE = 30 - SESAME_ONE_TIME = True + +If you expect your web site to load faster for all clients, a shorter lifespan +is possible. However, in the context of this document, it would make manual +testing more difficult. + +You could also enable single-use tokens. However, this would update the last +login date of the user every time a WebSocket connection is established. This +doesn't seem like a good idea, both in terms of behavior and in terms of +performance. Now you can generate tokens in a ``django-admin shell`` as follows: @@ -86,100 +86,103 @@ Now you can generate tokens in a ``django-admin shell`` as follows: >>> get_token(user) '' -Keep this console open: since tokens are single use, we'll have to generate a -new token every time we want to test our server. +Keep this console open: since tokens expire after 30 seconds, you'll have to +generate a new token every time you want to test connecting to the server. -Let's move on to the websockets server. Add websockets to the dependencies of -your Django project and install it. +Validate tokens +............... -Now here's a way to implement authentication. We're taking advantage of the -:meth:`~websockets.server.WebSocketServerProtocol.process_request` hook to -authenticate requests. If authentication succeeds, we store the user as an -attribute of the connection in order to make it available to the connection -handler. If authentication fails, we return a HTTP 401 Unauthorized error. +Let's move on to the websockets server. + +Add websockets to the dependencies of your Django project and install it. +Indeed, we're going to reuse the environment of the Django project, so we can +call its APIs in the websockets server. + +Now here's how to implement authentication. .. literalinclude:: ../../example/django/authentication.py Let's unpack this code. -We're using Django in a `standalone script`_. This requires setting the -``DJANGO_SETTINGS_MODULE`` environment variable and calling ``django.setup()`` -before doing anything with Django. +We're calling ``django.setup()`` before doing anything with Django because +we're using Django in a `standalone script`_. This assumes that the +``DJANGO_SETTINGS_MODULE`` environment variable is set to the Python path to +your settings module. .. _standalone script: https://docs.djangoproject.com/en/stable/topics/settings/#calling-django-setup-is-required-for-standalone-django-usage -We subclass :class:`~websockets.server.WebSocketServerProtocol` and override -:meth:`~websockets.server.WebSocketServerProtocol.process_request`, where: - -* We extract the token from the URL with the ``get_sesame()`` utility function - defined just above. If the token is missing, we return a HTTP 401 error. -* We authenticate the user with ``get_user()``, the API for `authentication - outside views`_. If authentication fails, we return a HTTP 401 error. +The connection handler reads the first message received from the client, which +is expected to contain a django-sesame token. Then it authenticates the user +with ``get_user()``, the API for `authentication outside views`_. If +authentication fails, it closes the connection and exits. .. _authentication outside views: https://github.com/aaugustin/django-sesame#authentication-outside-views -When we call an API that makes a database query, we wrap the call in -:func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support -asynchronous I/O. We would block the event loop if we didn't run these calls -in a separate thread. :func:`~asyncio.to_thread` is available since Python -3.9; in earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. +When we call an API that makes a database query such as ``get_user()``, we +wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't +support asynchronous I/O. It would block the event loop if it didn't run in a +separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In +earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. -The connection handler accesses the logged-in user that we stored as an -attribute of the connection object so we can test that authentication works. - -Finally, we start a server with :func:`~websockets.serve`, with the -``create_protocol`` pointing to our subclass of -:class:`~websockets.server.WebSocketServerProtocol`. +Finally, we start a server with :func:`~websockets.serve`. We're ready to test! -Make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set to the -Python path to your settings module and start the websockets server. If you -saved the server implementation to a file called ``authentication.py``: +Save this code to a file called ``authentication.py``, make sure the +``DJANGO_SETTINGS_MODULE`` environment variable is set properly, and start the +websockets server: .. code:: console $ python authentication.py -Open a new shell, generate a new token — remember, they're only valid for -30 seconds — and use it to connect to your server: +Generate a new token — remember, they're only valid for 30 seconds — and use +it to connect to your server. Paste your token and press Enter when you get a +prompt: .. code:: console - $ python -m websockets "ws://localhost:8888/?sesame=" - Connected to ws://localhost:8888/?sesame= + $ python -m websockets ws://localhost:8888/ + Connected to ws://localhost:8888/ + > < Hello ! Connection closed: code = 1000 (OK), no reason. It works! -If we try to reuse the same token, the connection is now rejected: +If you enter an expired or invalid token, authentication fails and the server +closes the connection: .. code:: console - $ python -m websockets "ws://localhost:8888/?sesame=" - Failed to connect to ws://localhost:8888/?sesame=: - server rejected WebSocket connection: HTTP 401. + $ python -m websockets ws://localhost:8888/ + Connected to ws://localhost:8888. + > not a token + Connection closed: code = 1011 (unexpected error), reason = authentication failed. You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: .. code:: javascript - webSocket = new WebSocket("ws://localhost:8888/?sesame="); - webSocket.onmessage = (event) => console.log(event.data); + websocket = new WebSocket("ws://localhost:8888/"); + websocket.onopen = (event) => websocket.send(""); + websocket.onmessage = (event) => console.log(event.data); -Streaming events ----------------- +Stream events +------------- We can connect and authenticate but our server doesn't do anything useful yet! -Let's send a message every time any user makes any action in the admin. This +Let's send a message every time a user makes an action in the admin. This message will be broadcast to all users who can access the model on which the action was made. This may be used for showing notifications to other users. Many use cases for WebSocket with Django follow a similar pattern. +Set up event bus +................ + We need a event bus to enable communications between Django and websockets. Both sides connect permanently to the bus. Then Django writes events and websockets reads them. For the sake of simplicity, we'll rely on `Redis @@ -187,15 +190,15 @@ Pub/Sub`_. .. _Redis Pub/Sub: https://redis.io/topics/pubsub -Let's start by writing events. The easiest way to add Redis to a Django -project is by configuring a cache backend with `django-redis`_. This library -manages connections to Redis efficiently, persisting them between requests, -and provides an API to access the Redis connection directly. +The easiest way to add Redis to a Django project is by configuring a cache +backend with `django-redis`_. This library manages connections to Redis +efficiently, persisting them between requests, and provides an API to access +the Redis connection directly. .. _django-redis: https://github.com/jazzband/django-redis Install Redis, add django-redis to the dependencies of your Django project, -install it and configure it in the settings of the project: +install it, and configure it in the settings of the project: .. code:: python @@ -209,6 +212,11 @@ install it and configure it in the settings of the project: If you already have a default cache, add a new one with a different name and change ``get_redis_connection("default")`` in the code below to the same name. +Publish events +.............. + +Now let's write events to the bus. + Add the following code to a module that is imported when your Django project starts. Typically, you would put it in a ``signals.py`` module, which you would import in the ``AppConfig.ready()`` method of one of your apps: @@ -236,9 +244,11 @@ Leave this command running, start the Django development server and make changes in the admin: add, modify, or delete objects. You should see corresponding events published to the ``"events"`` stream. +Broadcast events +................ + Now let's turn to reading events and broadcasting them to connected clients. -We'll reuse our custom ``ServerProtocol`` class for authentication. Then we -need to add several features: +We need to add several features: * Keep track of connected clients so we can broadcast messages. * Tell which content types the user has permission to view or to change. @@ -250,10 +260,10 @@ Here's a complete implementation. .. literalinclude:: ../../example/django/notifications.py Since the ``get_content_types()`` function makes a database query, it is -wrapped inside ``asyncio.to_thread()``. It runs once when each WebSocket +wrapped inside :func:`asyncio.to_thread()`. It runs once when each WebSocket connection is open; then its result is cached for the lifetime of the connection. Indeed, running it for each message would trigger database queries -for all connected users at the same time, which could hurt the database. +for all connected users at the same time, which would hurt the database. The connection handler merely registers the connection in a global variable, associated to the list of content types for which events should be sent to @@ -270,6 +280,9 @@ send a message to a slow client. Since Redis can publish a message to multiple subscribers, multiple instances of this server can safely run in parallel. +Does it scale? +-------------- + In theory, given enough servers, this design can scale to a hundred million clients, since Redis can handle ten thousand servers and each server can handle ten thousand clients. In practice, you would need a more scalable diff --git a/example/django/authentication.py b/example/django/authentication.py index c0a061109..bbb3db02a 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -1,8 +1,6 @@ #!/usr/bin/env python import asyncio -import http -import urllib.parse import django import websockets @@ -12,40 +10,18 @@ from sesame.utils import get_user -def get_sesame(path): - """Utility function to extract sesame token from request path.""" - query = urllib.parse.urlparse(path).query - params = urllib.parse.parse_qs(query) - sesame = params.get("sesame", []) - if len(sesame) == 1: - return sesame[0] - - -class ServerProtocol(websockets.WebSocketServerProtocol): - async def process_request(self, path, headers): - """Authenticate users with a django-sesame token.""" - sesame = get_sesame(path) - if sesame is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = await asyncio.to_thread(get_user, sesame) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user - - async def handler(websocket, path): - await websocket.send(f"Hello {websocket.user}!") + sesame = await websocket.recv() + user = await asyncio.to_thread(get_user, sesame) + if user is None: + await websocket.close(1011, "authentication failed") + return + + await websocket.send(f"Hello {user}!") async def main(): - async with websockets.serve( - handler, - "localhost", - 8888, - create_protocol=ServerProtocol, - ): + async with websockets.serve(handler, "localhost", 8888): await asyncio.Future() # run forever diff --git a/example/django/notifications.py b/example/django/notifications.py index ad2751d98..641643f92 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -10,9 +10,7 @@ django.setup() from django.contrib.contenttypes.models import ContentType - -# Reuse our custom protocol to authenticate connections -from authentication import ServerProtocol +from sesame.utils import get_user CONNECTIONS = {} @@ -31,8 +29,14 @@ def get_content_types(user): async def handler(websocket, path): - """Register connection in CONNECTIONS dict, until it's closed.""" - ct_ids = await asyncio.to_thread(get_content_types, websocket.user) + """Authenticate user and register connection in CONNECTIONS.""" + sesame = await websocket.recv() + user = await asyncio.to_thread(get_user, sesame) + if user is None: + await websocket.close(1011, "authentication failed") + return + + ct_ids = await asyncio.to_thread(get_content_types, user) CONNECTIONS[websocket] = {"content_type_ids": ct_ids} try: await websocket.wait_closed() @@ -57,12 +61,7 @@ async def process_events(): async def main(): - async with websockets.serve( - handler, - "localhost", - 8888, - create_protocol=ServerProtocol, - ): + async with websockets.serve(handler, "localhost", 8888): await process_events() # runs forever From 8acf7ccf39c96acec293a2ec668e299f6f29e46b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 18:33:41 +0200 Subject: [PATCH 037/762] Setup issue template Fix #835. --- .github/ISSUE_TEMPLATE/config.yml | 1 + .github/ISSUE_TEMPLATE/issue.md | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/config.yml create mode 100644 .github/ISSUE_TEMPLATE/issue.md diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..3ba13e0ce --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md new file mode 100644 index 000000000..8efabc617 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -0,0 +1,26 @@ +--- +name: Report an issue +about: Let us know about a problem with websockets +title: '' +labels: '' +assignees: '' + +--- + + From 81edc5f4adf6054852d041a222863d350a3d92d7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 24 May 2021 19:20:18 +0200 Subject: [PATCH 038/762] Move compression docs to a topic guide. --- .gitignore | 2 +- docs/howto/compression.rst | 51 ------ docs/howto/deployment.rst | 75 +------- docs/howto/index.rst | 1 - docs/topics/compression.rst | 172 ++++++++++++++++++ docs/topics/index.rst | 1 + .../compression/benchmark.py | 8 +- .../compression/client.py | 8 +- .../compression/server.py | 4 +- 9 files changed, 186 insertions(+), 136 deletions(-) delete mode 100644 docs/howto/compression.rst create mode 100644 docs/topics/compression.rst rename benchmark/compression.py => experiments/compression/benchmark.py (97%) rename benchmark/mem_client.py => experiments/compression/client.py (87%) rename benchmark/mem_server.py => experiments/compression/server.py (96%) diff --git a/.gitignore b/.gitignore index ac68ff739..8f9e7dc51 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,8 @@ .mypy_cache .tox build/ -benchmark/corpus.pkl compliance/reports/ +experiments/compression/corpus.pkl dist/ docs/_build/ htmlcov/ diff --git a/docs/howto/compression.rst b/docs/howto/compression.rst deleted file mode 100644 index 9023cec56..000000000 --- a/docs/howto/compression.rst +++ /dev/null @@ -1,51 +0,0 @@ -Compression -=========== - -:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable -compression by default. - -If you want to disable it, set ``compression=None``:: - - import websockets - - websockets.connect(..., compression=None) - - websockets.serve(..., compression=None) - -.. _per-message-deflate-configuration-example: - -You can also configure the Per-Message Deflate extension explicitly if you -want to customize compression settings:: - - import websockets - from websockets.extensions import permessage_deflate - - websockets.connect( - ..., - extensions=[ - permessage_deflate.ClientPerMessageDeflateFactory( - server_max_window_bits=11, - client_max_window_bits=11, - compress_settings={'memLevel': 4}, - ), - ], - ) - - websockets.serve( - ..., - extensions=[ - permessage_deflate.ServerPerMessageDeflateFactory( - server_max_window_bits=11, - client_max_window_bits=11, - compress_settings={'memLevel': 4}, - ), - ], - ) - -The window bits and memory level values chosen in these examples reduce memory -usage. You can read more about :ref:`optimizing compression settings -`. - -Refer to the API documentation of -:class:`~permessage_deflate.ClientPerMessageDeflateFactory` and -:class:`~permessage_deflate.ServerPerMessageDeflateFactory` for details. diff --git a/docs/howto/deployment.rst b/docs/howto/deployment.rst index 35720e8f1..d8264b2cc 100644 --- a/docs/howto/deployment.rst +++ b/docs/howto/deployment.rst @@ -66,82 +66,11 @@ Memory usage of a single connection is the sum of: Baseline ........ -.. _compression-settings: - Compression settings are the main factor affecting the baseline amount of memory used by each connection. -If you'd like to customize compression settings, here are the main knobs. - -- Context Takeover is necessary to get good performance for almost all - applications. It should remain enabled. -- Window Bits is a trade-off between memory usage and compression rate. - It should be an integer between 9 (lowest memory usage) and 15 (highest - compression rate). Setting it to 8 is possible but triggers a bug in some - versions of zlib. -- Memory Level is a trade-off between memory usage and compression speed. - However, a lower memory level can increase speed thanks to memory locality, - even if the CPU does more work! It should be an integer between 1 (lowest - memory usage) and 9 (highest compression speed in theory, not in practice). - -By default, websockets enables compression with conservative settings that -optimize memory usage at the cost of a slightly worse compression rate: Window -Bits = 12 and Memory Level = 5. This strikes a good balance for small messages -that are typical of WebSocket servers. - -If you'd like to configure different compression settings, see this -:ref:`example `. If you don't set -limits on Window Bits and neither does the remote endpoint, it defaults to the -maximum value of 15. If you don't set Memory Level, it defaults to 8 — more -accurately, to ``zlib.DEF_MEM_LEVEL`` which is 8. - -Here's how various compression settings affect memory usage of a single -connection on a 64-bit system, as well a benchmark of compressed size and -compression time for a corpus of small JSON documents. - -+-------------+-------------+--------------+--------------+------------------+------------------+ -| Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | -+=============+=============+==============+==============+==================+==================+ -| | 15 | 8 | 322 KiB | -4.0% | +15% + -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 14 | 7 | 178 KiB | -2.6% | +10% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 13 | 6 | 106 KiB | -1.4% | +5% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| *default* | 12 | 5 | 70 KiB | = | = | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 11 | 4 | 52 KiB | +3.7% | -5% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 10 | 3 | 43 KiB | +90% | +50% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 9 | 2 | 39 KiB | +160% | +100% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| *disabled* | N/A | N/A | 19 KiB | N/A | N/A | -+-------------+-------------+--------------+--------------+------------------+------------------+ - -*Don't assume this example is representative! Compressed size and compression -time depend heavily on the kind of messages exchanged by the application!* - -You can adapt the `compression.py`_ benchmark for your application by creating -a list of typical messages and passing it to the ``_benchmark`` function. - -.. _compression.py: https://github.com/aaugustin/websockets/blob/main/performance/compression.py - -This `blog post by Ilya Grigorik`_ provides more details about how compression -settings affect memory usage and how to optimize them. - -.. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ - -This `experiment by Peter Thorson`_ suggests Window Bits = 11 and Memory Level = -4 as a sweet spot for optimizing memory usage. - -.. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html - -websockets defaults to Window Bits = 12 and Memory Level = 5 in order to stay -away from Window Bits = 10 or Memory Level = 3, where performance craters in -the benchmark. This raises doubts on what could happen at Window Bits = 11 and -Memory Level = 4 on a different set of messages. The defaults needs to be safe -for all applications, hence a more conservative choice. +Read to the topic guide on :doc:`../topics/compression` to learn more about +tuning compression settings. Buffers ....... diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 033dfe15d..b6cff7237 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -31,7 +31,6 @@ optimized, and secure setup. :maxdepth: 2 deployment - compression heroku kubernetes supervisor diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst new file mode 100644 index 000000000..206bfa7b7 --- /dev/null +++ b/docs/topics/compression.rst @@ -0,0 +1,172 @@ +Compression +=========== + +Most WebSocket servers exchange JSON messages because they're convenient to +parse and serialize in a browser. These messages contain text data and tend to +be repetitive. + +This makes the stream of messages highly compressible. Enabling compression +can reduce network traffic by more than 80%. + +There's a standard for compressing messages. :rfc:`7692` defines WebSocket +Per-Message Deflate, a compression extension based on the Deflate_ algorithm. + +.. _Deflate: https://en.wikipedia.org/wiki/Deflate + +Configuring compression +----------------------- + +:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable +compression by default because the reduction in network bandwidth is usually +worth the additional memory and CPU cost. + +If you want to disable compression, set ``compression=None``:: + + import websockets + + websockets.connect(..., compression=None) + + websockets.serve(..., compression=None) + +If you want to customize compression settings, you can enable the Per-Message +Deflate extension explicitly with +:class:`~permessage_deflate.ClientPerMessageDeflateFactory` or +:class:`~permessage_deflate.ServerPerMessageDeflateFactory`:: + + import websockets + from websockets.extensions import permessage_deflate + + websockets.connect( + ..., + extensions=[ + permessage_deflate.ClientPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={"memLevel": 4}, + ), + ], + ) + + websockets.serve( + ..., + extensions=[ + permessage_deflate.ServerPerMessageDeflateFactory( + server_max_window_bits=11, + client_max_window_bits=11, + compress_settings={"memLevel": 4}, + ), + ], + ) + +The Window Bits and Memory Level values in these examples reduce memory usage at the expense of compression rate. + +Compression settings +-------------------- + +When a client and a server enable the Per-Message Deflate extension, they +negotiate two parameters to guarantee compatibility between compression and +decompression. This affects the trade-off between compression rate and memory +usage for both sides. + +* **Context Takeover** means that the compression context is retained between + messages. In other words, compression is applied to the stream of messages + rather than to each message individually. Context takeover should remain + enabled to get good performance on applications that send a stream of + messages with the same structure, that is, most applications. + +* **Window Bits** controls the size of the compression context. It must be + an integer between 9 (lowest memory usage) and 15 (best compression). + websockets defaults to 12. Setting it to 8 is possible but rejected by some + versions of zlib. + +:mod:`zlib` offers additional parameters for tuning compression. They control +the trade-off between compression rate and CPU and memory usage for the +compression side, transparently for the decompression side. + +* **Memory Level** controls the size of the compression state. It must be an + integer between 1 (lowest memory usage) and 9 (best compression). websockets + defaults to 5. A lower memory level can increase speed thanks to memory + locality. + +* **Compression Level** controls the effort to optimize compression. It must + be an integer between 1 (lowest CPU usage) and 9 (best compression). + +* **Strategy** selects the compression strategy. The best choice depends on + the type of data being compressed. + +Unless mentioned otherwise, websockets uses the defaults of +:func:`zlib.compressobj` for all these settings. + +Tuning compression +------------------ + +By default, websockets enables compression with conservative settings that +optimize memory usage at the cost of a slightly worse compression rate: Window +Bits = 12 and Memory Level = 5. This strikes a good balance for small messages +that are typical of WebSocket servers. + +Here's how various compression settings affect memory usage of a single +connection on a 64-bit system, as well a benchmark of compressed size and +compression time for a corpus of small JSON documents. + ++-------------+-------------+--------------+--------------+------------------+------------------+ +| Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | ++=============+=============+==============+==============+==================+==================+ +| | 15 | 8 | 322 KiB | -4.0% | +15% + ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 14 | 7 | 178 KiB | -2.6% | +10% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 13 | 6 | 106 KiB | -1.4% | +5% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| *default* | 12 | 5 | 70 KiB | = | = | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 11 | 4 | 52 KiB | +3.7% | -5% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 10 | 3 | 43 KiB | +90% | +50% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| | 9 | 2 | 39 KiB | +160% | +100% | ++-------------+-------------+--------------+--------------+------------------+------------------+ +| *disabled* | — | — | 19 KiB | +452% | — | ++-------------+-------------+--------------+--------------+------------------+------------------+ + +Window Bits and Memory Level don't have to move in lockstep. However, other +combinations don't yield significantly better results than those shown above. + +Compressed size and compression time depend heavily on the kind of messages +exchanged by the application so this example may not apply to your use case. + +You can adapt `compression/benchmark.py`_ by creating a list of typical +messages and passing it to the ``_run`` function. + +Window Bits = 11 and Memory Level = 4 looks like the sweet spot in this table. + +websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from +Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts +on what could happen at Window Bits = 11 and Memory Level = 4 on a different +corpus. + +Defaults must be safe for all applications, hence a more conservative choice. + +.. _compression/benchmark.py: https://github.com/aaugustin/websockets/blob/main/experiments/compression/benchmark.py + +The benchmark focuses on compression because it's more expensive than +decompression. Indeed, leaving aside small allocations, theoretical memory +usage is: + +* ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; +* ``1 << windowBits`` for decompression. + +CPU usage is also higher for compression than decompression. + +Further reading +--------------- + +This `blog post by Ilya Grigorik`_ provides more details about how compression +settings affect memory usage and how to optimize them. + +.. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ + +This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory +Level = 4 for optimizing memory usage. + +.. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 5363de0ce..b18434a39 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -5,5 +5,6 @@ Topics :maxdepth: 2 authentication + compression design security diff --git a/benchmark/compression.py b/experiments/compression/benchmark.py similarity index 97% rename from benchmark/compression.py rename to experiments/compression/benchmark.py index 15fb8653e..bdcdd8e95 100644 --- a/benchmark/compression.py +++ b/experiments/compression/benchmark.py @@ -49,7 +49,7 @@ def corpus(): pickle.dump(data, handle) -def _benchmark(data): +def _run(data): size = {} duration = {} @@ -149,15 +149,15 @@ def _benchmark(data): print() -def benchmark(): +def run(): with open(CORPUS_FILE, "rb") as handle: data = pickle.load(handle) - _benchmark(data) + _run(data) try: run = globals()[sys.argv[1]] except (KeyError, IndexError): - print(f"Usage: {sys.argv[0]} [corpus|benchmark]") + print(f"Usage: {sys.argv[0]} [corpus|run]") else: run() diff --git a/benchmark/mem_client.py b/experiments/compression/client.py similarity index 87% rename from benchmark/mem_client.py rename to experiments/compression/client.py index db68eb995..3ee19ddc5 100644 --- a/benchmark/mem_client.py +++ b/experiments/compression/client.py @@ -16,7 +16,7 @@ MEM_SIZE = [] -async def mem_client(client): +async def client(client): # Space out connections to make them sequential. await asyncio.sleep(client * INTERVAL) @@ -45,11 +45,11 @@ async def mem_client(client): await asyncio.sleep(CLIENTS * INTERVAL) -async def mem_clients(): - await asyncio.gather(*[mem_client(client) for client in range(CLIENTS + 1)]) +async def clients(): + await asyncio.gather(*[client(client) for client in range(CLIENTS + 1)]) -asyncio.run(mem_clients()) +asyncio.run(clients()) # First connection incurs non-representative setup costs. diff --git a/benchmark/mem_server.py b/experiments/compression/server.py similarity index 96% rename from benchmark/mem_server.py rename to experiments/compression/server.py index 852796249..f7b147006 100644 --- a/benchmark/mem_server.py +++ b/experiments/compression/server.py @@ -34,7 +34,7 @@ async def handler(ws, path): await asyncio.sleep(CLIENTS * INTERVAL) -async def mem_server(): +async def server(): loop = asyncio.get_running_loop() stop = loop.create_future() @@ -60,7 +60,7 @@ async def mem_server(): await stop -asyncio.run(mem_server()) +asyncio.run(server()) # First connection may incur non-representative setup costs. From 773f0b6d542307ff94c83d2f6fb9e28786e63dd8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 24 May 2021 21:46:13 +0200 Subject: [PATCH 039/762] Address objection to Django integration. --- docs/howto/django.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index e776f5f8c..001d1cb73 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -169,6 +169,11 @@ following code in the JavaScript console of the browser: websocket.onopen = (event) => websocket.send(""); websocket.onmessage = (event) => console.log(event.data); +If you don't want to import your entire Django project into the websockets +server, you can build a separate Django project with ``django.contrib.auth``, +``django-sesame``, a suitable ``User`` model, and a subset of the settings of +the main project. + Stream events ------------- From b455d30252d8eb66771b63c7b9ca26d37f76e4d9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 May 2021 22:31:32 +0200 Subject: [PATCH 040/762] Reorder FAQ. Add question on threads. --- docs/howto/faq.rst | 89 +++++++++++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 40 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 0196fd2c0..23964dcfb 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -131,13 +131,17 @@ websockets takes care of closing the connection when the handler exits. How do I run a HTTP server and WebSocket server on the same port? ................................................................. -This isn't supported. +You don't. + +HTTP and WebSockets have widely different operational characteristics. +Running them on the same server is a bad idea. Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the :attr:`~legacy.server.WebSocketServerProtocol.process_request` hook. + If you need more, pick a HTTP server and run it separately. Client side @@ -287,6 +291,18 @@ There are several reasons why long-lived connections may be lost: If you're facing a reproducible issue, :ref:`enable debug logs ` to see when and how connections are closed. +How do I set a timeout on ``recv()``? +..................................... + +Use :func:`~asyncio.wait_for`:: + + await asyncio.wait_for(websocket.recv(), timeout=10) + +This technique works for most APIs, except for asynchronous context managers. +See `issue 574`_. + +.. _issue 574: https://github.com/aaugustin/websockets/issues/574 + How can I pass additional arguments to a custom protocol subclass? .................................................................. @@ -307,31 +323,19 @@ You can bind additional arguments to the protocol factory with This example was for a server. The same pattern applies on a client. -Why do I get the error: ``module 'websockets' has no attribute '...'``? -....................................................................... - -Often, this is because you created a script called ``websockets.py`` in your -current working directory. Then ``import websockets`` imports this module -instead of the websockets library. - -Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? -............................................................................ - -No, there aren't. +How do I keep idle connections open? +.................................... -websockets provides high-level, coroutine-based APIs. Compared to callbacks, -coroutines make it easier to manage control flow in concurrent code. +websockets sends pings at 20 seconds intervals to keep the connection open. -If you prefer callback-based APIs, you should use another library. +In closes the connection if it doesn't get a pong within 20 seconds. -Can I use websockets synchronously, without ``async`` / ``await``? -.................................................................. +You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. -You can convert every asynchronous call to a synchronous call by wrapping it -in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this -is deprecated as of Python 3.10. +How do I respond to pings? +.......................... -If this turns out to be impractical, you should use another library. +websockets takes care of responding to pings with pongs. Miscellaneous ------------- @@ -344,31 +348,24 @@ websockets doesn't have built-in publish / subscribe for these use cases. Depending on the scale of your service, a simple in-memory implementation may do the job or you may need an external publish / subscribe component. -How do I set a timeout on ``recv()``? -..................................... - -Use :func:`~asyncio.wait_for`:: - - await asyncio.wait_for(websocket.recv(), timeout=10) - -This technique works for most APIs, except for asynchronous context managers. -See `issue 574`_. - -.. _issue 574: https://github.com/aaugustin/websockets/issues/574 +Can I use websockets synchronously, without ``async`` / ``await``? +.................................................................. -How do I keep idle connections open? -.................................... +You can convert every asynchronous call to a synchronous call by wrapping it +in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this +is deprecated as of Python 3.10. -websockets sends pings at 20 seconds intervals to keep the connection open. +If this turns out to be impractical, you should use another library. -In closes the connection if it doesn't get a pong within 20 seconds. +Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? +............................................................................ -You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. +No, there aren't. -How do I respond to pings? -.......................... +websockets provides high-level, coroutine-based APIs. Compared to callbacks, +coroutines make it easier to manage control flow in concurrent code. -websockets takes care of responding to pings with pongs. +If you prefer callback-based APIs, you should use another library. Is there a Python 2 version? ............................ @@ -379,4 +376,16 @@ Python 2 reached end of life on January 1st, 2020. Before that date, websockets required asyncio and therefore Python 3. +Why do I get the error: ``module 'websockets' has no attribute '...'``? +....................................................................... + +Often, this is because you created a script called ``websockets.py`` in your +current working directory. Then ``import websockets`` imports this module +instead of the websockets library. + +I'm having problems with threads +................................ + +You shouldn't use threads. Use tasks instead. +:func:`~asyncio.AbstractEventLoop.call_soon_threadsafe` may help. From 5e69983096359fdf87d26afa7b5143badfe2140e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 May 2021 23:00:21 +0200 Subject: [PATCH 041/762] Add deployment doc. Fix #445. --- docs/howto/deployment.rst | 111 -------------------- docs/howto/index.rst | 6 +- docs/spelling_wordlist.txt | 7 ++ docs/topics/deployment.rst | 181 +++++++++++++++++++++++++++++++++ docs/topics/deployment.svg | 63 ++++++++++++ docs/topics/index.rst | 2 + docs/topics/memory.rst | 43 ++++++++ docs/topics/security.rst | 4 +- example/health_check_server.py | 4 +- example/shutdown_server.py | 1 + 10 files changed, 303 insertions(+), 119 deletions(-) delete mode 100644 docs/howto/deployment.rst create mode 100644 docs/topics/deployment.rst create mode 100644 docs/topics/deployment.svg create mode 100644 docs/topics/memory.rst diff --git a/docs/howto/deployment.rst b/docs/howto/deployment.rst deleted file mode 100644 index d8264b2cc..000000000 --- a/docs/howto/deployment.rst +++ /dev/null @@ -1,111 +0,0 @@ -Deployment -========== - -.. currentmodule:: websockets - -Application server ------------------- - -The author of websockets isn't aware of best practices for deploying network -services based on :mod:`asyncio`, let alone application servers. - -You can run a script similar to the :ref:`server example `, -inside a supervisor if you deem that useful. - -You can also add a wrapper to daemonize the process. Third-party libraries -provide solutions for that. - -If you can share knowledge on this topic, please file an issue_. Thanks! - -.. _issue: https://github.com/aaugustin/websockets/issues/new - -Graceful shutdown ------------------ - -You may want to close connections gracefully when shutting down the server, -perhaps after executing some cleanup logic. There are two ways to achieve this -with the object returned by :func:`~legacy.server.serve`: - -- using it as a asynchronous context manager, or -- calling its ``close()`` method, then waiting for its ``wait_closed()`` - method to complete. - -On Unix systems, shutdown is usually triggered by sending a signal. - -Here's a full example for handling SIGTERM on Unix: - -.. literalinclude:: ../../example/shutdown_server.py - :emphasize-lines: 12-15,17 - -This example is easily adapted to handle other signals. If you override the -default handler for SIGINT, which raises :exc:`KeyboardInterrupt`, be aware -that you won't be able to interrupt a program with Ctrl-C anymore when it's -stuck in a loop. - -It's more difficult to achieve the same effect on Windows. Some third-party -projects try to help with this problem. - -If your server doesn't run in the main thread, look at -:func:`~asyncio.AbstractEventLoop.call_soon_threadsafe`. - -Memory usage ------------- - -.. _memory-usage: - -In most cases, memory usage of a WebSocket server is proportional to the -number of open connections. When a server handles thousands of connections, -memory usage can become a bottleneck. - -Memory usage of a single connection is the sum of: - -1. the baseline amount of memory websockets requires for each connection, -2. the amount of data held in buffers before the application processes it, -3. any additional memory allocated by the application itself. - -Baseline -........ - -Compression settings are the main factor affecting the baseline amount of -memory used by each connection. - -Read to the topic guide on :doc:`../topics/compression` to learn more about -tuning compression settings. - -Buffers -....... - -Under normal circumstances, buffers are almost always empty. - -Under high load, if a server receives more messages than it can process, -bufferbloat can result in excessive memory use. - -By default websockets has generous limits. It is strongly recommended to adapt -them to your application. When you call :func:`~legacy.server.serve`: - -- Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of - messages your application generates. -- Set ``max_queue`` (default: 32) to the maximum number of messages your - application expects to receive faster than it can process them. The queue - provides burst tolerance without slowing down the TCP connection. - -Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: -64 KiB) to reduce the size of buffers for incoming and outgoing data. - -The design document provides :ref:`more details about buffers`. - -Port sharing ------------- - -The WebSocket protocol is an extension of HTTP/1.1. It can be tempting to -serve both HTTP and WebSocket on the same port. - -The author of websockets doesn't think that's a good idea, due to the widely -different operational characteristics of HTTP and WebSocket. - -websockets provide minimal support for responding to HTTP requests with the -:meth:`~legacy.server.WebSocketServerProtocol.process_request` hook. Typical -use cases include health checks. Here's an example: - -.. literalinclude:: ../../example/health_check_server.py - :emphasize-lines: 9-11,20 diff --git a/docs/howto/index.rst b/docs/howto/index.rst index b6cff7237..38073a41e 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -24,13 +24,13 @@ features, which websockets supports fully. extensions -Once your application is ready, learn how to deploy it with a convenient, -optimized, and secure setup. +.. _deployment-howto: + +Once your application is ready, learn how to deploy it on various platforms. .. toctree:: :maxdepth: 2 - deployment heroku kubernetes supervisor diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 030917491..b7493bba6 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -2,6 +2,7 @@ api attr augustin auth +autoscaler awaitable aymeric backend @@ -25,13 +26,17 @@ datastructures django dyno fractalideas +gunicorn +hypercorn iframe IPv +istio iterable keepalive KiB Kubernetes lifecycle +linkerd liveness lookups MiB @@ -60,10 +65,12 @@ unparse unregister uple username +uvicorn virtualenv WebSocket websocket websockets ws +wsgi wss www diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst new file mode 100644 index 000000000..a7b4bd3ad --- /dev/null +++ b/docs/topics/deployment.rst @@ -0,0 +1,181 @@ +Deployment +========== + +.. currentmodule:: websockets + +When you deploy your websockets server to production, at a high level, your +architecture will almost certainly look like the following diagram: + +.. image:: deployment.svg + +The basic unit for scaling a websockets server is "one server process". Each +blue box in the diagram represents one server process. + +There's more variation in routing. While the routing layer is shown as one big +box, it is likely to involve several subsystems. + +When you design a deployment, your should consider two questions: + +1. How will I run the appropriate number of server processes? +2. How will I route incoming connections to these processes? + +These questions are strongly related. There's a wide range of acceptable +answers, depending on your goals and your constraints. + +You can find a few concrete examples in the :ref:`deployment how-to guides +`. + +Running server processes +------------------------ + +How many processes do I need? +............................. + +Typically, one server process will manage a few hundreds or thousands +connections, depending on the frequency of messages and the amount of work +they require. + +CPU and memory usage increase with the number of connections to the server. + +Often CPU is the limiting factor. If a server process goes to 100% CPU, then +you reached the limit. How much headroom you want to keep is up to you. + +Once you know how many connections a server process can manage and how many +connections you need to handle, you can calculate how many processes to run. + +You can also automate this calculation by configuring an autoscaler to keep +CPU usage or connection count within acceptable limits. + +Don't scale with threads. Threads doesn't make sense for a server built with +:mod:`asyncio`. + +How do I run processes? +....................... + +Most solutions for running multiple instances of a server process fall into +one of these three buckets: + +1. Running N processes on a platform: + + * a Kubernetes Deployment + + * its equivalent on a Platform as a Service provider + +2. Running N servers: + + * an AWS Auto Scaling group, a GCP Managed instance group, etc. + + * a fixed set of long-lived servers + +3. Running N processes on a server: + + * preferrably via a process manager or supervisor + +Option 1 is easiest of you have access to such a platform. + +Option 2 almost always combines with option 3. + +How do I start a process? +......................... + +Run a Python program that invokes :func:`~serve`. That's it. + +Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're +alternatives to websockets, not complements. + +Don't run a WSGI server such as Gunicorn, Waitress, or mod_wsgi. They aren't +designed to run WebSocket applications. + +Applications servers handle network connections and expose a Python API. You +don't need one because websockets handles network connections directly. + +How do I stop a process? +........................ + +Process managers send the SIGTERM signal to terminate processes. Catch this +signal and exit the server to ensure a graceful shutdown. + +Here's an example: + +.. literalinclude:: ../../example/shutdown_server.py + :emphasize-lines: 12-15,18 + +When exiting the context manager, :func:`~server.serve` closes all connections +with code 1001 (going away). As a consequence: + +* If the connection handler is awaiting + :meth:`~server.WebSocketServerProtocol.recv`, it receives a + :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception + and clean up before exiting. + +* Otherwise, it should be waiting on + :meth:`~server.WebSocketServerProtocol.wait_closed`, so it can receive the + :exc:`~exceptions.ConnectionClosedOK` exception and exit. + +This example is easily adapted to handle other signals. + +If you override the default signal handler for SIGINT, which raises +:exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a +program with Ctrl-C anymore when it's stuck in a loop. + +Routing connections +------------------- + +What does routing involve? +.......................... + +Since the routing layer is directly exposed to the Internet, it should provide +appropriate protection against threats ranging from Internet background noise +to targeted attacks. + +You should always secure WebSocket connections with TLS. Since the routing +layer carries the public domain name, it should terminate TLS connections. + +Finally, it must route connections to the server processes, balancing new +connections across them. + +How do I route connections? +........................... + +Here are typical solutions for load balancing, matched to ways of running +processes: + +1. If you're running on a platform, it comes with a routing layer: + + * a Kubernetes Ingress and Service + + * a service mesh: Istio, Consul, Linkerd, etc. + + * the routing mesh of a Platform as a Service + +2. If you're running N servers, you may load balance with: + + * a cloud load balancer: AWS Elastic Load Balancing, GCP Cloud Load + Balancing, etc. + + * A software load balancer: HAProxy, NGINX, etc. + +3. If you're running N processes on a server, you may load balance with: + + * A software load balancer: HAProxy, NGINX, etc. + + * The operating system — all processes listen on the same port + +You may trust the load balancer to handle encryption and to provide security. +You may add another layer in front of the load balancer for these purposes. + +There are many possibilities. Don't add layers that you don't need, though. + +How do I implement a health check? +.................................. + +Load balancers need a way to check whether server processes are up and running +to avoid routing connections to a non-functional backend. + +websockets provide minimal support for responding to HTTP requests with the +:meth:`~server.WebSocketServerProtocol.process_request` hook. + +Here's an example: + +.. literalinclude:: ../../example/health_check_server.py + :emphasize-lines: 7-9,18 diff --git a/docs/topics/deployment.svg b/docs/topics/deployment.svg new file mode 100644 index 000000000..fbacb18c4 --- /dev/null +++ b/docs/topics/deployment.svg @@ -0,0 +1,63 @@ +Internetwebsocketswebsocketswebsocketsrouting \ No newline at end of file diff --git a/docs/topics/index.rst b/docs/topics/index.rst index b18434a39..41bdc3051 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -4,7 +4,9 @@ Topics .. toctree:: :maxdepth: 2 + deployment authentication compression design + memory security diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst new file mode 100644 index 000000000..e5b9ed9a2 --- /dev/null +++ b/docs/topics/memory.rst @@ -0,0 +1,43 @@ +Memory usage +============ + +In most cases, memory usage of a WebSocket server is proportional to the +number of open connections. When a server handles thousands of connections, +memory usage can become a bottleneck. + +Memory usage of a single connection is the sum of: + +1. the baseline amount of memory websockets requires for each connection, +2. the amount of data held in buffers before the application processes it, +3. any additional memory allocated by the application itself. + +Baseline +-------- + +Compression settings are the main factor affecting the baseline amount of +memory used by each connection. + +Read to the topic guide on :doc:`../topics/compression` to learn more about +tuning compression settings. + +Buffers +------- + +Under normal circumstances, buffers are almost always empty. + +Under high load, if a server receives more messages than it can process, +bufferbloat can result in excessive memory use. + +By default websockets has generous limits. It is strongly recommended to adapt +them to your application. When you call :func:`~legacy.server.serve`: + +- Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of + messages your application generates. +- Set ``max_queue`` (default: 32) to the maximum number of messages your + application expects to receive faster than it can process them. The queue + provides burst tolerance without slowing down the TCP connection. + +Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: +64 KiB) to reduce the size of buffers for incoming and outgoing data. + +The design document provides :ref:`more details about buffers`. diff --git a/docs/topics/security.rst b/docs/topics/security.rst index 39de08120..6c541db06 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -24,8 +24,8 @@ With the default settings, opening a connection uses 70 KiB of memory. Sending some highly compressed messages could use up to 128 MiB of memory with an amplification factor of 1000 between network traffic and memory use. -Configuring a server to :ref:`optimize memory usage ` will -improve security in addition to improving performance. +Configuring a server to :doc:`memory` will improve security in addition to +improving performance. Other limits ------------ diff --git a/example/health_check_server.py b/example/health_check_server.py index c5bb6d5ab..2ca185cde 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -1,13 +1,11 @@ #!/usr/bin/env python -# WS echo server with HTTP endpoint at /health/ - import asyncio import http import websockets async def health_check(path, request_headers): - if path == "/health/": + if path == "/healthz": return http.HTTPStatus.OK, [], b"OK\n" async def echo(websocket, path): diff --git a/example/shutdown_server.py b/example/shutdown_server.py index 1ae44af1e..cabba4014 100755 --- a/example/shutdown_server.py +++ b/example/shutdown_server.py @@ -13,6 +13,7 @@ async def server(): loop = asyncio.get_running_loop() stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + async with websockets.serve(echo, "localhost", 8765): await stop From dfecbd0307a94bc8704f895b66167e8838c222e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 13:23:30 +0200 Subject: [PATCH 042/762] Prep for releasing 9.1 with security fix. --- docs/project/changelog.rst | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a44dd0418..dcaa06e9d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -45,11 +45,22 @@ They may change at any time. * Optimized default compression settings to reduce memory usage. -* Protected against timing attacks on HTTP Basic Auth. - * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. +9.1 +... + +*May 27, 2021* + +.. note:: + + **Version 9.1 fixes a security issue introduced in version 8.0.** + + Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords. + + .. _CVE-2018-1000518: https://nvd.nist.gov/vuln/detail/CVE-2018-1000518 + 9.0.2 ..... From ac85304980285079ff871cde728964d0acfde569 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 21:27:27 +0200 Subject: [PATCH 043/762] Fix link to changelog. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8c7a4984a..0625798a2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,7 +4,7 @@ python-tag = py37.py38.py39.py310 [metadata] license_file = LICENSE project_urls = - Changelog = https://websockets.readthedocs.io/en/stable/changelog.html + Changelog = https://websockets.readthedocs.io/en/stable/project/changelog.html Documentation = https://websockets.readthedocs.io/ Funding = https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme Tracker = https://github.com/aaugustin/websockets/issues From 17210674d6d2f0987a0cd74d0d6ac37d88d28977 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 22:52:49 +0200 Subject: [PATCH 044/762] Add deployment guide for nginx. --- docs/howto/index.rst | 1 + docs/howto/nginx.rst | 84 +++++++++++++++++++++++ example/deployment/nginx/app.py | 29 ++++++++ example/deployment/nginx/nginx.conf | 25 +++++++ example/deployment/nginx/supervisord.conf | 7 ++ 5 files changed, 146 insertions(+) create mode 100644 docs/howto/nginx.rst create mode 100644 example/deployment/nginx/app.py create mode 100644 example/deployment/nginx/nginx.conf create mode 100644 example/deployment/nginx/supervisord.conf diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 38073a41e..eab6cab6e 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -34,3 +34,4 @@ Once your application is ready, learn how to deploy it on various platforms. heroku kubernetes supervisor + nginx diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst new file mode 100644 index 000000000..cb4dc83f2 --- /dev/null +++ b/docs/howto/nginx.rst @@ -0,0 +1,84 @@ +Deploy behind nginx +=================== + +This guide demonstrates a way to load balance connections across multiple +websockets server processes running on the same machine with nginx_. + +We'll run server processes with Supervisor as described in :doc:`this guide +`. + +.. _nginx: https://nginx.org/ + +Run server processes +-------------------- + +Save this app to ``app.py``: + +.. literalinclude:: ../../example/deployment/nginx/app.py + :emphasize-lines: 21,23 + +We'd like to nginx to connect to websockets servers via Unix sockets in order +to avoid the overhead of TCP for communicating between processes running in +the same OS. + +We start the app with :func:`~websockets.server.unix_serve`. Each server +process listens on a different socket thanks to an environment variable set +by Supervisor to a different value. + +Save this configuration to ``supervisord.conf``: + +.. literalinclude:: ../../example/deployment/nginx/supervisord.conf + +This configuration runs four instances of the app. + +Install Supervisor and run it: + +.. code:: console + + $ supervisord -c supervisord.conf -n + +Configure and run nginx +----------------------- + +Here's a simple nginx configuration to load balance connections across four +processes: + +.. literalinclude:: ../../example/deployment/nginx/nginx.conf + +We set ``daemon off`` so we can run nginx in the foreground for testing. + +Then we combine the `WebSocket proxying`_ and `load balancing`_ guides: + +* The WebSocket protocol requires HTTP/1.1. We must set the HTTP protocol + version to 1.1, else nginx defaults to HTTP/1.0 for proxying. + +* The WebSocket handshake involves the ``Connection`` and ``Upgrade`` HTTP + headers. We must pass them to the upstream explicitly, else nginx drops + them because they're hop-by-hop headers. + + We deviate from the `WebSocket proxying`_ guide because its example adds a + ``Connection: Upgrade`` header to every upstream request, even if the + original request didn't contain that header. + +* In the upstream configuration, we set the load balancing method to + ``least_conn`` in order to balance the number of active connections across + servers. This is best for long running connections. + +.. _WebSocket proxying: http://nginx.org/en/docs/http/websocket.html +.. _load balancing: http://nginx.org/en/docs/http/load_balancing.html + +Save the configuration to ``nginx.conf``, install nginx, and run it: + +.. code:: console + + $ nginx -c nginx.conf -p . + +You can confirm that nginx proxies connections properly: + +.. code:: console + + $ PYTHONPATH=src python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + > Hello! + < Hello! + Connection closed: code = 1000 (OK), no reason. diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py new file mode 100644 index 000000000..ad42a8b3e --- /dev/null +++ b/example/deployment/nginx/app.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import asyncio +import os +import signal + +import websockets + + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.unix_serve( + echo, + path=f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock", + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/nginx/nginx.conf b/example/deployment/nginx/nginx.conf new file mode 100644 index 000000000..67aa0086d --- /dev/null +++ b/example/deployment/nginx/nginx.conf @@ -0,0 +1,25 @@ +daemon off; + +events { +} + +http { + server { + listen localhost:8080; + + location / { + proxy_http_version 1.1; + proxy_pass http://websocket; + proxy_set_header Connection $http_connection; + proxy_set_header Upgrade $http_upgrade; + } + } + + upstream websocket { + least_conn; + server unix:websockets-test_00.sock; + server unix:websockets-test_01.sock; + server unix:websockets-test_02.sock; + server unix:websockets-test_03.sock; + } +} diff --git a/example/deployment/nginx/supervisord.conf b/example/deployment/nginx/supervisord.conf new file mode 100644 index 000000000..76a664d91 --- /dev/null +++ b/example/deployment/nginx/supervisord.conf @@ -0,0 +1,7 @@ +[supervisord] + +[program:websockets-test] +command = python app.py +process_name = %(program_name)s_%(process_num)02d +numprocs = 4 +autorestart = true From 2990cf761177073e6d741e1ef81dc1d6d3c5dba8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 27 May 2021 23:05:18 +0200 Subject: [PATCH 045/762] Add HAProxy deployment guide. Ref #445. --- docs/howto/haproxy.rst | 61 +++++++++++++++++++++ docs/howto/index.rst | 1 + docs/spelling_wordlist.txt | 2 + docs/topics/deployment.rst | 2 +- example/deployment/haproxy/app.py | 30 ++++++++++ example/deployment/haproxy/haproxy.cfg | 17 ++++++ example/deployment/haproxy/supervisord.conf | 7 +++ 7 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 docs/howto/haproxy.rst create mode 100644 example/deployment/haproxy/app.py create mode 100644 example/deployment/haproxy/haproxy.cfg create mode 100644 example/deployment/haproxy/supervisord.conf diff --git a/docs/howto/haproxy.rst b/docs/howto/haproxy.rst new file mode 100644 index 000000000..d520d278a --- /dev/null +++ b/docs/howto/haproxy.rst @@ -0,0 +1,61 @@ +Deploy behind HAProxy +===================== + +This guide demonstrates a way to load balance connections across multiple +websockets server processes running on the same machine with HAProxy_. + +We'll run server processes with Supervisor as described in :doc:`this guide +`. + +.. _HAProxy: https://www.haproxy.org/ + +Run server processes +-------------------- + +Save this app to ``app.py``: + +.. literalinclude:: ../../example/deployment/haproxy/app.py + :emphasize-lines: 24 + +Each server process listens on a different port by extracting an incremental +index from an environment variable set by Supervisor. + +Save this configuration to ``supervisord.conf``: + +.. literalinclude:: ../../example/deployment/haproxy/supervisord.conf + +This configuration runs four instances of the app. + +Install Supervisor and run it: + +.. code:: console + + $ supervisord -c supervisord.conf -n + +Configure and run HAProxy +------------------------- + +Here's a simple HAProxy configuration to load balance connections across four +processes: + +.. literalinclude:: ../../example/deployment/haproxy/haproxy.cfg + +In the backend configuration, we set the load balancing method to +``leastconn`` in order to balance the number of active connections across +servers. This is best for long running connections. + +Save the configuration to ``haproxy.cfg``, install HAProxy, and run it: + +.. code:: console + + $ haproxy -f haproxy.cfg + +You can confirm that HAProxy proxies connections properly: + +.. code:: console + + $ PYTHONPATH=src python -m websockets ws://localhost:8080/ + Connected to ws://localhost:8080/. + > Hello! + < Hello! + Connection closed: code = 1000 (OK), no reason. diff --git a/docs/howto/index.rst b/docs/howto/index.rst index eab6cab6e..ac1182705 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -35,3 +35,4 @@ Once your application is ready, learn how to deploy it on various platforms. kubernetes supervisor nginx + haproxy diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b7493bba6..676e13afc 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -27,6 +27,7 @@ django dyno fractalideas gunicorn +haproxy hypercorn iframe IPv @@ -48,6 +49,7 @@ permessage pid pong pongs +proxying pythonic redis runtime diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index a7b4bd3ad..d30c5568e 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -69,7 +69,7 @@ one of these three buckets: 3. Running N processes on a server: - * preferrably via a process manager or supervisor + * preferably via a process manager or supervisor Option 1 is easiest of you have access to such a platform. diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py new file mode 100644 index 000000000..2b24790dd --- /dev/null +++ b/example/deployment/haproxy/app.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +import asyncio +import os +import signal + +import websockets + + +async def echo(websocket, path): + async for message in websocket: + await websocket.send(message) + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, + host="localhost", + port=8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]), + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/haproxy/haproxy.cfg b/example/deployment/haproxy/haproxy.cfg new file mode 100644 index 000000000..e63727d1c --- /dev/null +++ b/example/deployment/haproxy/haproxy.cfg @@ -0,0 +1,17 @@ +defaults + mode http + timeout connect 10s + timeout client 30s + timeout server 30s + +frontend websocket + bind localhost:8080 + default_backend websocket + +backend websocket + balance leastconn + server websockets-test_00 localhost:8000 + server websockets-test_01 localhost:8001 + server websockets-test_02 localhost:8002 + server websockets-test_03 localhost:8003 + diff --git a/example/deployment/haproxy/supervisord.conf b/example/deployment/haproxy/supervisord.conf new file mode 100644 index 000000000..76a664d91 --- /dev/null +++ b/example/deployment/haproxy/supervisord.conf @@ -0,0 +1,7 @@ +[supervisord] + +[program:websockets-test] +command = python app.py +process_name = %(program_name)s_%(process_num)02d +numprocs = 4 +autorestart = true From 58ac5228dabae9a734e2ee360be36b8db5f2e8c2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 28 May 2021 21:56:17 +0200 Subject: [PATCH 046/762] Explain keepalives. Ref #919. --- docs/spelling_wordlist.txt | 1 + docs/topics/index.rst | 1 + docs/topics/timeouts.rst | 66 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 docs/topics/timeouts.rst diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 676e13afc..4fa44fbaa 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -52,6 +52,7 @@ pongs proxying pythonic redis +retransmit runtime scalable serializers diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 41bdc3051..9269c585a 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -7,6 +7,7 @@ Topics deployment authentication compression + timeouts design memory security diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst new file mode 100644 index 000000000..828d22a0b --- /dev/null +++ b/docs/topics/timeouts.rst @@ -0,0 +1,66 @@ +Timeouts +======== + +Since the WebSocket protocol is intended for real-time communications over +long-lived connections, it is desirable to ensure that connections don't +break, and if they do, to report the problem quickly. + +WebSocket is built on top of HTTP/1.1 where connections are short-lived, even +with ``Connection: keep-alive``. Typically, HTTP/1.1 infrastructure closes +idle connections after 30 to 120 seconds. + +As a consequence, proxies may terminate WebSocket connections prematurely, +when no message was exchanged in 30 seconds. + +In order to avoid this problem, websockets implements a keepalive mechanism +based on WebSocket Ping and Pong frames. Ping and Pong are designed for this +purpose. + +By default, websockets waits 20 seconds, then sends a Ping frame, and expects +to receive the corresponding Pong frame within 20 seconds. Else, it considers +the connection broken and closes it. + +Timings are configurable with ``ping_interval`` and ``ping_timeout``. + +While WebSocket runs on top of TCP, websockets doesn't rely on TCP keepalive +because it's disabled by default and, if enabled, the default interval is no +less than two hours, which doesn't meet requirements. + +Latency between a client and a server may increase for two reasons: + +* Network connectivity is poor. When network packets are lost, TCP attempts to + retransmit them, which manifests as latency. Excessive packet loss makes + the connection unusable in practice. At some point, timing out is a + reasonable choice. + +* Traffic is high. For example, if a client sends messages on the connection + faster than a server can process them, this manifests as latency as well, + because data is waiting in flight, mostly in OS buffers. + + If the server is more than 20 seconds behind, it doesn't see the Pong before + the default timeout elapses. As a consequence, it closes the connection. + This is a reasonable choice to prevent overload. + + If traffic spikes cause unwanted timeouts and you're confident that the + server will catch up eventually, you can increase ``ping_timeout`` or you + can disable keepalive entirely with ``ping_interval=None``. + + The same reasoning applies to situations where the server sends more traffic + than the client can accept. + +You can monitor latency as follows: + +.. code:: python + + import asyncio + import logging + import time + + async def log_latency(websocket, logger): + t0 = time.perf_counter() + pong_waiter = await websocket.ping() + await pong_waiter + t1 = time.perf_counter() + logger.info("Connection latency: %.3f seconds", t1 - t0) + + asyncio.create_task(log_latency(websocket, logging.getLogger())) From 820dad7b339a6dd653938d64b22cf732a9415133 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 May 2021 07:02:33 +0200 Subject: [PATCH 047/762] Add ToC to FAQ. --- docs/howto/faq.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 23964dcfb..ce3704a2e 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -11,6 +11,9 @@ FAQ .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html +.. contents:: + :local: + Server side ----------- From f8e081da8d1ee76da19b64cadcb2adf1eaa140d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 May 2021 21:04:26 +0200 Subject: [PATCH 048/762] isort now supports import at the bottom of modules. --- src/websockets/frames.py | 2 +- src/websockets/legacy/framing.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 6e5ef1b73..62bea6814 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -338,4 +338,4 @@ def check_close(code: int) -> None: # at the bottom to allow circular import, because Extension depends on Frame -from . import extensions # isort:skip # noqa +from . import extensions # noqa diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 627e6922c..e947c9383 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -124,12 +124,11 @@ def write( write(self.serialize(mask=mask, extensions=extensions)) -# Backwards compatibility with previously documented public APIs -from ..frames import parse_close # isort:skip # noqa -from ..frames import prepare_ctrl as encode_data # isort:skip # noqa -from ..frames import prepare_data # isort:skip # noqa -from ..frames import serialize_close # isort:skip # noqa - - # at the bottom to allow circular import, because Extension depends on Frame -from .. import extensions # isort:skip # noqa +from .. import extensions # noqa + +# Backwards compatibility with previously documented public APIs +from ..frames import parse_close # noqa +from ..frames import prepare_data # noqa +from ..frames import serialize_close # noqa +from ..frames import prepare_ctrl as encode_data # noqa From 07e8a636eeb7187532175b27e92a8c67564f631a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 15:17:04 +0200 Subject: [PATCH 049/762] Check that GET requests don't have a body. It is technically possible but doesn't have a meaning in general. --- src/websockets/http11.py | 8 ++++++++ tests/test_http11.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 0754ddabb..6f3cbccc4 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -99,6 +99,14 @@ def parse( headers = yield from parse_headers(read_line) + # https://tools.ietf.org/html/rfc7230#section-3.3.3 + + if "Transfer-Encoding" in headers: + raise NotImplementedError("transfer codings aren't supported") + + if "Content-Length" in headers: + raise ValueError("unsupported request body") + return cls(path, headers) def serialize(self) -> bytes: diff --git a/tests/test_http11.py b/tests/test_http11.py index 1cca2053f..e73365cf0 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -77,6 +77,24 @@ def test_parse_invalid_header(self): "invalid HTTP header line: Oops", ) + def test_parse_body(self): + self.reader.feed_data(b"GET / HTTP/1.1\r\nContent-Length: 3\r\n\r\nYo\n") + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "unsupported request body", + ) + + def test_parse_body_with_transfer_encoding(self): + self.reader.feed_data(b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + with self.assertRaises(NotImplementedError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "transfer codings aren't supported", + ) + def test_serialize(self): # Example from the protocol overview in RFC 6455 request = Request( From 58193e088ace6ce7c4c2dfcb6deea2af7cbe8601 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 17:20:56 +0200 Subject: [PATCH 050/762] Add UUID to connections. --- docs/reference/client.rst | 6 ++++++ docs/reference/server.rst | 6 ++++++ src/websockets/connection.py | 4 ++++ src/websockets/legacy/protocol.py | 4 ++++ 4 files changed, 20 insertions(+) diff --git a/docs/reference/client.rst b/docs/reference/client.rst index 374a8197e..c2bb6266b 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -17,6 +17,12 @@ Client .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. attribute:: id + + UUID for the connection. + + Useful for identifying connections in logs. + .. autoattribute:: local_address .. autoattribute:: remote_address diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 1a2dd1c88..898403070 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -27,6 +27,12 @@ Server .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + .. attribute:: id + + UUID for the connection. + + Useful for identifying connections in logs. + .. autoattribute:: local_address .. autoattribute:: remote_address diff --git a/src/websockets/connection.py b/src/websockets/connection.py index aeb774f00..4bd09e282 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,5 +1,6 @@ import enum import logging +import uuid from typing import Generator, List, Optional, Union from .exceptions import InvalidState, PayloadTooBig, ProtocolError @@ -68,6 +69,9 @@ def __init__( state: State = OPEN, max_size: Optional[int] = 2 ** 20, ) -> None: + # Unique identifier. For logs. + self.id = uuid.uuid4() + # Connection side. CLIENT or SERVER. self.side = side diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 4e8958b60..b544b93d0 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -14,6 +14,7 @@ import logging import random import struct +import uuid import warnings from typing import ( Any, @@ -132,6 +133,9 @@ def __init__( self.read_limit = read_limit self.write_limit = write_limit + # Unique identifier. For logs. + self.id = uuid.uuid4() + assert loop is not None # Remove when dropping Python < 3.10 - use get_running_loop instead. self.loop = loop From 2617b2c8f80c19e759a4e1364bc7007ff8d99acd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 May 2021 20:59:26 +0200 Subject: [PATCH 051/762] Add human-friendly representation of frames. Fix #765. --- src/websockets/__main__.py | 3 +- src/websockets/exceptions.py | 46 ++----------- src/websockets/frames.py | 89 +++++++++++++++++++++++++ tests/test_frames.py | 126 +++++++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+), 43 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 46e746da5..4358c323c 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -6,7 +6,8 @@ import threading from typing import Any, Set -from .exceptions import ConnectionClosed, format_close +from .exceptions import ConnectionClosed +from .frames import format_close from .legacy.client import connect diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 9ab9d3ebe..4bd2a41a6 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -68,48 +68,6 @@ class WebSocketException(Exception): """ -# See https://www.iana.org/assignments/websocket/websocket.xhtml -CLOSE_CODES = { - 1000: "OK", - 1001: "going away", - 1002: "protocol error", - 1003: "unsupported type", - # 1004 is reserved - 1005: "no status code [internal]", - 1006: "connection closed abnormally [internal]", - 1007: "invalid data", - 1008: "policy violation", - 1009: "message too big", - 1010: "extension required", - 1011: "unexpected error", - 1012: "service restart", - 1013: "try again later", - 1014: "bad gateway", - 1015: "TLS failure [internal]", -} - - -def format_close(code: int, reason: str) -> str: - """ - Display a human-readable version of the close code and reason. - - """ - if 3000 <= code < 4000: - explanation = "registered" - elif 4000 <= code < 5000: - explanation = "private use" - else: - explanation = CLOSE_CODES.get(code, "unknown") - result = f"code = {code} ({explanation}), " - - if reason: - result += f"reason = {reason}" - else: - result += "no reason" - - return result - - class ConnectionClosed(WebSocketException): """ Raised when trying to interact with a closed connection. @@ -371,3 +329,7 @@ class ProtocolError(WebSocketException): WebSocketProtocolError = ProtocolError # for backwards compatibility + + +# at the bottom to allow circular import, because the frames module imports exceptions +from .frames import format_close # noqa diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 62bea6814..de7fdb941 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -52,6 +52,28 @@ class Opcode(enum.IntEnum): DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG + +# See https://www.iana.org/assignments/websocket/websocket.xhtml +CLOSE_CODES = { + 1000: "OK", + 1001: "going away", + 1002: "protocol error", + 1003: "unsupported type", + # 1004 is reserved + 1005: "no status code [internal]", + 1006: "connection closed abnormally [internal]", + 1007: "invalid data", + 1008: "policy violation", + 1009: "message too big", + 1010: "extension required", + 1011: "unexpected error", + 1012: "service restart", + 1013: "try again later", + 1014: "bad gateway", + 1015: "TLS failure [internal]", +} + + # Close code that are allowed in a close frame. # Using a set optimizes `code in EXTERNAL_CLOSE_CODES`. EXTERNAL_CLOSE_CODES = { @@ -96,6 +118,52 @@ class Frame(NamedTuple): rsv2: bool = False rsv3: bool = False + def __str__(self) -> str: + """ + Return a human-readable represention of a frame. + + """ + coding = None + length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}" + non_final = "" if self.fin else "continued" + + if self.opcode is OP_TEXT: + # Decoding only the beginning and the end is needlessly hard. + # Decode the entire payload then elide later if necessary. + data = self.data.decode() + elif self.opcode is OP_BINARY: + # We'll show at most the first 16 bytes and the last 8 bytes. + # Encode just what we need, plus two dummy bytes to elide later. + binary = self.data + if len(binary) > 25: + binary = binary[:16] + b"\x00\x00" + binary[-8:] + data = " ".join(f"{byte:02x}" for byte in binary) + elif self.opcode is OP_CLOSE: + code, reason = parse_close(self.data) + data = format_close(code, reason) + elif self.data: + # We don't know if a Continuation frame contains text or binary. + # Ping and Pong frames could contain UTF-8. Attempt to decode as + # UTF-8 and display it as text; fallback to binary. + try: + data = self.data.decode() + coding = "text" + except UnicodeDecodeError: + binary = self.data + if len(binary) > 25: + binary = binary[:16] + b"\x00\x00" + binary[-8:] + data = " ".join(f"{byte:02x}" for byte in binary) + coding = "binary" + else: + data = "" + + if len(data) > 75: + data = data[:48] + "..." + data[-24:] + + metadata = ", ".join(filter(None, [coding, length, non_final])) + + return f"{self.opcode.name} {data} [{metadata}]" + @classmethod def parse( cls, @@ -337,5 +405,26 @@ def check_close(code: int) -> None: raise ProtocolError("invalid status code") +def format_close(code: int, reason: str) -> str: + """ + Display a human-readable version of the close code and reason. + + """ + if 3000 <= code < 4000: + explanation = "registered" + elif 4000 <= code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODES.get(code, "unknown") + result = f"code = {code} ({explanation}), " + + if reason: + result += f"reason = {reason}" + else: + result += "no reason" + + return result + + # at the bottom to allow circular import, because Extension depends on Frame from . import extensions # noqa diff --git a/tests/test_frames.py b/tests/test_frames.py index 13a712322..491386566 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -193,6 +193,132 @@ def decode(frame, *, max_size=None): ) +class StrTests(unittest.TestCase): + def test_cont_text(self): + self.assertEqual( + str(Frame(False, OP_CONT, b" cr\xc3\xa8me")), + "CONT crème [text, 7 bytes, continued]", + ) + + def test_cont_binary(self): + self.assertEqual( + str(Frame(False, OP_CONT, b"\xfc\xfd\xfe\xff")), + "CONT fc fd fe ff [binary, 4 bytes, continued]", + ) + + def test_cont_final_text(self): + self.assertEqual( + str(Frame(True, OP_CONT, b" cr\xc3\xa8me")), + "CONT crème [text, 7 bytes]", + ) + + def test_cont_final_binary(self): + self.assertEqual( + str(Frame(True, OP_CONT, b"\xfc\xfd\xfe\xff")), + "CONT fc fd fe ff [binary, 4 bytes]", + ) + + def test_cont_text_truncated(self): + self.assertEqual( + str(Frame(False, OP_CONT, b"caf\xc3\xa9 " * 16)), + "CONT café café café café café café café café café caf..." + "afé café café café café [text, 96 bytes, continued]", + ) + + def test_cont_binary_truncated(self): + self.assertEqual( + str(Frame(False, OP_CONT, bytes(range(256)))), + "CONT 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." + " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", + ) + + def test_text(self): + self.assertEqual( + str(Frame(True, OP_TEXT, b"caf\xc3\xa9")), + "TEXT café [5 bytes]", + ) + + def test_text_non_final(self): + self.assertEqual( + str(Frame(False, OP_TEXT, b"caf\xc3\xa9")), + "TEXT café [5 bytes, continued]", + ) + + def test_text_truncated(self): + self.assertEqual( + str(Frame(True, OP_TEXT, b"caf\xc3\xa9 " * 16)), + "TEXT café café café café café café café café café caf..." + "afé café café café café [96 bytes]", + ) + + def test_binary(self): + self.assertEqual( + str(Frame(True, OP_BINARY, b"\x00\x01\x02\x03")), + "BINARY 00 01 02 03 [4 bytes]", + ) + + def test_binary_non_final(self): + self.assertEqual( + str(Frame(False, OP_BINARY, b"\x00\x01\x02\x03")), + "BINARY 00 01 02 03 [4 bytes, continued]", + ) + + def test_binary_truncated(self): + self.assertEqual( + str(Frame(True, OP_BINARY, bytes(range(256)))), + "BINARY 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." + " f8 f9 fa fb fc fd fe ff [256 bytes]", + ) + + def test_close(self): + self.assertEqual( + str(Frame(True, OP_CLOSE, b"\x03\xe8")), + "CLOSE code = 1000 (OK), no reason [2 bytes]", + ) + + def test_close_reason(self): + self.assertEqual( + str(Frame(True, OP_CLOSE, b"\x03\xe9Bye!")), + "CLOSE code = 1001 (going away), reason = Bye! [6 bytes]", + ) + + def test_ping(self): + self.assertEqual( + str(Frame(True, OP_PING, b"")), + "PING [0 bytes]", + ) + + def test_ping_text(self): + self.assertEqual( + str(Frame(True, OP_PING, b"ping")), + "PING ping [text, 4 bytes]", + ) + + def test_ping_binary(self): + self.assertEqual( + str(Frame(True, OP_PING, b"\xff\x00\xff\x00")), + "PING ff 00 ff 00 [binary, 4 bytes]", + ) + + def test_pong(self): + self.assertEqual( + str(Frame(True, OP_PONG, b"")), + "PONG [0 bytes]", + ) + + def test_pong_text(self): + self.assertEqual( + str(Frame(True, OP_PONG, b"pong")), + "PONG pong [text, 4 bytes]", + ) + + def test_pong_binary(self): + self.assertEqual( + str(Frame(True, OP_PONG, b"\xff\x00\xff\x00")), + "PONG ff 00 ff 00 [binary, 4 bytes]", + ) + + class PrepareDataTests(unittest.TestCase): def test_prepare_data_str(self): self.assertEqual( From 09f829f66a57ded2832279a18a64257b5fd3f875 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 16:13:43 +0200 Subject: [PATCH 052/762] Inject logger in Sans-I/O layer. --- src/websockets/__init__.py | 2 ++ src/websockets/client.py | 18 +++++++++++------- src/websockets/connection.py | 22 ++++++++++++++-------- src/websockets/server.py | 26 +++++++++++++++----------- src/websockets/typing.py | 12 +++++++++++- tests/test_client.py | 9 +++++++++ tests/test_server.py | 9 +++++++++ 7 files changed, 71 insertions(+), 27 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 65d9fb913..f136a4e45 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -27,6 +27,7 @@ "InvalidStatusCode", "InvalidUpgrade", "InvalidURI", + "LoggerLike", "NegotiationError", "Origin", "parse_uri", @@ -92,6 +93,7 @@ "WebSocketServerProtocol": ".legacy.server", "WebSocketServer": ".legacy.server", "Data": ".typing", + "LoggerLike": ".typing", "Origin": ".typing", "ExtensionHeader": ".typing", "ExtensionParameter": ".typing", diff --git a/src/websockets/client.py b/src/websockets/client.py index 0ddf19f00..9d3f0e7f7 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,5 +1,4 @@ import collections -import logging from typing import Generator, List, Optional, Sequence from .connection import CLIENT, CONNECTING, OPEN, Connection @@ -27,6 +26,7 @@ from .typing import ( ConnectionOption, ExtensionHeader, + LoggerLike, Origin, Subprotocol, UpgradeProtocol, @@ -41,8 +41,6 @@ __all__ = ["ClientConnection"] -logger = logging.getLogger(__name__) - class ClientConnection(Connection): def __init__( @@ -53,8 +51,14 @@ def __init__( subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, max_size: Optional[int] = 2 ** 20, + logger: Optional[LoggerLike] = None, ): - super().__init__(side=CLIENT, state=CONNECTING, max_size=max_size) + super().__init__( + side=CLIENT, + state=CONNECTING, + max_size=max_size, + logger=logger, + ) self.wsuri = parse_uri(uri) self.origin = origin self.available_extensions = extensions @@ -271,8 +275,8 @@ def send_request(self, request: Request) -> None: Send a WebSocket handshake request to the server. """ - logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) - logger.debug("%s > %r", self.side, request.headers) + self.logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) + self.logger.debug("%s > %r", self.side, request.headers) self.writes.append(request.serialize()) @@ -285,7 +289,7 @@ def parse(self) -> Generator[None, None, None]: self.process_response(response) except InvalidHandshake as exc: response = response._replace(exception=exc) - logger.debug("Invalid handshake", exc_info=True) + self.logger.debug("Invalid handshake", exc_info=True) else: self.set_state(OPEN) finally: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4bd09e282..4852560e3 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -18,7 +18,7 @@ ) from .http11 import Request, Response from .streams import StreamReader -from .typing import Origin, Subprotocol +from .typing import LoggerLike, Origin, Subprotocol __all__ = [ @@ -28,8 +28,6 @@ "SEND_EOF", ] -logger = logging.getLogger(__name__) - Event = Union[Request, Response, Frame] @@ -68,6 +66,7 @@ def __init__( side: Side, state: State = OPEN, max_size: Optional[int] = 2 ** 20, + logger: Optional[LoggerLike] = None, ) -> None: # Unique identifier. For logs. self.id = uuid.uuid4() @@ -76,12 +75,19 @@ def __init__( self.side = side # Connnection state. CONNECTING and CLOSED states are handled in subclasses. - logger.debug("%s - initial state: %s", self.side, state.name) self.state = state # Maximum size of incoming messages in bytes. self.max_size = max_size + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger(f"websockets.{side.name.lower()}") + self.logger = logger + + # Must wait until we have the logger to log the initial state! + self.logger.debug("%s - initial state: %s", self.side, state.name) + # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. self.cur_size: Optional[int] = None @@ -117,7 +123,7 @@ def __init__( self.parser_exc: Optional[Exception] = None def set_state(self, state: State) -> None: - logger.debug( + self.logger.debug( "%s - state change: %s > %s", self.side, self.state.name, state.name ) self.state = state @@ -286,7 +292,7 @@ def step_parser(self) -> None: self.parser_exc = exc raise except Exception as exc: - logger.error("unexpected exception in parser", exc_info=True) + self.logger.error("unexpected exception in parser", exc_info=True) # Don't include exception details, which may be security-sensitive. self.fail_connection(1011) self.parser_exc = exc @@ -401,7 +407,7 @@ def send_frame(self, frame: Frame) -> None: f"cannot write to a WebSocket in the {self.state.name} state" ) - logger.debug("%s > %r", self.side, frame) + self.logger.debug("%s > %r", self.side, frame) self.writes.append( frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) ) @@ -409,5 +415,5 @@ def send_frame(self, frame: Frame) -> None: def send_eof(self) -> None: assert not self.eof_sent self.eof_sent = True - logger.debug("%s > EOF", self.side) + self.logger.debug("%s > EOF", self.side) self.writes.append(SEND_EOF) diff --git a/src/websockets/server.py b/src/websockets/server.py index f57d36b70..4a76ac886 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -3,7 +3,6 @@ import collections import email.utils import http -import logging from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast from .connection import CONNECTING, OPEN, SERVER, Connection @@ -29,6 +28,7 @@ from .typing import ( ConnectionOption, ExtensionHeader, + LoggerLike, Origin, Subprotocol, UpgradeProtocol, @@ -42,8 +42,6 @@ __all__ = ["ServerConnection"] -logger = logging.getLogger(__name__) - HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] @@ -59,8 +57,14 @@ def __init__( subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, max_size: Optional[int] = 2 ** 20, + logger: Optional[LoggerLike] = None, ): - super().__init__(side=SERVER, state=CONNECTING, max_size=max_size) + super().__init__( + side=SERVER, + state=CONNECTING, + max_size=max_size, + logger=logger, + ) self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols @@ -80,13 +84,13 @@ def accept(self, request: Request) -> Response: try: key, extensions_header, protocol_header = self.process_request(request) except InvalidOrigin as exc: - logger.debug("Invalid origin", exc_info=True) + self.logger.debug("Invalid origin", exc_info=True) return self.reject( http.HTTPStatus.FORBIDDEN, f"Failed to open a WebSocket connection: {exc}.\n", )._replace(exception=exc) except InvalidUpgrade as exc: - logger.debug("Invalid upgrade", exc_info=True) + self.logger.debug("Invalid upgrade", exc_info=True) return self.reject( http.HTTPStatus.UPGRADE_REQUIRED, ( @@ -98,13 +102,13 @@ def accept(self, request: Request) -> Response: headers=Headers([("Upgrade", "websocket")]), )._replace(exception=exc) except InvalidHandshake as exc: - logger.debug("Invalid handshake", exc_info=True) + self.logger.debug("Invalid handshake", exc_info=True) return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc}.\n", )._replace(exception=exc) except Exception as exc: - logger.warning("Error in opening handshake", exc_info=True) + self.logger.warning("Error in opening handshake", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( @@ -410,15 +414,15 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: self.set_state(OPEN) - logger.debug( + self.logger.debug( "%s > HTTP/1.1 %d %s", self.side, response.status_code, response.reason_phrase, ) - logger.debug("%s > %r", self.side, response.headers) + self.logger.debug("%s > %r", self.side, response.headers) if response.body is not None: - logger.debug("%s > body (%d bytes)", self.side, len(response.body)) + self.logger.debug("%s > body (%d bytes)", self.side, len(response.body)) self.writes.append(response.serialize()) diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 630a9fbe3..b1858d73e 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -1,7 +1,15 @@ +import logging from typing import List, NewType, Optional, Tuple, Union -__all__ = ["Data", "Origin", "ExtensionHeader", "ExtensionParameter", "Subprotocol"] +__all__ = [ + "Data", + "LoggerLike", + "Origin", + "ExtensionHeader", + "ExtensionParameter", + "Subprotocol", +] Data = Union[str, bytes] Data.__doc__ = """ @@ -12,6 +20,8 @@ """ +LoggerLike = Union[logging.Logger, logging.LoggerAdapter] +LoggerLike.__doc__ = """"Types accepted where :class:`~logging.Logger` is expected""" Origin = NewType("Origin", str) Origin.__doc__ = """Value of a Origin header""" diff --git a/tests/test_client.py b/tests/test_client.py index 747594bf3..b96ebd272 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import logging import unittest import unittest.mock @@ -568,3 +569,11 @@ def test_unsupported_subprotocol(self): with self.assertRaises(InvalidHandshake) as raised: raise response.exception self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") + + +class MiscTests(unittest.TestCase): + def test_custom_logger(self): + logger = logging.getLogger("test") + with self.assertLogs("test", logging.DEBUG) as logs: + ClientConnection("wss://example.com/test", logger=logger) + self.assertEqual(len(logs.records), 1) diff --git a/tests/test_server.py b/tests/test_server.py index ad56a37bc..6a25f0d25 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,5 @@ import http +import logging import unittest import unittest.mock @@ -625,3 +626,11 @@ def test_extra_headers_overrides_server(self): self.assertEqual(response.status_code, 101) self.assertEqual(response.headers["Server"], "Other") + + +class MiscTests(unittest.TestCase): + def test_custom_logger(self): + logger = logging.getLogger("test") + with self.assertLogs("test", logging.DEBUG) as logs: + ServerConnection(logger=logger) + self.assertEqual(len(logs.records), 1) From f5eae5657c2d948cdf20174a03190c6a4206341c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 May 2021 16:23:30 +0200 Subject: [PATCH 053/762] Inject logger in legacy asyncio layer. --- docs/reference/client.rst | 6 +-- docs/reference/server.rst | 6 +-- src/websockets/legacy/client.py | 19 +++++---- src/websockets/legacy/protocol.py | 68 ++++++++++++++++-------------- src/websockets/legacy/server.py | 41 ++++++++++-------- tests/legacy/test_client_server.py | 16 +++++++ 6 files changed, 94 insertions(+), 62 deletions(-) diff --git a/docs/reference/client.rst b/docs/reference/client.rst index c2bb6266b..c7c738c81 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -6,16 +6,16 @@ Client Opening a connection -------------------- - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None, **kwds) :async: - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, **kwds) + .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None, **kwds) :async: Using a connection ------------------ - .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None) + .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None) .. attribute:: id diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 898403070..2d54eca7a 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -6,10 +6,10 @@ Server Starting a server ----------------- - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None, **kwds) :async: - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, **kwds) + .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None, **kwds) :async: Stopping a server @@ -25,7 +25,7 @@ Server Using a connection ------------------ - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None) + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None) .. attribute:: id diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b77b4e86d..760309c3f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -31,7 +31,7 @@ parse_subprotocol, ) from ..http import USER_AGENT, build_host -from ..typing import ExtensionHeader, Origin, Subprotocol +from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri from .compatibility import asyncio_get_running_loop from .handshake import build_request, check_response @@ -41,8 +41,6 @@ __all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] -logger = logging.getLogger("websockets.server") - class WebSocketClientProtocol(WebSocketCommonProtocol): """ @@ -152,13 +150,16 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: + if logger is None: + logger = logging.getLogger("websockets.client") + super().__init__(logger=logger, **kwargs) self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers - super().__init__(**kwargs) def write_http_request(self, path: str, headers: Headers) -> None: """ @@ -168,8 +169,8 @@ def write_http_request(self, path: str, headers: Headers) -> None: self.path = path self.request_headers = headers - logger.debug("%s > GET %s HTTP/1.1", self.side, path) - logger.debug("%s > %r", self.side, headers) + self.logger.debug("%s > GET %s HTTP/1.1", self.side, path) + self.logger.debug("%s > %r", self.side, headers) # Since the path and headers only contain ASCII characters, # we can keep this simple. @@ -198,8 +199,8 @@ async def read_http_response(self) -> Tuple[int, Headers]: except Exception as exc: raise InvalidMessage("did not receive a valid HTTP response") from exc - logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) - logger.debug("%s < %r", self.side, headers) + self.logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) + self.logger.debug("%s < %r", self.side, headers) self.response_headers = headers @@ -478,6 +479,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -542,6 +544,7 @@ def __init__( extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, + logger=logger, ) if kwargs.pop("unix", False): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b544b93d0..940e330a9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -54,15 +54,13 @@ prepare_data, serialize_close, ) -from ..typing import Data, Subprotocol +from ..typing import Data, LoggerLike, Subprotocol from .compatibility import loop_if_py_lt_38 from .framing import Frame __all__ = ["WebSocketCommonProtocol"] -logger = logging.getLogger("websockets.protocol") - # A WebSocket connection goes through the following four states, in order: @@ -108,6 +106,7 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, + logger: Optional[LoggerLike] = None, # The following arguments are kept only for backwards compatibility. host: Optional[str] = None, port: Optional[int] = None, @@ -132,6 +131,9 @@ def __init__( self.max_queue = max_queue self.read_limit = read_limit self.write_limit = write_limit + if logger is None: + logger = logging.getLogger("websockets.protocol") + self.logger = logger # Unique identifier. For logs. self.id = uuid.uuid4() @@ -162,7 +164,7 @@ def __init__( # Subclasses implement the opening handshake and, on success, execute # :meth:`connection_open` to change the state to OPEN. self.state = State.CONNECTING - logger.debug("%s - state = CONNECTING", self.side) + self.logger.debug("%s - state = CONNECTING", self.side) # HTTP protocol parameters. self.path: str @@ -246,7 +248,7 @@ def connection_open(self) -> None: # 4.1. The WebSocket Connection is Established. assert self.state is State.CONNECTING self.state = State.OPEN - logger.debug("%s - state = OPEN", self.side) + self.logger.debug("%s - state = OPEN", self.side) # Start the task that receives incoming WebSocket messages. self.transfer_data_task = self.loop.create_task(self.transfer_data()) # Start the task that sends pings at regular intervals. @@ -812,7 +814,7 @@ async def transfer_data(self) -> None: # This shouldn't happen often because exceptions expected under # regular circumstances are handled above. If it does, consider # catching and handling more exceptions. - logger.error("Error in data transfer", exc_info=True) + self.logger.error("Error in data transfer", exc_info=True) self.transfer_data_exc = exc self.fail_connection(1011) @@ -923,7 +925,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PING: # Answer pings. ping_hex = frame.data.hex() or "[empty]" - logger.debug( + self.logger.debug( "%s - received ping, sending pong: %s", self.side, ping_hex ) await self.pong(frame.data) @@ -931,7 +933,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PONG: # Acknowledge pings on solicited pongs. if frame.data in self.pings: - logger.debug( + self.logger.debug( "%s - received solicited pong: %s", self.side, frame.data.hex() or "[empty]", @@ -956,14 +958,14 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: ping_id.hex() or "[empty]" for ping_id in ping_ids ) plural = "s" if len(ping_ids) > 1 else "" - logger.debug( + self.logger.debug( "%s - acknowledged previous ping%s: %s", self.side, plural, pings_hex, ) else: - logger.debug( + self.logger.debug( "%s - received unsolicited pong: %s", self.side, frame.data.hex() or "[empty]", @@ -984,7 +986,7 @@ async def read_frame(self, max_size: Optional[int]) -> Frame: max_size=max_size, extensions=self.extensions, ) - logger.debug("%s < %r", self.side, frame) + self.logger.debug("%s < %r", self.side, frame) return frame async def write_frame( @@ -997,7 +999,7 @@ async def write_frame( ) frame = Frame(fin, Opcode(opcode), data) - logger.debug("%s > %r", self.side, frame) + self.logger.debug("%s > %r", self.side, frame) frame.write( self.transport.write, mask=self.is_client, extensions=self.extensions ) @@ -1029,7 +1031,7 @@ async def write_close_frame(self, data: bytes = b"") -> None: if self.state is State.OPEN: # 7.1.3. The WebSocket Closing Handshake is Started self.state = State.CLOSING - logger.debug("%s - state = CLOSING", self.side) + self.logger.debug("%s - state = CLOSING", self.side) # 7.1.2. Start the WebSocket Closing Handshake await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) @@ -1071,7 +1073,7 @@ async def keepalive_ping(self) -> None: **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: - logger.debug("%s ! timed out waiting for pong", self.side) + self.logger.debug("%s ! timed out waiting for pong", self.side) self.fail_connection(1011) break @@ -1084,7 +1086,9 @@ async def keepalive_ping(self) -> None: pass except Exception: - logger.warning("Unexpected exception in keepalive ping task", exc_info=True) + self.logger.warning( + "Unexpected exception in keepalive ping task", exc_info=True + ) async def close_connection(self) -> None: """ @@ -1116,18 +1120,18 @@ async def close_connection(self) -> None: # Coverage marks this line as a partially executed branch. # I supect a bug in coverage. Ignore it for now. return # pragma: no cover - logger.debug("%s ! timed out waiting for TCP close", self.side) + self.logger.debug("%s ! timed out waiting for TCP close", self.side) # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): - logger.debug("%s x half-closing TCP connection", self.side) + self.logger.debug("%s x half-closing TCP connection", self.side) self.transport.write_eof() if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. # I supect a bug in coverage. Ignore it for now. return # pragma: no cover - logger.debug("%s ! timed out waiting for TCP close", self.side) + self.logger.debug("%s ! timed out waiting for TCP close", self.side) finally: # The try/finally ensures that the transport never remains open, @@ -1140,15 +1144,15 @@ async def close_connection(self) -> None: return # Close the TCP connection. Buffers are flushed asynchronously. - logger.debug("%s x closing TCP connection", self.side) + self.logger.debug("%s x closing TCP connection", self.side) self.transport.close() if await self.wait_for_connection_lost(): return - logger.debug("%s ! timed out waiting for TCP close", self.side) + self.logger.debug("%s ! timed out waiting for TCP close", self.side) # Abort the TCP connection. Buffers are discarded. - logger.debug("%s x aborting TCP connection", self.side) + self.logger.debug("%s x aborting TCP connection", self.side) self.transport.abort() # connection_lost() is called quickly after aborting. @@ -1196,7 +1200,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: (The specification describes these steps in the opposite order.) """ - logger.debug( + self.logger.debug( "%s ! failing %s WebSocket connection with code %d", self.side, self.state.name, @@ -1226,10 +1230,10 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # and write_frame(). self.state = State.CLOSING - logger.debug("%s - state = CLOSING", self.side) + self.logger.debug("%s - state = CLOSING", self.side) frame = Frame(True, OP_CLOSE, frame_data) - logger.debug("%s > %r", self.side, frame) + self.logger.debug("%s > %r", self.side, frame) frame.write( self.transport.write, mask=self.is_client, extensions=self.extensions ) @@ -1259,7 +1263,7 @@ def abort_pings(self) -> None: if self.pings: pings_hex = ", ".join(ping_id.hex() or "[empty]" for ping_id in self.pings) plural = "s" if len(self.pings) > 1 else "" - logger.debug( + self.logger.debug( "%s - aborted pending ping%s: %s", self.side, plural, pings_hex ) @@ -1279,7 +1283,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: which means it's the best point for configuring it. """ - logger.debug("%s - event = connection_made(%s)", self.side, transport) + self.logger.debug("%s - event = connection_made(%s)", self.side, transport) transport = cast(asyncio.Transport, transport) transport.set_write_buffer_limits(self.write_limit) @@ -1293,14 +1297,14 @@ def connection_lost(self, exc: Optional[Exception]) -> None: 7.1.4. The WebSocket Connection is Closed. """ - logger.debug("%s - event = connection_lost(%s)", self.side, exc) + self.logger.debug("%s - event = connection_lost(%s)", self.side, exc) self.state = State.CLOSED - logger.debug("%s - state = CLOSED", self.side) + self.logger.debug("%s - state = CLOSED", self.side) if not hasattr(self, "close_code"): self.close_code = 1006 if not hasattr(self, "close_reason"): self.close_reason = "" - logger.debug( + self.logger.debug( "%s x code = %d, reason = %s", self.side, self.close_code, @@ -1351,7 +1355,9 @@ def resume_writing(self) -> None: # pragma: no cover waiter.set_result(None) def data_received(self, data: bytes) -> None: - logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data)) + self.logger.debug( + "%s - event = data_received(<%d bytes>)", self.side, len(data) + ) self.reader.feed_data(data) def eof_received(self) -> None: @@ -1367,5 +1373,5 @@ def eof_received(self) -> None: Besides, that doesn't work on TLS connections. """ - logger.debug("%s - event = eof_received()", self.side) + self.logger.debug("%s - event = eof_received()", self.side) self.reader.feed_eof() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 5bd7d0f56..5f76500df 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -41,7 +41,7 @@ from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import build_extension, parse_extension, parse_subprotocol from ..http import USER_AGENT -from ..typing import ExtensionHeader, Origin, Subprotocol +from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from .compatibility import asyncio_get_running_loop, loop_if_py_lt_38 from .handshake import build_response, check_request from .http import read_request @@ -50,8 +50,6 @@ __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] -logger = logging.getLogger("websockets.server") - HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] @@ -184,8 +182,12 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: + if logger is None: + logger = logging.getLogger("websockets.server") + super().__init__(logger=logger, **kwargs) # For backwards compatibility with 6.0 or earlier. if origins is not None and "" in origins: warnings.warn("use None instead of '' in origins", DeprecationWarning) @@ -198,7 +200,6 @@ def __init__( self.extra_headers = extra_headers self._process_request = process_request self._select_subprotocol = select_subprotocol - super().__init__(**kwargs) def connection_made(self, transport: asyncio.BaseTransport) -> None: """ @@ -236,20 +237,22 @@ async def handler(self) -> None: except asyncio.CancelledError: # pragma: no cover raise except ConnectionError: - logger.debug("Connection error in opening handshake", exc_info=True) + self.logger.debug( + "Connection error in opening handshake", exc_info=True + ) raise except Exception as exc: if isinstance(exc, AbortHandshake): status, headers, body = exc.status, exc.headers, exc.body elif isinstance(exc, InvalidOrigin): - logger.debug("Invalid origin", exc_info=True) + self.logger.debug("Invalid origin", exc_info=True) status, headers, body = ( http.HTTPStatus.FORBIDDEN, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) elif isinstance(exc, InvalidUpgrade): - logger.debug("Invalid upgrade", exc_info=True) + self.logger.debug("Invalid upgrade", exc_info=True) status, headers, body = ( http.HTTPStatus.UPGRADE_REQUIRED, Headers([("Upgrade", "websocket")]), @@ -261,14 +264,14 @@ async def handler(self) -> None: ).encode(), ) elif isinstance(exc, InvalidHandshake): - logger.debug("Invalid handshake", exc_info=True) + self.logger.debug("Invalid handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) else: - logger.warning("Error in opening handshake", exc_info=True) + self.logger.warning("Error in opening handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.INTERNAL_SERVER_ERROR, Headers(), @@ -292,7 +295,7 @@ async def handler(self) -> None: try: await self.ws_handler(self, path) except Exception: - logger.error("Error in connection handler", exc_info=True) + self.logger.error("Error in connection handler", exc_info=True) if not self.closed: self.fail_connection(1011) raise @@ -300,10 +303,12 @@ async def handler(self) -> None: try: await self.close() except ConnectionError: - logger.debug("Connection error in closing handshake", exc_info=True) + self.logger.debug( + "Connection error in closing handshake", exc_info=True + ) raise except Exception: - logger.warning("Error in closing handshake", exc_info=True) + self.logger.warning("Error in closing handshake", exc_info=True) raise except Exception: @@ -338,8 +343,8 @@ async def read_http_request(self) -> Tuple[str, Headers]: except Exception as exc: raise InvalidMessage("did not receive a valid HTTP request") from exc - logger.debug("%s < GET %s HTTP/1.1", self.side, path) - logger.debug("%s < %r", self.side, headers) + self.logger.debug("%s < GET %s HTTP/1.1", self.side, path) + self.logger.debug("%s < %r", self.side, headers) self.path = path self.request_headers = headers @@ -357,8 +362,8 @@ def write_http_response( """ self.response_headers = headers - logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) - logger.debug("%s > %r", self.side, headers) + self.logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) + self.logger.debug("%s > %r", self.side, headers) # Since the status line and headers only contain ASCII characters, # we can keep this simple. @@ -368,7 +373,7 @@ def write_http_response( self.transport.write(response.encode()) if body is not None: - logger.debug("%s > body (%d bytes)", self.side, len(body)) + self.logger.debug("%s > body (%d bytes)", self.side, len(body)) self.transport.write(body) async def process_request( @@ -962,6 +967,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -1025,6 +1031,7 @@ def __init__( extra_headers=extra_headers, process_request=process_request, select_subprotocol=select_subprotocol, + logger=logger, ) if kwargs.pop("unix", False): diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 353e5b370..60c0a14ae 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -2,6 +2,7 @@ import contextlib import functools import http +import logging import pathlib import random import socket @@ -1471,3 +1472,18 @@ async def run_client(): self.loop.run_until_complete(run_client()) self.assertEqual(messages, self.MESSAGES) + + +class LoggerTests(ClientServerTestsMixin, AsyncioTestCase): + def test_logger_client(self): + with self.assertLogs("test.server", logging.DEBUG) as server_logs: + self.start_server(logger=logging.getLogger("test.server")) + with self.assertLogs("test.client", logging.DEBUG) as client_logs: + self.start_client(logger=logging.getLogger("test.client")) + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + self.stop_client() + self.stop_server() + + self.assertGreater(len(server_logs.records), 0) + self.assertGreater(len(client_logs.records), 0) From 57ff93f051d0731d28c8be5182740a0167a85d23 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 29 May 2021 07:25:52 +0200 Subject: [PATCH 054/762] Document how to disable logging. Fix #759. --- docs/howto/index.rst | 1 + docs/howto/logging.rst | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 docs/howto/logging.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index ac1182705..e5af8488e 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -15,6 +15,7 @@ The following guides will help you integrate websockets into a broader system. :maxdepth: 2 django + logging The WebSocket protocol makes provisions for extending or specializing its features, which websockets supports fully. diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst new file mode 100644 index 000000000..214e05105 --- /dev/null +++ b/docs/howto/logging.rst @@ -0,0 +1,26 @@ +Configure logging +================= + +Disable logging +--------------- + +If your application doesn't configure :mod:`logging`, Python outputs messages +of severity :data:`~logging.WARNING` and higher to :data:`~sys.stderr`. If +you want to disable this behavior for websockets, you can add +a :class:`~logging.NullHandler`:: + + logging.getLogger("websockets").addHandler(logging.NullHandler()) + +Additionally, if your application configures :mod:`logging`, you must disable +propagation to the root logger, or else its handlers could output logs:: + + logging.getLogger("websockets").propagate = False + +Alternatively, you could set the log level to :data:`~logging.CRITICAL` for +websockets, as the highest level currently used is :data:`~logging.ERROR`:: + + logging.getLogger("websockets").setLevel(logging.CRITICAL) + +Or you could configure a filter to drop all messages:: + + logging.getLogger("websockets").addFilter(lambda record: None) From 3388225957c6bba1cf46748f6f83c6bb8f7ff7f6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 14:13:07 +0200 Subject: [PATCH 055/762] Separate prints in examples from logs. They used the same < and > symbols which was confusing. --- example/client.py | 4 ++-- example/secure_client.py | 4 ++-- example/secure_server.py | 4 ++-- example/server.py | 4 ++-- example/unix_client.py | 4 ++-- example/unix_server.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/example/client.py b/example/client.py index e39df81f7..062540202 100755 --- a/example/client.py +++ b/example/client.py @@ -11,9 +11,9 @@ async def hello(): name = input("What's your name? ") await websocket.send(name) - print(f"> {name}") + print(f">>> {name}") greeting = await websocket.recv() - print(f"< {greeting}") + print(f"<<< {greeting}") asyncio.run(hello()) diff --git a/example/secure_client.py b/example/secure_client.py index 2657ba68b..518819dd1 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -20,9 +20,9 @@ async def hello(): name = input("What's your name? ") await websocket.send(name) - print(f"> {name}") + print(f">>> {name}") greeting = await websocket.recv() - print(f"< {greeting}") + print(f"<<< {greeting}") asyncio.run(hello()) diff --git a/example/secure_server.py b/example/secure_server.py index e0ef6e53b..96c300390 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -9,12 +9,12 @@ async def hello(websocket, path): name = await websocket.recv() - print(f"< {name}") + print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) - print(f"> {greeting}") + print(f">>> {greeting}") ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") diff --git a/example/server.py b/example/server.py index 98dbb5acd..4dcf317f5 100755 --- a/example/server.py +++ b/example/server.py @@ -7,12 +7,12 @@ async def hello(websocket, path): name = await websocket.recv() - print(f"< {name}") + print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) - print(f"> {greeting}") + print(f">>> {greeting}") async def main(): async with websockets.serve(hello, "localhost", 8765): diff --git a/example/unix_client.py b/example/unix_client.py index 434638c80..926156730 100755 --- a/example/unix_client.py +++ b/example/unix_client.py @@ -11,9 +11,9 @@ async def hello(): async with websockets.unix_connect(socket_path) as websocket: name = input("What's your name? ") await websocket.send(name) - print(f"> {name}") + print(f">>> {name}") greeting = await websocket.recv() - print(f"< {greeting}") + print(f"<<< {greeting}") asyncio.run(hello()) diff --git a/example/unix_server.py b/example/unix_server.py index 223d97301..192f31bb0 100755 --- a/example/unix_server.py +++ b/example/unix_server.py @@ -8,12 +8,12 @@ async def hello(websocket, path): name = await websocket.recv() - print(f"< {name}") + print(f"<<< {name}") greeting = f"Hello {name}!" await websocket.send(greeting) - print(f"> {greeting}") + print(f">>> {greeting}") async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") From acb6166aa3ab20dbd4ca8e80abf1a614b76f92c1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 15:06:06 +0200 Subject: [PATCH 056/762] Overhaul logging. * Log connections open and closed at the info level. * Put debug logs behind `if self.debug:` to minimize overhead. * Standardize logging style. * Add documentation. --- docs/howto/logging.rst | 129 +++++++++++++++++++++++++- docs/project/changelog.rst | 2 + setup.cfg | 8 +- src/websockets/client.py | 20 +++- src/websockets/connection.py | 38 ++++---- src/websockets/legacy/client.py | 12 ++- src/websockets/legacy/protocol.py | 148 ++++++++++++++++-------------- src/websockets/legacy/server.py | 42 +++++---- src/websockets/server.py | 35 ++++--- tests/test_connection.py | 8 +- tests/test_http11.py | 4 +- 11 files changed, 315 insertions(+), 131 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index 214e05105..c83ec6906 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -1,12 +1,98 @@ +Logging +======= + +.. currentmodule:: websockets + +Logs contents +------------- + +When you run a WebSocket client, your code calls coroutines provided by +websockets. + +If an error occurs, websockets tells you by raising an exception. For example, +it raises a :exc:`~exception.ConnectionClosed` exception if the other side +closes the connection. + +When you run a WebSocket server, websockets accepts connections, performs the +opening handshake, runs the connection handler coroutine that you provided, +and performs the closing handshake. + +Given this `inversion of control`_, if an error happens in the opening +handshake or if the connection handler crashes, there is no way to raise an +exception that you can handle. + +.. _inversion of control: https://en.wikipedia.org/wiki/Inversion_of_control + +Logs tell you about these errors. + +Besides errors, you may want to record the activity of the server. + +In a request/response protocol such as HTTP, there's an obvious way to record +activity: log one event per request/response. Unfortunately, this solution +doesn't work well for a bidirectional protocol such as WebSocket. + +Instead, when running as a server, websockets logs one event when a +`connection is established`_ and another event when a `connection is +closed`_. + +.. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455#section-4 +.. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 + +websockets doesn't log an event for every message because that would be +excessive for many applications exchanging small messages at a fast rate. +However, you could add this level of logging in your own code if necessary. + +See :ref:`log levels ` below for details of events logged by +websockets at each level. + Configure logging -================= +----------------- + +websockets relies on the :mod:`logging` module from the standard library in +order to maximize compatibility and integrate nicely with other libraries:: + + import logging + +websockets logs to the ``"websockets.client"`` and ``"websockets.server"`` +loggers. + +websockets doesn't provide a default logging configuration because +requirements vary a lot depending on the environment. + +Here's a basic configuration for a server in production:: + + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.INFO, + ) + +Here's how to enable debug logs for development:: + + logging.basicConfig( + format="%(message)s", + level=logging.DEBUG, + ) + +You can select a different :class:`~logging.Logger` with the ``logger`` +argument:: + + import websockets + + async with websockets.serve( + ..., + logger=logging.getLogger("interface.websocket"), + ): + ... Disable logging --------------- If your application doesn't configure :mod:`logging`, Python outputs messages -of severity :data:`~logging.WARNING` and higher to :data:`~sys.stderr`. If -you want to disable this behavior for websockets, you can add +of severity :data:`~logging.WARNING` and higher to :data:`~sys.stderr`. As a +consequence, you will see a message and a stack trace if a connection handler +coroutine crashes or if you hit a bug in websockets. + +If you want to disable this behavior for websockets, you can add a :class:`~logging.NullHandler`:: logging.getLogger("websockets").addHandler(logging.NullHandler()) @@ -24,3 +110,40 @@ websockets, as the highest level currently used is :data:`~logging.ERROR`:: Or you could configure a filter to drop all messages:: logging.getLogger("websockets").addFilter(lambda record: None) + +.. _log-levels: + +Log levels +---------- + +Here's what websockets logs at each level. + +:attr:`~logging.ERROR` +...................... + +* Exceptions raised by connection handler coroutines in servers +* Exceptions resulting from bugs in websockets + +:attr:`~logging.INFO` +..................... + +* Connections opened and closed in servers + +:attr:`~logging.DEBUG` +...................... + +* Changes to the state of connections +* Handshake requests and responses +* All frames sent and received +* Steps to close a connection +* Keepalive pings and pongs +* Errors handled transparently + +Debug messages have cute prefixes that make logs easier to scan: + +* ``>`` - send something +* ``<`` - receive something +* ``=`` - set connection state +* ``x`` - shut down connection +* ``%`` - manage pings and pongs +* ``!`` - handle errors and timeouts diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 919081a3a..caf637614 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,6 +43,8 @@ They may change at any time. * Added compatibility with Python 3.10. +* Improved logging. + * Optimized default compression settings to reduce memory usage. * Made it easier to customize authentication with diff --git a/setup.cfg b/setup.cfg index 0625798a2..9ff1939c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,7 +20,8 @@ lines_after_imports = 2 [coverage:run] branch = True -omit = */__main__.py +omit = + */__main__.py source = websockets tests @@ -29,3 +30,8 @@ source = source = src/websockets .tox/*/lib/python*/site-packages/websockets + +[coverage:report] +exclude_lines = + if self.debug: + pragma: no cover diff --git a/src/websockets/client.py b/src/websockets/client.py index 9d3f0e7f7..738b6c998 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -275,21 +275,33 @@ def send_request(self, request: Request) -> None: Send a WebSocket handshake request to the server. """ - self.logger.debug("%s > GET %s HTTP/1.1", self.side, request.path) - self.logger.debug("%s > %r", self.side, request.headers) + if self.debug: + self.logger.debug("> GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("> %s: %s", key, value) self.writes.append(request.serialize()) def parse(self) -> Generator[None, None, None]: response = yield from Response.parse( - self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, ) + + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("< HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + if response.body is not None: + self.logger.debug("< [body] (%d bytes)", len(response.body)) + assert self.state == CONNECTING try: self.process_response(response) except InvalidHandshake as exc: response = response._replace(exception=exc) - self.logger.debug("Invalid handshake", exc_info=True) else: self.set_state(OPEN) finally: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4852560e3..5052df3f7 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -71,23 +71,23 @@ def __init__( # Unique identifier. For logs. self.id = uuid.uuid4() + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger(f"websockets.{side.name.lower()}") + self.logger = logger + + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) + # Connection side. CLIENT or SERVER. self.side = side # Connnection state. CONNECTING and CLOSED states are handled in subclasses. - self.state = state + self.set_state(state) # Maximum size of incoming messages in bytes. self.max_size = max_size - # Logger or LoggerAdapter for this connection. - if logger is None: - logger = logging.getLogger(f"websockets.{side.name.lower()}") - self.logger = logger - - # Must wait until we have the logger to log the initial state! - self.logger.debug("%s - initial state: %s", self.side, state.name) - # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. self.cur_size: Optional[int] = None @@ -123,9 +123,8 @@ def __init__( self.parser_exc: Optional[Exception] = None def set_state(self, state: State) -> None: - self.logger.debug( - "%s - state change: %s > %s", self.side, self.state.name, state.name - ) + if self.debug: + self.logger.debug("= connection is %s", state.name) self.state = state # Public APIs for receiving data. @@ -273,7 +272,7 @@ def step_parser(self) -> None: # EOF because receive_data() or receive_eof() would fail earlier.) assert self.parser_exc is not None raise RuntimeError( - "cannot receive data or EOF after an error" + "parser cannot receive data or EOF after an error" ) from self.parser_exc except ProtocolError as exc: self.fail_connection(1002, str(exc)) @@ -292,7 +291,7 @@ def step_parser(self) -> None: self.parser_exc = exc raise except Exception as exc: - self.logger.error("unexpected exception in parser", exc_info=True) + self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. self.fail_connection(1011) self.parser_exc = exc @@ -302,6 +301,8 @@ def parse(self) -> Generator[None, None, None]: while True: eof = yield from self.reader.at_eof() if eof: + if self.debug: + self.logger.debug("< EOF") if self.close_frame_received: if not self.eof_sent: self.send_eof() @@ -330,6 +331,9 @@ def parse(self) -> Generator[None, None, None]: extensions=self.extensions, ) + if self.debug: + self.logger.debug("< %s", frame) + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: # 5.5.1 Close: "The application MUST NOT send any more data # frames after sending a Close frame." @@ -407,7 +411,8 @@ def send_frame(self, frame: Frame) -> None: f"cannot write to a WebSocket in the {self.state.name} state" ) - self.logger.debug("%s > %r", self.side, frame) + if self.debug: + self.logger.debug("> %s", frame) self.writes.append( frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) ) @@ -415,5 +420,6 @@ def send_frame(self, frame: Frame) -> None: def send_eof(self) -> None: assert not self.eof_sent self.eof_sent = True - self.logger.debug("%s > EOF", self.side) + if self.debug: + self.logger.debug("> EOF") self.writes.append(SEND_EOF) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 760309c3f..5c281d3b8 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -169,8 +169,10 @@ def write_http_request(self, path: str, headers: Headers) -> None: self.path = path self.request_headers = headers - self.logger.debug("%s > GET %s HTTP/1.1", self.side, path) - self.logger.debug("%s > %r", self.side, headers) + if self.debug: + self.logger.debug("> GET %s HTTP/1.1", path) + for key, value in headers.raw_items(): + self.logger.debug("> %s: %s", key, value) # Since the path and headers only contain ASCII characters, # we can keep this simple. @@ -199,8 +201,10 @@ async def read_http_response(self) -> Tuple[int, Headers]: except Exception as exc: raise InvalidMessage("did not receive a valid HTTP response") from exc - self.logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) - self.logger.debug("%s < %r", self.side, headers) + if self.debug: + self.logger.debug("< HTTP/1.1 %d %s", status_code, reason) + for key, value in headers.raw_items(): + self.logger.debug("< %s: %s", key, value) self.response_headers = headers diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 940e330a9..ec8abaed9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -131,12 +131,17 @@ def __init__( self.max_queue = max_queue self.read_limit = read_limit self.write_limit = write_limit + + # Unique identifier. For logs. + self.id = uuid.uuid4() + + # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger("websockets.protocol") self.logger = logger - # Unique identifier. For logs. - self.id = uuid.uuid4() + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) assert loop is not None # Remove when dropping Python < 3.10 - use get_running_loop instead. @@ -164,7 +169,8 @@ def __init__( # Subclasses implement the opening handshake and, on success, execute # :meth:`connection_open` to change the state to OPEN. self.state = State.CONNECTING - self.logger.debug("%s - state = CONNECTING", self.side) + if self.debug: + self.logger.debug("= connection is CONNECTING") # HTTP protocol parameters. self.path: str @@ -248,7 +254,8 @@ def connection_open(self) -> None: # 4.1. The WebSocket Connection is Established. assert self.state is State.CONNECTING self.state = State.OPEN - self.logger.debug("%s - state = OPEN", self.side) + if self.debug: + self.logger.debug("= connection is OPEN") # Start the task that receives incoming WebSocket messages. self.transfer_data_task = self.loop.create_task(self.transfer_data()) # Start the task that sends pings at regular intervals. @@ -814,7 +821,7 @@ async def transfer_data(self) -> None: # This shouldn't happen often because exceptions expected under # regular circumstances are handled above. If it does, consider # catching and handling more exceptions. - self.logger.error("Error in data transfer", exc_info=True) + self.logger.error("data transfer failed", exc_info=True) self.transfer_data_exc = exc self.fail_connection(1011) @@ -925,20 +932,20 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PING: # Answer pings. ping_hex = frame.data.hex() or "[empty]" - self.logger.debug( - "%s - received ping, sending pong: %s", self.side, ping_hex - ) + if self.debug: + self.logger.debug("%% received ping, sending pong: %s", ping_hex) await self.pong(frame.data) elif frame.opcode == OP_PONG: # Acknowledge pings on solicited pongs. if frame.data in self.pings: - self.logger.debug( - "%s - received solicited pong: %s", - self.side, - frame.data.hex() or "[empty]", - ) + if self.debug: + self.logger.debug( + "%% received solicited pong: %s", + frame.data.hex() or "[empty]", + ) # Acknowledge all pings up to the one matching this pong. + # Sending a pong for only the most recent ping is legal. ping_id = None ping_ids = [] for ping_id, ping in self.pings.items(): @@ -952,24 +959,25 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: # Remove acknowledged pings from self.pings. for ping_id in ping_ids: del self.pings[ping_id] - ping_ids = ping_ids[:-1] - if ping_ids: - pings_hex = ", ".join( - ping_id.hex() or "[empty]" for ping_id in ping_ids - ) - plural = "s" if len(ping_ids) > 1 else "" + # Log previous pings acknowledged. + if self.debug: + ping_ids = ping_ids[:-1] + if ping_ids: + pings_hex = ", ".join( + ping_id.hex() or "[empty]" for ping_id in ping_ids + ) + plural = "s" if len(ping_ids) > 1 else "" + self.logger.debug( + "%% acknowledged previous ping%s: %s", + plural, + pings_hex, + ) + else: + if self.debug: self.logger.debug( - "%s - acknowledged previous ping%s: %s", - self.side, - plural, - pings_hex, + "%% received unsolicited pong: %s", + frame.data.hex() or "[empty]", ) - else: - self.logger.debug( - "%s - received unsolicited pong: %s", - self.side, - frame.data.hex() or "[empty]", - ) # 5.6. Data Frames else: @@ -986,7 +994,8 @@ async def read_frame(self, max_size: Optional[int]) -> Frame: max_size=max_size, extensions=self.extensions, ) - self.logger.debug("%s < %r", self.side, frame) + if self.debug: + self.logger.debug("< %s", frame) return frame async def write_frame( @@ -999,9 +1008,12 @@ async def write_frame( ) frame = Frame(fin, Opcode(opcode), data) - self.logger.debug("%s > %r", self.side, frame) + if self.debug: + self.logger.debug("> %s", frame) frame.write( - self.transport.write, mask=self.is_client, extensions=self.extensions + self.transport.write, + mask=self.is_client, + extensions=self.extensions, ) try: @@ -1031,7 +1043,8 @@ async def write_close_frame(self, data: bytes = b"") -> None: if self.state is State.OPEN: # 7.1.3. The WebSocket Closing Handshake is Started self.state = State.CLOSING - self.logger.debug("%s - state = CLOSING", self.side) + if self.debug: + self.logger.debug("= connection is CLOSING") # 7.1.2. Start the WebSocket Closing Handshake await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) @@ -1073,7 +1086,8 @@ async def keepalive_ping(self) -> None: **loop_if_py_lt_38(self.loop), ) except asyncio.TimeoutError: - self.logger.debug("%s ! timed out waiting for pong", self.side) + if self.debug: + self.logger.debug("! timed out waiting for pong") self.fail_connection(1011) break @@ -1086,9 +1100,7 @@ async def keepalive_ping(self) -> None: pass except Exception: - self.logger.warning( - "Unexpected exception in keepalive ping task", exc_info=True - ) + self.logger.error("keepalive ping failed", exc_info=True) async def close_connection(self) -> None: """ @@ -1120,18 +1132,21 @@ async def close_connection(self) -> None: # Coverage marks this line as a partially executed branch. # I supect a bug in coverage. Ignore it for now. return # pragma: no cover - self.logger.debug("%s ! timed out waiting for TCP close", self.side) + if self.debug: + self.logger.debug("! timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): - self.logger.debug("%s x half-closing TCP connection", self.side) + if self.debug: + self.logger.debug("x half-closing TCP connection") self.transport.write_eof() if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. # I supect a bug in coverage. Ignore it for now. return # pragma: no cover - self.logger.debug("%s ! timed out waiting for TCP close", self.side) + if self.debug: + self.logger.debug("! timed out waiting for TCP close") finally: # The try/finally ensures that the transport never remains open, @@ -1144,15 +1159,18 @@ async def close_connection(self) -> None: return # Close the TCP connection. Buffers are flushed asynchronously. - self.logger.debug("%s x closing TCP connection", self.side) + if self.debug: + self.logger.debug("x closing TCP connection") self.transport.close() if await self.wait_for_connection_lost(): return - self.logger.debug("%s ! timed out waiting for TCP close", self.side) + if self.debug: + self.logger.debug("! timed out waiting for TCP close") # Abort the TCP connection. Buffers are discarded. - self.logger.debug("%s x aborting TCP connection", self.side) + if self.debug: + self.logger.debug("x aborting TCP connection") self.transport.abort() # connection_lost() is called quickly after aborting. @@ -1200,12 +1218,8 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: (The specification describes these steps in the opposite order.) """ - self.logger.debug( - "%s ! failing %s WebSocket connection with code %d", - self.side, - self.state.name, - code, - ) + if self.debug: + self.logger.debug("! failing connection with code %d", code) # Cancel transfer_data_task if the opening handshake succeeded. # cancel() is idempotent and ignored if the task is done already. @@ -1230,12 +1244,16 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # and write_frame(). self.state = State.CLOSING - self.logger.debug("%s - state = CLOSING", self.side) + if self.debug: + self.logger.debug("= connection is CLOSING") frame = Frame(True, OP_CLOSE, frame_data) - self.logger.debug("%s > %r", self.side, frame) + if self.debug: + self.logger.debug("> %s", frame) frame.write( - self.transport.write, mask=self.is_client, extensions=self.extensions + self.transport.write, + mask=self.is_client, + extensions=self.extensions, ) # Start close_connection_task if the opening handshake didn't succeed. @@ -1260,12 +1278,13 @@ def abort_pings(self) -> None: # nothing, but it prevents logging the exception. ping.cancel() - if self.pings: - pings_hex = ", ".join(ping_id.hex() or "[empty]" for ping_id in self.pings) - plural = "s" if len(self.pings) > 1 else "" - self.logger.debug( - "%s - aborted pending ping%s: %s", self.side, plural, pings_hex - ) + if self.debug: + if self.pings: + pings_hex = ", ".join( + ping_id.hex() or "[empty]" for ping_id in self.pings + ) + plural = "s" if len(self.pings) > 1 else "" + self.logger.debug("% aborted pending ping%s: %s", plural, pings_hex) # asyncio.Protocol methods @@ -1283,8 +1302,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: which means it's the best point for configuring it. """ - self.logger.debug("%s - event = connection_made(%s)", self.side, transport) - transport = cast(asyncio.Transport, transport) transport.set_write_buffer_limits(self.write_limit) self.transport = transport @@ -1297,20 +1314,19 @@ def connection_lost(self, exc: Optional[Exception]) -> None: 7.1.4. The WebSocket Connection is Closed. """ - self.logger.debug("%s - event = connection_lost(%s)", self.side, exc) self.state = State.CLOSED - self.logger.debug("%s - state = CLOSED", self.side) if not hasattr(self, "close_code"): self.close_code = 1006 if not hasattr(self, "close_reason"): self.close_reason = "" self.logger.debug( - "%s x code = %d, reason = %s", - self.side, + "= connection is CLOSED - code = %d, reason = %s", self.close_code, self.close_reason or "[no reason]", ) + self.abort_pings() + # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. @@ -1355,9 +1371,6 @@ def resume_writing(self) -> None: # pragma: no cover waiter.set_result(None) def data_received(self, data: bytes) -> None: - self.logger.debug( - "%s - event = data_received(<%d bytes>)", self.side, len(data) - ) self.reader.feed_data(data) def eof_received(self) -> None: @@ -1373,5 +1386,4 @@ def eof_received(self) -> None: Besides, that doesn't work on TLS connections. """ - self.logger.debug("%s - event = eof_received()", self.side) self.reader.feed_eof() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 5f76500df..31ad0b18e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -237,22 +237,21 @@ async def handler(self) -> None: except asyncio.CancelledError: # pragma: no cover raise except ConnectionError: - self.logger.debug( - "Connection error in opening handshake", exc_info=True - ) raise except Exception as exc: if isinstance(exc, AbortHandshake): status, headers, body = exc.status, exc.headers, exc.body elif isinstance(exc, InvalidOrigin): - self.logger.debug("Invalid origin", exc_info=True) + if self.debug: + self.logger.debug("! invalid origin", exc_info=True) status, headers, body = ( http.HTTPStatus.FORBIDDEN, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) elif isinstance(exc, InvalidUpgrade): - self.logger.debug("Invalid upgrade", exc_info=True) + if self.debug: + self.logger.debug("! invalid upgrade", exc_info=True) status, headers, body = ( http.HTTPStatus.UPGRADE_REQUIRED, Headers([("Upgrade", "websocket")]), @@ -264,14 +263,15 @@ async def handler(self) -> None: ).encode(), ) elif isinstance(exc, InvalidHandshake): - self.logger.debug("Invalid handshake", exc_info=True) + if self.debug: + self.logger.debug("! invalid handshake", exc_info=True) status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), f"Failed to open a WebSocket connection: {exc}.\n".encode(), ) else: - self.logger.warning("Error in opening handshake", exc_info=True) + self.logger.error("opening handshake failed", exc_info=True) status, headers, body = ( http.HTTPStatus.INTERNAL_SERVER_ERROR, Headers(), @@ -288,6 +288,9 @@ async def handler(self) -> None: headers.setdefault("Connection", "close") self.write_http_response(status, headers, body) + self.logger.info( + "connection failed (%d %s)", status.value, status.phrase + ) self.fail_connection() await self.wait_closed() return @@ -295,7 +298,7 @@ async def handler(self) -> None: try: await self.ws_handler(self, path) except Exception: - self.logger.error("Error in connection handler", exc_info=True) + self.logger.error("connection handler failed", exc_info=True) if not self.closed: self.fail_connection(1011) raise @@ -303,12 +306,9 @@ async def handler(self) -> None: try: await self.close() except ConnectionError: - self.logger.debug( - "Connection error in closing handshake", exc_info=True - ) raise except Exception: - self.logger.warning("Error in closing handshake", exc_info=True) + self.logger.error("closing handshake failed", exc_info=True) raise except Exception: @@ -324,6 +324,7 @@ async def handler(self) -> None: # task because the server waits for tasks attached to registered # connections before terminating. self.ws_server.unregister(self) + self.logger.info("connection closed") async def read_http_request(self) -> Tuple[str, Headers]: """ @@ -343,8 +344,10 @@ async def read_http_request(self) -> Tuple[str, Headers]: except Exception as exc: raise InvalidMessage("did not receive a valid HTTP request") from exc - self.logger.debug("%s < GET %s HTTP/1.1", self.side, path) - self.logger.debug("%s < %r", self.side, headers) + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", path) + for key, value in headers.raw_items(): + self.logger.debug("< %s: %s", key, value) self.path = path self.request_headers = headers @@ -362,8 +365,12 @@ def write_http_response( """ self.response_headers = headers - self.logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase) - self.logger.debug("%s > %r", self.side, headers) + if self.debug: + self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase) + for key, value in headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + if body is not None: + self.logger.debug("> [body] (%d bytes)", len(body)) # Since the status line and headers only contain ASCII characters, # we can keep this simple. @@ -373,7 +380,6 @@ def write_http_response( self.transport.write(response.encode()) if body is not None: - self.logger.debug("%s > body (%d bytes)", self.side, len(body)) self.transport.write(body) async def process_request( @@ -692,6 +698,8 @@ async def handshake( self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) + self.logger.info("connection open") + self.connection_open() return path diff --git a/src/websockets/server.py b/src/websockets/server.py index 4a76ac886..87aa07b51 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -84,13 +84,15 @@ def accept(self, request: Request) -> Response: try: key, extensions_header, protocol_header = self.process_request(request) except InvalidOrigin as exc: - self.logger.debug("Invalid origin", exc_info=True) + if self.debug: + self.logger.debug("! invalid origin", exc_info=True) return self.reject( http.HTTPStatus.FORBIDDEN, f"Failed to open a WebSocket connection: {exc}.\n", )._replace(exception=exc) except InvalidUpgrade as exc: - self.logger.debug("Invalid upgrade", exc_info=True) + if self.debug: + self.logger.debug("! invalid upgrade", exc_info=True) return self.reject( http.HTTPStatus.UPGRADE_REQUIRED, ( @@ -102,13 +104,14 @@ def accept(self, request: Request) -> Response: headers=Headers([("Upgrade", "websocket")]), )._replace(exception=exc) except InvalidHandshake as exc: - self.logger.debug("Invalid handshake", exc_info=True) + if self.debug: + self.logger.debug("! invalid handshake", exc_info=True) return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc}.\n", )._replace(exception=exc) except Exception as exc: - self.logger.warning("Error in opening handshake", exc_info=True) + self.logger.error("opening handshake failed", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, ( @@ -145,6 +148,7 @@ def accept(self, request: Request) -> Response: headers.setdefault("Date", email.utils.formatdate(usegmt=True)) headers.setdefault("Server", USER_AGENT) + self.logger.info("connection open") return Response(101, "Switching Protocols", headers) def process_request( @@ -404,6 +408,7 @@ def reject( headers.setdefault("Content-Length", str(len(body))) headers.setdefault("Content-Type", "text/plain; charset=utf-8") headers.setdefault("Connection", "close") + self.logger.info("connection failed (%d %s)", status.value, status.phrase) return Response(status.value, status.phrase, headers, body) def send_response(self, response: Response) -> None: @@ -414,20 +419,24 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: self.set_state(OPEN) - self.logger.debug( - "%s > HTTP/1.1 %d %s", - self.side, - response.status_code, - response.reason_phrase, - ) - self.logger.debug("%s > %r", self.side, response.headers) - if response.body is not None: - self.logger.debug("%s > body (%d bytes)", self.side, len(response.body)) + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("> HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + if response.body is not None: + self.logger.debug("> [body] (%d bytes)", len(response.body)) self.writes.append(response.serialize()) def parse(self) -> Generator[None, None, None]: request = yield from Request.parse(self.reader.read_line) + + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + assert self.state == CONNECTING self.events.append(request) yield from super().parse() diff --git a/tests/test_connection.py b/tests/test_connection.py index 3e39a3f9e..6203a1469 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1356,7 +1356,7 @@ def test_client_receives_data_after_exception(self): with self.assertRaises(RuntimeError) as raised: client.receive_data(b"\x00\x00") self.assertEqual( - str(raised.exception), "cannot receive data or EOF after an error" + str(raised.exception), "parser cannot receive data or EOF after an error" ) def test_server_receives_data_after_exception(self): @@ -1367,7 +1367,7 @@ def test_server_receives_data_after_exception(self): with self.assertRaises(RuntimeError) as raised: server.receive_data(b"\x00\x00") self.assertEqual( - str(raised.exception), "cannot receive data or EOF after an error" + str(raised.exception), "parser cannot receive data or EOF after an error" ) def test_client_receives_eof_after_exception(self): @@ -1378,7 +1378,7 @@ def test_client_receives_eof_after_exception(self): with self.assertRaises(RuntimeError) as raised: client.receive_eof() self.assertEqual( - str(raised.exception), "cannot receive data or EOF after an error" + str(raised.exception), "parser cannot receive data or EOF after an error" ) def test_server_receives_eof_after_exception(self): @@ -1389,7 +1389,7 @@ def test_server_receives_eof_after_exception(self): with self.assertRaises(RuntimeError) as raised: server.receive_eof() self.assertEqual( - str(raised.exception), "cannot receive data or EOF after an error" + str(raised.exception), "parser cannot receive data or EOF after an error" ) def test_client_receives_data_after_eof(self): diff --git a/tests/test_http11.py b/tests/test_http11.py index e73365cf0..afd85f64a 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -132,7 +132,9 @@ def setUp(self): def parse(self): return Response.parse( - self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, ) def test_parse(self): From 6c138c7f48cfb41151bad829f1bc5d41a7a95e90 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 20:30:07 +0200 Subject: [PATCH 057/762] Add protocol instance to logging context. --- docs/howto/logging.rst | 49 ++++++++++++++++++++++++++++--- docs/spelling_wordlist.txt | 1 + src/websockets/legacy/protocol.py | 4 ++- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index c83ec6906..5aaf31ee3 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -73,14 +73,55 @@ Here's how to enable debug logs for development:: level=logging.DEBUG, ) -You can select a different :class:`~logging.Logger` with the ``logger`` -argument:: +Furthermore, websockets adds a ``websocket`` attribute to every log record, so +you can include additional information about connections in logs. - import websockets +You could attempt to add information with a formatter:: + + # this doesn't work! + logging.basicConfig( + format="{asctime} {websocket.id} {websocket.remote_address[0]} {message}", + level=logging.INFO, + style="{", + ) + +However, this technique has two downsides: + +* The formatter applies to all records. It will crash if it receives a record + that doesn't have a ``websocket`` attribute. You could configure logging to + work around this problem but that could get complicated quickly. + +* Even with :meth:`str.format` style, you're restricted to attribute and index + lookups, which isn't enough to implement some fairly simple requirements. + +There's a better way. :func:`~server.serve` accepts a ``logger`` argument to +override the default :class:`~logging.Logger`. You can set ``logger`` to +a :class:`~logging.LoggerAdapter` that enriches logs. + +For example, if the server is behind a reverse proxy, ``remote_address`` gives +the IP address of the proxy, which isn't useful. IP addresses of clients are +generally available in a HTTP header set by the proxy. + +Here's how to include them in logs, assuming they're in the +``X-Forwarded-For`` header:: + + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.INFO, + ) + + class LoggerAdapter(logging.LoggerAdapter): + def process(self, msg, kwargs): + try: + websocket = kwargs["extra"]["websocket"] + except KeyError: + return msg, kwargs + xff = websocket.request_headers.get("X-Forwarded-For") + return f"{websocket.id} {xff} {msg}", kwargs async with websockets.serve( ..., - logger=logging.getLogger("interface.websocket"), + logger=LoggerAdapter(logging.getLogger("websockets.server")), ): ... diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 4fa44fbaa..b460ef033 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -25,6 +25,7 @@ daemonize datastructures django dyno +formatter fractalideas gunicorn haproxy diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index ec8abaed9..90683aaf9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -138,7 +138,9 @@ def __init__( # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger("websockets.protocol") - self.logger = logger + # https://github.com/python/typeshed/issues/5561 + logger = cast(logging.Logger, logger) + self.logger = logging.LoggerAdapter(logger, {"websocket": self}) # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) From e93f5e916291479c9b6be85609744e805681eb41 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 May 2021 07:57:04 +0200 Subject: [PATCH 058/762] Document structured logging. --- docs/howto/logging.rst | 40 +++++++++++++++++++++++++++++++++++ example/json_log_formatter.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 example/json_log_formatter.py diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index 5aaf31ee3..b780c0162 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -125,6 +125,46 @@ Here's how to include them in logs, assuming they're in the ): ... +Logging to JSON +--------------- + +Even though :mod:`logging` predates structured logging, it's still possible to +output logs as JSON with a bit of effort. + +First, we need a :class:`~logging.Formatter` that renders JSON: + +.. literalinclude:: ../../example/json_log_formatter.py + +Then, we configure logging to apply this formatter:: + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + logger = logging.getLogger() + logger.addHandler(handler) + logger.setLevel(logging.INFO) + +Finally, we populate the ``event_data`` custom attribute in log records with +a :class:`~logging.LoggerAdapter`:: + + class LoggerAdapter(logging.LoggerAdapter): + def process(self, msg, kwargs): + try: + websocket = kwargs["extra"]["websocket"] + except KeyError: + return msg, kwargs + kwargs["extra"]["event_data"] = { + "connection_id": str(websocket.id), + "remote_addr": websocket.request_headers.get("X-Forwarded-For"), + } + return msg, kwargs + + async with websockets.serve( + ..., + logger=LoggerAdapter(logging.getLogger("websockets.server")), + ): + ... + Disable logging --------------- diff --git a/example/json_log_formatter.py b/example/json_log_formatter.py new file mode 100644 index 000000000..b8fc8d6dc --- /dev/null +++ b/example/json_log_formatter.py @@ -0,0 +1,33 @@ +import json +import logging +import datetime + +class JSONFormatter(logging.Formatter): + """ + Render logs as JSON. + + To add details to a log record, store them in a ``event_data`` + custom attribute. This dict is merged into the event. + + """ + def __init__(self): + pass # override logging.Formatter constructor + + def format(self, record): + event = { + "timestamp": self.getTimestamp(record.created), + "message": record.getMessage(), + "level": record.levelname, + "logger": record.name, + } + event_data = getattr(record, "event_data", None) + if event_data: + event.update(event_data) + if record.exc_info: + event["exc_info"] = self.formatException(record.exc_info) + if record.stack_info: + event["stack_info"] = self.formatStack(record.stack_info) + return json.dumps(event) + + def getTimestamp(self, created): + return datetime.datetime.utcfromtimestamp(created).isoformat() From a61ab26038e97e48b2f894c3e66e28bdfa3d8f3f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 30 May 2021 19:02:58 +0200 Subject: [PATCH 059/762] Log when server starts and stops. --- docs/howto/logging.rst | 13 +++++++------ src/websockets/legacy/server.py | 27 +++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index b780c0162..f69ee47b9 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -73,8 +73,8 @@ Here's how to enable debug logs for development:: level=logging.DEBUG, ) -Furthermore, websockets adds a ``websocket`` attribute to every log record, so -you can include additional information about connections in logs. +Furthermore, websockets adds a ``websocket`` attribute to log records, so you +can include additional information about the current connection in logs. You could attempt to add information with a formatter:: @@ -85,11 +85,11 @@ You could attempt to add information with a formatter:: style="{", ) -However, this technique has two downsides: +However, this technique runs into two problems: * The formatter applies to all records. It will crash if it receives a record - that doesn't have a ``websocket`` attribute. You could configure logging to - work around this problem but that could get complicated quickly. + without a ``websocket`` attribute. For example, this happens when logging + that the server starts because there is no current connection. * Even with :meth:`str.format` style, you're restricted to attribute and index lookups, which isn't enough to implement some fairly simple requirements. @@ -208,7 +208,8 @@ Here's what websockets logs at each level. :attr:`~logging.INFO` ..................... -* Connections opened and closed in servers +* Server starting and stopping +* Server establishing and closing connections :attr:`~logging.DEBUG` ...................... diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 31ad0b18e..67fffff2b 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -725,10 +725,16 @@ class WebSocketServer: """ - def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, loop: asyncio.AbstractEventLoop, logger: Optional[LoggerLike] = None + ) -> None: # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + # Keep track of active connections. self.websockets: Set[WebSocketServerProtocol] = set() @@ -753,6 +759,19 @@ def wrap(self, server: asyncio.AbstractServer) -> None: """ self.server = server + assert server.sockets is not None + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) def register(self, protocol: WebSocketServerProtocol) -> None: """ @@ -803,6 +822,8 @@ async def _close(self) -> None: then closes open connections with close code 1001. """ + self.logger.info("server closing") + # Stop accepting new connections. self.server.close() @@ -839,6 +860,8 @@ async def _close(self) -> None: # Tell wait_closed() to return. self.closed_waiter.set_result(None) + self.logger.info("server closed") + async def wait_closed(self) -> None: """ Wait until the server is closed. @@ -1008,7 +1031,7 @@ def __init__( else: warnings.warn("remove loop argument", DeprecationWarning) - ws_server = WebSocketServer(loop=loop) + ws_server = WebSocketServer(logger=logger, loop=loop) secure = kwargs.get("ssl") is not None From 848e0500a61e14180cebd8ad2eadaf9c9d3e4176 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Jun 2021 07:45:10 +0200 Subject: [PATCH 060/762] Simplify logging of ping/pong. Focus on keepalive ping/pongs, as they're the only ones that users run in trouble with. The readable logs should do the rest. Ref #765. --- src/websockets/legacy/protocol.py | 34 ++++--------------------------- 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 90683aaf9..6a6a00012 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -933,21 +933,12 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PING: # Answer pings. - ping_hex = frame.data.hex() or "[empty]" - if self.debug: - self.logger.debug("%% received ping, sending pong: %s", ping_hex) await self.pong(frame.data) elif frame.opcode == OP_PONG: - # Acknowledge pings on solicited pongs. if frame.data in self.pings: - if self.debug: - self.logger.debug( - "%% received solicited pong: %s", - frame.data.hex() or "[empty]", - ) - # Acknowledge all pings up to the one matching this pong. # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] for ping_id, ping in self.pings.items(): @@ -961,25 +952,6 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: # Remove acknowledged pings from self.pings. for ping_id in ping_ids: del self.pings[ping_id] - # Log previous pings acknowledged. - if self.debug: - ping_ids = ping_ids[:-1] - if ping_ids: - pings_hex = ", ".join( - ping_id.hex() or "[empty]" for ping_id in ping_ids - ) - plural = "s" if len(ping_ids) > 1 else "" - self.logger.debug( - "%% acknowledged previous ping%s: %s", - plural, - pings_hex, - ) - else: - if self.debug: - self.logger.debug( - "%% received unsolicited pong: %s", - frame.data.hex() or "[empty]", - ) # 5.6. Data Frames else: @@ -1078,6 +1050,7 @@ async def keepalive_ping(self) -> None: # ping() raises ConnectionClosed if the connection is lost, # when connection_lost() calls abort_pings(). + self.logger.debug("%% sending keepalive ping") pong_waiter = await self.ping() if self.ping_timeout is not None: @@ -1087,9 +1060,10 @@ async def keepalive_ping(self) -> None: self.ping_timeout, **loop_if_py_lt_38(self.loop), ) + self.logger.debug("%% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for pong") + self.logger.debug("! timed out waiting for keepalive pong") self.fail_connection(1011) break From 2883ca8a41bb4ce2a933e58a1fdbf958a97ccb84 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Jun 2021 08:57:50 +0200 Subject: [PATCH 061/762] Simplify connection shutdown on handshake failure. --- src/websockets/legacy/protocol.py | 48 +++++++++++++++++-------------- src/websockets/legacy/server.py | 3 +- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 6a6a00012..7a3570139 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1127,32 +1127,38 @@ async def close_connection(self) -> None: finally: # The try/finally ensures that the transport never remains open, # even if this coroutine is canceled (for example). + await self.close_transport() - # If connection_lost() was called, the TCP connection is closed. - # However, if TLS is enabled, the transport still needs closing. - # Else asyncio complains: ResourceWarning: unclosed transport. - if self.connection_lost_waiter.done() and self.transport.is_closing(): - return + async def close_transport(self) -> None: + """ + Close the TCP connection. - # Close the TCP connection. Buffers are flushed asynchronously. - if self.debug: - self.logger.debug("x closing TCP connection") - self.transport.close() + """ + # If connection_lost() was called, the TCP connection is closed. + # However, if TLS is enabled, the transport still needs closing. + # Else asyncio complains: ResourceWarning: unclosed transport. + if self.connection_lost_waiter.done() and self.transport.is_closing(): + return - if await self.wait_for_connection_lost(): - return - if self.debug: - self.logger.debug("! timed out waiting for TCP close") + # Close the TCP connection. Buffers are flushed asynchronously. + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() - # Abort the TCP connection. Buffers are discarded. - if self.debug: - self.logger.debug("x aborting TCP connection") - self.transport.abort() + if await self.wait_for_connection_lost(): + return + if self.debug: + self.logger.debug("! timed out waiting for TCP close") + + # Abort the TCP connection. Buffers are discarded. + if self.debug: + self.logger.debug("x aborting TCP connection") + self.transport.abort() - # connection_lost() is called quickly after aborting. - # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. - await self.wait_for_connection_lost() # pragma: no cover + # connection_lost() is called quickly after aborting. + # Coverage marks this line as a partially executed branch. + # I supect a bug in coverage. Ignore it for now. + await self.wait_for_connection_lost() # pragma: no cover async def wait_for_connection_lost(self) -> bool: """ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 67fffff2b..1704ae083 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -291,8 +291,7 @@ async def handler(self) -> None: self.logger.info( "connection failed (%d %s)", status.value, status.phrase ) - self.fail_connection() - await self.wait_closed() + await self.close_transport() return try: From 6722857ebefd365c9299a79e1526644e25fe1015 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 06:37:52 +0200 Subject: [PATCH 062/762] Make it possibly to bypass the handshake. This can facilitate integration in other projects that bring their own HTTP stack. --- src/websockets/client.py | 51 +++++++++++++++++++++------------------- src/websockets/server.py | 21 ++++++++++------- tests/test_client.py | 7 ++++++ tests/test_server.py | 7 ++++++ 4 files changed, 53 insertions(+), 33 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 738b6c998..698228d3a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ import collections from typing import Generator, List, Optional, Sequence -from .connection import CLIENT, CONNECTING, OPEN, Connection +from .connection import CLIENT, CONNECTING, OPEN, Connection, State from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( InvalidHandshake, @@ -50,12 +50,13 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + state: State = CONNECTING, max_size: Optional[int] = 2 ** 20, logger: Optional[LoggerLike] = None, ): super().__init__( side=CLIENT, - state=CONNECTING, + state=state, max_size=max_size, logger=logger, ) @@ -283,27 +284,29 @@ def send_request(self, request: Request) -> None: self.writes.append(request.serialize()) def parse(self) -> Generator[None, None, None]: - response = yield from Response.parse( - self.reader.read_line, - self.reader.read_exact, - self.reader.read_to_eof, - ) + if self.state is CONNECTING: + response = yield from Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + ) + + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("< HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + if response.body is not None: + self.logger.debug("< [body] (%d bytes)", len(response.body)) + + try: + self.process_response(response) + except InvalidHandshake as exc: + response = response._replace(exception=exc) + else: + assert self.state == CONNECTING + self.set_state(OPEN) + finally: + self.events.append(response) - if self.debug: - code, phrase = response.status_code, response.reason_phrase - self.logger.debug("< HTTP/1.1 %d %s", code, phrase) - for key, value in response.headers.raw_items(): - self.logger.debug("< %s: %s", key, value) - if response.body is not None: - self.logger.debug("< [body] (%d bytes)", len(response.body)) - - assert self.state == CONNECTING - try: - self.process_response(response) - except InvalidHandshake as exc: - response = response._replace(exception=exc) - else: - self.set_state(OPEN) - finally: - self.events.append(response) yield from super().parse() diff --git a/src/websockets/server.py b/src/websockets/server.py index 87aa07b51..1ce943891 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import http from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast -from .connection import CONNECTING, OPEN, SERVER, Connection +from .connection import CONNECTING, OPEN, SERVER, Connection, State from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( InvalidHandshake, @@ -56,12 +56,13 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, + state: State = CONNECTING, max_size: Optional[int] = 2 ** 20, logger: Optional[LoggerLike] = None, ): super().__init__( side=SERVER, - state=CONNECTING, + state=state, max_size=max_size, logger=logger, ) @@ -417,6 +418,7 @@ def send_response(self, response: Response) -> None: """ if response.status_code == 101: + assert self.state is CONNECTING self.set_state(OPEN) if self.debug: @@ -430,13 +432,14 @@ def send_response(self, response: Response) -> None: self.writes.append(response.serialize()) def parse(self) -> Generator[None, None, None]: - request = yield from Request.parse(self.reader.read_line) + if self.state == CONNECTING: + request = yield from Request.parse(self.reader.read_line) - if self.debug: - self.logger.debug("< GET %s HTTP/1.1", request.path) - for key, value in request.headers.raw_items(): - self.logger.debug("< %s: %s", key, value) + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + + self.events.append(request) - assert self.state == CONNECTING - self.events.append(request) yield from super().parse() diff --git a/tests/test_client.py b/tests/test_client.py index b96ebd272..840b7148f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,6 +6,7 @@ from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers from websockets.exceptions import InvalidHandshake, InvalidHeader +from websockets.frames import OP_TEXT, Frame from websockets.http import USER_AGENT from websockets.http11 import Request, Response from websockets.utils import accept_key @@ -572,6 +573,12 @@ def test_unsupported_subprotocol(self): class MiscTests(unittest.TestCase): + def test_bypass_handshake(self): + client = ClientConnection("ws://example.com/test", state=OPEN) + client.receive_data(b"\x81\x06Hello!") + [frame] = client.events_received() + self.assertEqual(frame, Frame(True, OP_TEXT, b"Hello!")) + def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: diff --git a/tests/test_server.py b/tests/test_server.py index 6a25f0d25..db34e6ba3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -6,6 +6,7 @@ from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade +from websockets.frames import OP_TEXT, Frame from websockets.http import USER_AGENT from websockets.http11 import Request, Response from websockets.server import * @@ -629,6 +630,12 @@ def test_extra_headers_overrides_server(self): class MiscTests(unittest.TestCase): + def test_bypass_handshake(self): + server = ServerConnection(state=OPEN) + server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") + [frame] = server.events_received() + self.assertEqual(frame, Frame(True, OP_TEXT, b"Hello!")) + def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: From 0cce70a8645abeac25cae59eac6bd0fb008afa38 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 09:26:10 +0200 Subject: [PATCH 063/762] Add error reason in close on ping timeout. Fix #636. --- src/websockets/legacy/protocol.py | 2 +- tests/legacy/test_protocol.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7a3570139..992d678e2 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1064,7 +1064,7 @@ async def keepalive_ping(self) -> None: except asyncio.TimeoutError: if self.debug: self.logger.debug("! timed out waiting for keepalive pong") - self.fail_connection(1011) + self.fail_connection(1011, "keepalive ping timeout") break # Remove this branch when dropping support for Python < 3.8 diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 58444ce5a..6f6e6f686 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1143,7 +1143,9 @@ def test_keepalive_ping_not_acknowledged_closes_connection(self): # Connection is closed at 6ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(1011, "")) + self.assertOneFrameSent( + True, OP_CLOSE, serialize_close(1011, "keepalive ping timeout") + ) # The keepalive ping task is complete. self.assertEqual(self.protocol.keepalive_ping_task.result(), None) From f447a5699a2c6b85f6d6fbcba10982aa3e8d5afd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 07:12:51 +0200 Subject: [PATCH 064/762] Convert WebSocketURI to a dataclass. --- src/websockets/uri.py | 20 +++++--------------- tests/test_uri.py | 31 ++++++++++++++++++++++++------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index ce21b445b..958975b22 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -7,8 +7,9 @@ """ +import dataclasses import urllib.parse -from typing import NamedTuple, Optional, Tuple +from typing import Optional, Tuple from .exceptions import InvalidURI @@ -16,10 +17,8 @@ __all__ = ["parse_uri", "WebSocketURI"] -# Consider converting to a dataclass when dropping support for Python < 3.7. - - -class WebSocketURI(NamedTuple): +@dataclasses.dataclass +class WebSocketURI: """ WebSocket URI. @@ -37,16 +36,7 @@ class WebSocketURI(NamedTuple): host: str port: int resource_name: str - user_info: Optional[Tuple[str, str]] - - -# Work around https://bugs.python.org/issue19931 - -WebSocketURI.secure.__doc__ = "" -WebSocketURI.host.__doc__ = "" -WebSocketURI.port.__doc__ = "" -WebSocketURI.resource_name.__doc__ = "" -WebSocketURI.user_info.__doc__ = "" + user_info: Optional[Tuple[str, str]] = None # All characters from the gen-delims and sub-delims sets in RFC 3987. diff --git a/tests/test_uri.py b/tests/test_uri.py index 9eeb8431d..a91bcb083 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -5,15 +5,32 @@ VALID_URIS = [ - ("ws://localhost/", (False, "localhost", 80, "/", None)), - ("wss://localhost/", (True, "localhost", 443, "/", None)), - ("ws://localhost/path?query", (False, "localhost", 80, "/path?query", None)), - ("WS://LOCALHOST/PATH?QUERY", (False, "localhost", 80, "/PATH?QUERY", None)), - ("ws://user:pass@localhost/", (False, "localhost", 80, "/", ("user", "pass"))), - ("ws://høst/", (False, "xn--hst-0na", 80, "/", None)), + ( + "ws://localhost/", + WebSocketURI(False, "localhost", 80, "/", None), + ), + ( + "wss://localhost/", + WebSocketURI(True, "localhost", 443, "/", None), + ), + ( + "ws://localhost/path?query", + WebSocketURI(False, "localhost", 80, "/path?query", None), + ), + ( + "WS://LOCALHOST/PATH?QUERY", + WebSocketURI(False, "localhost", 80, "/PATH?QUERY", None), + ), + ( + "ws://user:pass@localhost/", + WebSocketURI(False, "localhost", 80, "/", ("user", "pass")), + ), + ("ws://høst/", WebSocketURI(False, "xn--hst-0na", 80, "/", None)), ( "ws://üser:påss@høst/πass", - (False, "xn--hst-0na", 80, "/%CF%80ass", ("%C3%BCser", "p%C3%A5ss")), + WebSocketURI( + False, "xn--hst-0na", 80, "/%CF%80ass", ("%C3%BCser", "p%C3%A5ss") + ), ), ] From 5dd1290f202a8274e68d48582b400013fb953010 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 07:33:21 +0200 Subject: [PATCH 065/762] Convert Request/Response to dataclasses. --- src/websockets/client.py | 2 +- src/websockets/http11.py | 18 +++++++++--------- src/websockets/server.py | 15 ++++++++------- tests/test_server.py | 32 ++++++++++++++++---------------- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 698228d3a..e9bc12cbe 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -302,7 +302,7 @@ def parse(self) -> Generator[None, None, None]: try: self.process_response(response) except InvalidHandshake as exc: - response = response._replace(exception=exc) + response.exception = exc else: assert self.state == CONNECTING self.set_state(OPEN) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 6f3cbccc4..22488dc89 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -1,5 +1,6 @@ +import dataclasses import re -from typing import Callable, Generator, NamedTuple, Optional +from typing import Callable, Generator, Optional from .datastructures import Headers from .exceptions import SecurityError @@ -37,10 +38,8 @@ def d(value: bytes) -> str: _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") -# Consider converting to dataclasses when dropping support for Python < 3.7. - - -class Request(NamedTuple): +@dataclasses.dataclass +class Request: """ WebSocket handshake request. @@ -52,6 +51,9 @@ class Request(NamedTuple): headers: Headers # body isn't useful is the context of this library + # If processing the request triggers an exception, it's stored here. + exception: Optional[Exception] = None + @classmethod def parse( cls, read_line: Callable[[], Generator[None, None, bytes]] @@ -121,10 +123,8 @@ def serialize(self) -> bytes: return request -# Consider converting to dataclasses when dropping support for Python < 3.7. - - -class Response(NamedTuple): +@dataclasses.dataclass +class Response: """ WebSocket handshake response. diff --git a/src/websockets/server.py b/src/websockets/server.py index 1ce943891..483ddbe1e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -79,19 +79,18 @@ def accept(self, request: Request) -> Response: connection, which may be unexpected. """ - # TODO: when changing Request to a dataclass, set the exception - # attribute on the request rather than the Response, which will - # be semantically more correct. try: key, extensions_header, protocol_header = self.process_request(request) except InvalidOrigin as exc: + request.exception = exc if self.debug: self.logger.debug("! invalid origin", exc_info=True) return self.reject( http.HTTPStatus.FORBIDDEN, f"Failed to open a WebSocket connection: {exc}.\n", - )._replace(exception=exc) + ) except InvalidUpgrade as exc: + request.exception = exc if self.debug: self.logger.debug("! invalid upgrade", exc_info=True) return self.reject( @@ -103,15 +102,17 @@ def accept(self, request: Request) -> Response: f"with a browser. You need a WebSocket client.\n" ), headers=Headers([("Upgrade", "websocket")]), - )._replace(exception=exc) + ) except InvalidHandshake as exc: + request.exception = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc}.\n", - )._replace(exception=exc) + ) except Exception as exc: + request.exception = exc self.logger.error("opening handshake failed", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, @@ -119,7 +120,7 @@ def accept(self, request: Request) -> Response: "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ), - )._replace(exception=exc) + ) headers = Headers() diff --git a/tests/test_server.py b/tests/test_server.py index db34e6ba3..86fa3f34d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -187,7 +187,7 @@ def test_unexpected_exception(self): self.assertEqual(response.status_code, 500) with self.assertRaises(Exception) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "BOOM") def test_missing_connection(self): @@ -199,7 +199,7 @@ def test_missing_connection(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): @@ -212,7 +212,7 @@ def test_invalid_connection(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): @@ -224,7 +224,7 @@ def test_missing_upgrade(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): @@ -237,7 +237,7 @@ def test_invalid_upgrade(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_key(self): @@ -248,7 +248,7 @@ def test_missing_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") def test_multiple_key(self): @@ -259,7 +259,7 @@ def test_multiple_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: " @@ -275,7 +275,7 @@ def test_invalid_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" ) @@ -291,7 +291,7 @@ def test_truncated_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" ) @@ -304,7 +304,7 @@ def test_missing_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") def test_multiple_version(self): @@ -315,7 +315,7 @@ def test_multiple_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: " @@ -331,7 +331,7 @@ def test_invalid_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: 11" ) @@ -343,7 +343,7 @@ def test_no_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise response.exception + raise request.exception self.assertEqual(str(raised.exception), "missing Origin header") def test_origin(self): @@ -363,7 +363,7 @@ def test_unexpected_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Origin header: https://other.example.com" ) @@ -381,7 +381,7 @@ def test_multiple_origin(self): # 400 Bad Request rather than 403 Forbidden. self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Origin header: more than one Origin header found", @@ -408,7 +408,7 @@ def test_unsupported_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise response.exception + raise request.exception self.assertEqual( str(raised.exception), "invalid Origin header: https://original.example.com" ) From 634ee6a29327146d317564c88010310556c4c384 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 22:30:49 +0200 Subject: [PATCH 066/762] Wrap NewFrame instead of inheriting it. This ensures Frame remains a NamedTuple for backwards compatibility. --- src/websockets/legacy/framing.py | 46 +++++++++++++++++++++++++++----- tests/legacy/test_framing.py | 1 + 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index e947c9383..016677615 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -11,7 +11,7 @@ """ import struct -from typing import Any, Awaitable, Callable, Optional, Sequence +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence from ..exceptions import PayloadTooBig, ProtocolError from ..frames import Frame as NewFrame, Opcode @@ -23,7 +23,32 @@ from ..utils import apply_mask -class Frame(NewFrame): +class Frame(NamedTuple): + + fin: bool + opcode: Opcode + data: bytes + rsv1: bool = False + rsv2: bool = False + rsv3: bool = False + + @property + def new_frame(self) -> NewFrame: + return NewFrame( + self.fin, + self.opcode, + self.data, + self.rsv1, + self.rsv2, + self.rsv3, + ) + + def __str__(self) -> str: + return str(self.new_frame) + + def check(self) -> None: + return self.new_frame.check() + @classmethod async def read( cls, @@ -86,16 +111,23 @@ async def read( if mask: data = apply_mask(data, mask_bits) - frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) + new_frame = NewFrame(fin, opcode, data, rsv1, rsv2, rsv3) if extensions is None: extensions = [] for extension in reversed(extensions): - frame = cls(*extension.decode(frame, max_size=max_size)) + new_frame = extension.decode(new_frame, max_size=max_size) - frame.check() + new_frame.check() - return frame + return cls( + new_frame.fin, + new_frame.opcode, + new_frame.data, + new_frame.rsv1, + new_frame.rsv2, + new_frame.rsv3, + ) def write( self, @@ -121,7 +153,7 @@ def write( # The frame is written in a single call to write in order to prevent # TCP fragmentation. See #68 for details. This also makes it safe to # send frames concurrently from multiple coroutines. - write(self.serialize(mask=mask, extensions=extensions)) + write(self.new_frame.serialize(mask=mask, extensions=extensions)) # at the bottom to allow circular import, because Extension depends on Frame diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index ac870c79e..0e2670faa 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -43,6 +43,7 @@ def encode(self, frame, mask=False, extensions=None): def round_trip(self, message, expected, mask=False, extensions=None): decoded = self.decode(message, mask, extensions=extensions) + decoded.check() self.assertEqual(decoded, expected) encoded = self.encode(decoded, mask, extensions=extensions) if mask: # non-deterministic encoding From d500274f9340b4336d27665a48e9ce34f23c9910 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 09:16:21 +0200 Subject: [PATCH 067/762] Convert Frame to a dataclass. --- .../extensions/permessage_deflate.py | 9 +++--- src/websockets/frames.py | 9 +++--- tests/extensions/test_permessage_deflate.py | 29 +++++++++++++------ tests/extensions/utils.py | 6 ++-- tests/legacy/test_framing.py | 3 +- tests/test_frames.py | 3 +- 6 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 34cc1f950..56ef03e0c 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -4,6 +4,7 @@ """ +import dataclasses import zlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -115,7 +116,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: else: if not frame.rsv1: return frame - frame = frame._replace(rsv1=False) + frame = dataclasses.replace(frame, rsv1=False) if not frame.fin: self.decode_cont_data = True @@ -138,7 +139,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: if frame.fin and self.remote_no_context_takeover: del self.decoder - return frame._replace(data=data) + return dataclasses.replace(frame, data=data) def encode(self, frame: Frame) -> Frame: """ @@ -154,7 +155,7 @@ def encode(self, frame: Frame) -> Frame: if frame.opcode != OP_CONT: # Set the rsv1 flag on the first frame of a compressed message. - frame = frame._replace(rsv1=True) + frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( @@ -170,7 +171,7 @@ def encode(self, frame: Frame) -> Frame: if frame.fin and self.local_no_context_takeover: del self.encoder - return frame._replace(data=data) + return dataclasses.replace(frame, data=data) def _build_parameters( diff --git a/src/websockets/frames.py b/src/websockets/frames.py index de7fdb941..79de63e47 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -3,11 +3,12 @@ """ +import dataclasses import enum import io import secrets import struct -from typing import Callable, Generator, NamedTuple, Optional, Sequence, Tuple +from typing import Callable, Generator, Optional, Sequence, Tuple from .exceptions import PayloadTooBig, ProtocolError from .typing import Data @@ -92,10 +93,8 @@ class Opcode(enum.IntEnum): } -# Consider converting to a dataclass when dropping support for Python < 3.7. - - -class Frame(NamedTuple): +@dataclasses.dataclass +class Frame: """ WebSocket frame. diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index d3c6f0ac6..2af04a806 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,3 +1,4 @@ +import dataclasses import unittest import zlib @@ -82,7 +83,9 @@ def test_encode_decode_text_frame(self): enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"JNL;\xbc\x12\x00")) + self.assertEqual( + enc_frame, dataclasses.replace(frame, rsv1=True, data=b"JNL;\xbc\x12\x00") + ) dec_frame = self.extension.decode(enc_frame) @@ -93,7 +96,9 @@ def test_encode_decode_binary_frame(self): enc_frame = self.extension.encode(frame) - self.assertEqual(enc_frame, frame._replace(rsv1=True, data=b"*IM\x04\x00")) + self.assertEqual( + enc_frame, dataclasses.replace(frame, rsv1=True, data=b"*IM\x04\x00") + ) dec_frame = self.extension.decode(enc_frame) @@ -110,13 +115,15 @@ def test_encode_decode_fragmented_text_frame(self): self.assertEqual( enc_frame1, - frame1._replace(rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff"), + dataclasses.replace( + frame1, rsv1=True, data=b"JNL;\xbc\x12\x00\x00\x00\xff\xff" + ), ) self.assertEqual( - enc_frame2, frame2._replace(data=b"RPS\x00\x00\x00\x00\xff\xff") + enc_frame2, dataclasses.replace(frame2, data=b"RPS\x00\x00\x00\x00\xff\xff") ) self.assertEqual( - enc_frame3, frame3._replace(data=b"J.\xca\xcf,.N\xcc+)\x06\x00") + enc_frame3, dataclasses.replace(frame3, data=b"J.\xca\xcf,.N\xcc+)\x06\x00") ) dec_frame1 = self.extension.decode(enc_frame1) @@ -136,11 +143,13 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual( enc_frame1, - frame1._replace(rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff"), + dataclasses.replace( + frame1, rsv1=True, data=b"*IMT\x00\x00\x00\x00\xff\xff" + ), ) self.assertEqual( enc_frame2, - frame2._replace(data=b"*\xc9\xccM\x05\x00"), + dataclasses.replace(frame2, data=b"*\xc9\xccM\x05\x00"), ) dec_frame1 = self.extension.decode(enc_frame1) @@ -242,8 +251,10 @@ def test_compress_settings(self): self.assertEqual( enc_frame, - frame._replace( - rsv1=True, data=b"\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00" # not compressed + dataclasses.replace( + frame, + rsv1=True, + data=b"\x00\x05\x00\xfa\xffcaf\xc3\xa9\x00", # not compressed ), ) diff --git a/tests/extensions/utils.py b/tests/extensions/utils.py index 81990bb07..1eabc163f 100644 --- a/tests/extensions/utils.py +++ b/tests/extensions/utils.py @@ -1,3 +1,5 @@ +import dataclasses + from websockets.exceptions import NegotiationError @@ -49,11 +51,11 @@ class Rsv2Extension: def decode(self, frame, *, max_size=None): assert frame.rsv2 - return frame._replace(rsv2=False) + return dataclasses.replace(frame, rsv2=False) def encode(self, frame): assert not frame.rsv2 - return frame._replace(rsv2=True) + return dataclasses.replace(frame, rsv2=True) def __eq__(self, other): return isinstance(other, Rsv2Extension) diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 0e2670faa..2baa827a9 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -1,5 +1,6 @@ import asyncio import codecs +import dataclasses import unittest import unittest.mock import warnings @@ -160,7 +161,7 @@ def encode(frame): assert frame.opcode == OP_TEXT text = frame.data.decode() data = codecs.encode(text, "rot13").encode() - return frame._replace(data=data) + return dataclasses.replace(frame, data=data) # This extensions is symmetrical. @staticmethod diff --git a/tests/test_frames.py b/tests/test_frames.py index 491386566..c05fa43a5 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -1,4 +1,5 @@ import codecs +import dataclasses import unittest import unittest.mock @@ -178,7 +179,7 @@ def encode(frame): assert frame.opcode == OP_TEXT text = frame.data.decode() data = codecs.encode(text, "rot13").encode() - return frame._replace(data=data) + return dataclasses.replace(frame, data=data) # This extensions is symmetrical. @staticmethod From ba4be454b6fd7f283d2ea8ec4354e2e6c9627267 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 22:20:48 +0200 Subject: [PATCH 068/762] Give FIN a default value of False. This makes the default case, non fragmented messages, convenient. Also it uniformises with rsv1/2/3. --- src/websockets/connection.py | 18 +- src/websockets/frames.py | 4 +- src/websockets/legacy/framing.py | 4 +- tests/extensions/test_permessage_deflate.py | 44 ++--- tests/test_client.py | 2 +- tests/test_connection.py | 186 ++++++++++---------- tests/test_frames.py | 64 +++---- tests/test_server.py | 2 +- 8 files changed, 161 insertions(+), 163 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 5052df3f7..dbdc15bb5 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -166,7 +166,7 @@ def send_continuation(self, data: bytes, fin: bool) -> None: if not self.expect_continuation_frame: raise ProtocolError("unexpected continuation frame") self.expect_continuation_frame = not fin - self.send_frame(Frame(fin, OP_CONT, data)) + self.send_frame(Frame(OP_CONT, data, fin)) def send_text(self, data: bytes, fin: bool = True) -> None: """ @@ -176,7 +176,7 @@ def send_text(self, data: bytes, fin: bool = True) -> None: if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") self.expect_continuation_frame = not fin - self.send_frame(Frame(fin, OP_TEXT, data)) + self.send_frame(Frame(OP_TEXT, data, fin)) def send_binary(self, data: bytes, fin: bool = True) -> None: """ @@ -186,7 +186,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") self.expect_continuation_frame = not fin - self.send_frame(Frame(fin, OP_BINARY, data)) + self.send_frame(Frame(OP_BINARY, data, fin)) def send_close(self, code: Optional[int] = None, reason: str = "") -> None: """ @@ -201,7 +201,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: data = b"" else: data = serialize_close(code, reason) - self.send_frame(Frame(True, OP_CLOSE, data)) + self.send_frame(Frame(OP_CLOSE, data)) # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started self.set_state(CLOSING) @@ -213,14 +213,14 @@ def send_ping(self, data: bytes) -> None: Send a ping frame. """ - self.send_frame(Frame(True, OP_PING, data)) + self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: """ Send a pong frame. """ - self.send_frame(Frame(True, OP_PONG, data)) + self.send_frame(Frame(OP_PONG, data)) # Public API for getting incoming events after receiving data. @@ -257,7 +257,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. if code != 1006 and self.state is OPEN: - self.send_frame(Frame(True, OP_CLOSE, serialize_close(code, reason))) + self.send_frame(Frame(OP_CLOSE, serialize_close(code, reason))) self.set_state(CLOSING) if not self.eof_sent: self.send_eof() @@ -365,7 +365,7 @@ def parse(self) -> Generator[None, None, None]: # send a Pong frame in response, unless it already received a # Close frame." if not self.close_frame_received: - pong_frame = Frame(True, OP_PONG, frame.data) + pong_frame = Frame(OP_PONG, frame.data) self.send_frame(pong_frame) elif frame.opcode is OP_PONG: @@ -391,7 +391,7 @@ def parse(self) -> Generator[None, None, None]: # serialize_close() because that fails when the close frame # is empty and parse_close() synthetizes a 1005 close code. # The rest is identical to send_close(). - self.send_frame(Frame(True, OP_CLOSE, frame.data)) + self.send_frame(Frame(OP_CLOSE, frame.data)) self.set_state(CLOSING) if self.side is SERVER: self.send_eof() diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 79de63e47..ec6ff2258 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -110,9 +110,9 @@ class Frame: """ - fin: bool opcode: Opcode data: bytes + fin: bool = True rsv1: bool = False rsv2: bool = False rsv3: bool = False @@ -224,7 +224,7 @@ def parse( if mask: data = apply_mask(data, mask_bytes) - frame = cls(fin, opcode, data, rsv1, rsv2, rsv3) + frame = cls(opcode, data, fin, rsv1, rsv2, rsv3) if extensions is None: extensions = [] diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 016677615..901b2e2e3 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -35,9 +35,9 @@ class Frame(NamedTuple): @property def new_frame(self) -> NewFrame: return NewFrame( - self.fin, self.opcode, self.data, + self.fin, self.rsv1, self.rsv2, self.rsv3, @@ -111,7 +111,7 @@ async def read( if mask: data = apply_mask(data, mask_bits) - new_frame = NewFrame(fin, opcode, data, rsv1, rsv2, rsv3) + new_frame = NewFrame(opcode, data, fin, rsv1, rsv2, rsv3) if extensions is None: extensions = [] diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 2af04a806..5ba4d8ddf 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -56,21 +56,21 @@ def test_repr(self): # Control frames aren't encoded or decoded. def test_no_encode_decode_ping_frame(self): - frame = Frame(True, OP_PING, b"") + frame = Frame(OP_PING, b"") self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_pong_frame(self): - frame = Frame(True, OP_PONG, b"") + frame = Frame(OP_PONG, b"") self.assertEqual(self.extension.encode(frame), frame) self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_close_frame(self): - frame = Frame(True, OP_CLOSE, serialize_close(1000, "")) + frame = Frame(OP_CLOSE, serialize_close(1000, "")) self.assertEqual(self.extension.encode(frame), frame) @@ -79,7 +79,7 @@ def test_no_encode_decode_close_frame(self): # Data frames are encoded and decoded. def test_encode_decode_text_frame(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame = self.extension.encode(frame) @@ -92,7 +92,7 @@ def test_encode_decode_text_frame(self): self.assertEqual(dec_frame, frame) def test_encode_decode_binary_frame(self): - frame = Frame(True, OP_BINARY, b"tea") + frame = Frame(OP_BINARY, b"tea") enc_frame = self.extension.encode(frame) @@ -105,9 +105,9 @@ def test_encode_decode_binary_frame(self): self.assertEqual(dec_frame, frame) def test_encode_decode_fragmented_text_frame(self): - frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) - frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) - frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) + frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) @@ -135,8 +135,8 @@ def test_encode_decode_fragmented_text_frame(self): self.assertEqual(dec_frame3, frame3) def test_encode_decode_fragmented_binary_frame(self): - frame1 = Frame(False, OP_TEXT, b"tea ") - frame2 = Frame(True, OP_CONT, b"time") + frame1 = Frame(OP_TEXT, b"tea ", fin=False) + frame2 = Frame(OP_CONT, b"time") enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) @@ -159,21 +159,21 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_no_decode_text_frame(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_binary_frame(self): - frame = Frame(True, OP_TEXT, b"tea") + frame = Frame(OP_TEXT, b"tea") # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_fragmented_text_frame(self): - frame1 = Frame(False, OP_TEXT, "café".encode("utf-8")) - frame2 = Frame(False, OP_CONT, " & ".encode("utf-8")) - frame3 = Frame(True, OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) + frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -184,8 +184,8 @@ def test_no_decode_fragmented_text_frame(self): self.assertEqual(dec_frame3, frame3) def test_no_decode_fragmented_binary_frame(self): - frame1 = Frame(False, OP_TEXT, b"tea ") - frame2 = Frame(True, OP_CONT, b"time") + frame1 = Frame(OP_TEXT, b"tea ", fin=False) + frame2 = Frame(OP_CONT, b"time") dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -194,7 +194,7 @@ def test_no_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_context_takeover(self): - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -206,7 +206,7 @@ def test_remote_no_context_takeover(self): # No context takeover when decoding messages. self.extension = PerMessageDeflate(True, False, 15, 15) - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -225,7 +225,7 @@ def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. self.extension = PerMessageDeflate(True, True, 15, 15) - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -245,7 +245,7 @@ def test_compress_settings(self): # Configure an extension so that no compression actually occurs. extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) - frame = Frame(True, OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode("utf-8")) enc_frame = extension.encode(frame) @@ -261,7 +261,7 @@ def test_compress_settings(self): # Frames aren't decoded beyond max_size. def test_decompress_max_size(self): - frame = Frame(True, OP_TEXT, ("a" * 20).encode("utf-8")) + frame = Frame(OP_TEXT, ("a" * 20).encode("utf-8")) enc_frame = self.extension.encode(frame) diff --git a/tests/test_client.py b/tests/test_client.py index 840b7148f..2ef1f6a95 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -577,7 +577,7 @@ def test_bypass_handshake(self): client = ClientConnection("ws://example.com/test", state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() - self.assertEqual(frame, Frame(True, OP_TEXT, b"Hello!")) + self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): logger = logging.getLogger("test") diff --git a/tests/test_connection.py b/tests/test_connection.py index 6203a1469..677881238 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -59,7 +59,6 @@ def assertConnectionClosing(self, connection, code=None, reason=""): """ close_frame = Frame( - True, OP_CLOSE, b"" if code is None else serialize_close(code, reason), ) @@ -76,7 +75,6 @@ def assertConnectionFailing(self, connection, code=None, reason=""): """ close_frame = Frame( - True, OP_CLOSE, b"" if code is None else serialize_close(code, reason), ) @@ -113,7 +111,7 @@ def test_client_receives_unmasked_frame(self): client.receive_data(self.unmasked_text_frame_date) self.assertFrameReceived( client, - Frame(True, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam"), ) def test_server_receives_masked_frame(self): @@ -121,7 +119,7 @@ def test_server_receives_masked_frame(self): server.receive_data(self.masked_text_frame_data) self.assertFrameReceived( server, - Frame(True, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam"), ) def test_client_receives_masked_frame(self): @@ -235,7 +233,7 @@ def test_client_receives_text(self): client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, - Frame(True, OP_TEXT, "😀".encode()), + Frame(OP_TEXT, "😀".encode()), ) def test_server_receives_text(self): @@ -243,7 +241,7 @@ def test_server_receives_text(self): server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, - Frame(True, OP_TEXT, "😀".encode()), + Frame(OP_TEXT, "😀".encode()), ) def test_client_receives_text_over_size_limit(self): @@ -265,7 +263,7 @@ def test_client_receives_text_without_size_limit(self): client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, - Frame(True, OP_TEXT, "😀".encode()), + Frame(OP_TEXT, "😀".encode()), ) def test_server_receives_text_without_size_limit(self): @@ -273,7 +271,7 @@ def test_server_receives_text_without_size_limit(self): server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, - Frame(True, OP_TEXT, "😀".encode()), + Frame(OP_TEXT, "😀".encode()), ) def test_client_sends_fragmented_text(self): @@ -304,17 +302,17 @@ def test_client_receives_fragmented_text(self): client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_CONT, "😀😀".encode()[2:6]), + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) client.receive_data(b"\x80\x02\x98\x80") self.assertFrameReceived( client, - Frame(True, OP_CONT, "😀".encode()[2:]), + Frame(OP_CONT, "😀".encode()[2:]), ) def test_server_receives_fragmented_text(self): @@ -322,17 +320,17 @@ def test_server_receives_fragmented_text(self): server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_CONT, "😀😀".encode()[2:6]), + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertFrameReceived( server, - Frame(True, OP_CONT, "😀".encode()[2:]), + Frame(OP_CONT, "😀".encode()[2:]), ) def test_client_receives_fragmented_text_over_size_limit(self): @@ -340,7 +338,7 @@ def test_client_receives_fragmented_text_over_size_limit(self): client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) with self.assertRaises(PayloadTooBig) as raised: client.receive_data(b"\x80\x02\x98\x80") @@ -352,7 +350,7 @@ def test_server_receives_fragmented_text_over_size_limit(self): server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) with self.assertRaises(PayloadTooBig) as raised: server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") @@ -364,17 +362,17 @@ def test_client_receives_fragmented_text_without_size_limit(self): client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") self.assertFrameReceived( client, - Frame(False, OP_CONT, "😀😀".encode()[2:6]), + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) client.receive_data(b"\x80\x02\x98\x80") self.assertFrameReceived( client, - Frame(True, OP_CONT, "😀".encode()[2:]), + Frame(OP_CONT, "😀".encode()[2:]), ) def test_server_receives_fragmented_text_without_size_limit(self): @@ -382,17 +380,17 @@ def test_server_receives_fragmented_text_without_size_limit(self): server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_TEXT, "😀".encode()[:2]), + Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") self.assertFrameReceived( server, - Frame(False, OP_CONT, "😀😀".encode()[2:6]), + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertFrameReceived( server, - Frame(True, OP_CONT, "😀".encode()[2:]), + Frame(OP_CONT, "😀".encode()[2:]), ) def test_client_sends_unexpected_text(self): @@ -414,7 +412,7 @@ def test_client_receives_unexpected_text(self): client.receive_data(b"\x01\x00") self.assertFrameReceived( client, - Frame(False, OP_TEXT, b""), + Frame(OP_TEXT, b"", fin=False), ) with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\x01\x00") @@ -426,7 +424,7 @@ def test_server_receives_unexpected_text(self): server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertFrameReceived( server, - Frame(False, OP_TEXT, b""), + Frame(OP_TEXT, b"", fin=False), ) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\x01\x80\x00\x00\x00\x00") @@ -489,7 +487,7 @@ def test_client_receives_binary(self): client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertFrameReceived( client, - Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), + Frame(OP_BINARY, b"\x01\x02\xfe\xff"), ) def test_server_receives_binary(self): @@ -497,7 +495,7 @@ def test_server_receives_binary(self): server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertFrameReceived( server, - Frame(True, OP_BINARY, b"\x01\x02\xfe\xff"), + Frame(OP_BINARY, b"\x01\x02\xfe\xff"), ) def test_client_receives_binary_over_size_limit(self): @@ -542,17 +540,17 @@ def test_client_receives_fragmented_binary(self): client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, - Frame(False, OP_BINARY, b"\x01\x02"), + Frame(OP_BINARY, b"\x01\x02", fin=False), ) client.receive_data(b"\x00\x04\xfe\xff\x01\x02") self.assertFrameReceived( client, - Frame(False, OP_CONT, b"\xfe\xff\x01\x02"), + Frame(OP_CONT, b"\xfe\xff\x01\x02", fin=False), ) client.receive_data(b"\x80\x02\xfe\xff") self.assertFrameReceived( client, - Frame(True, OP_CONT, b"\xfe\xff"), + Frame(OP_CONT, b"\xfe\xff"), ) def test_server_receives_fragmented_binary(self): @@ -560,17 +558,17 @@ def test_server_receives_fragmented_binary(self): server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, - Frame(False, OP_BINARY, b"\x01\x02"), + Frame(OP_BINARY, b"\x01\x02", fin=False), ) server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") self.assertFrameReceived( server, - Frame(False, OP_CONT, b"\xee\xff\x01\x02"), + Frame(OP_CONT, b"\xee\xff\x01\x02", fin=False), ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertFrameReceived( server, - Frame(True, OP_CONT, b"\xfe\xff"), + Frame(OP_CONT, b"\xfe\xff"), ) def test_client_receives_fragmented_binary_over_size_limit(self): @@ -578,7 +576,7 @@ def test_client_receives_fragmented_binary_over_size_limit(self): client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, - Frame(False, OP_BINARY, b"\x01\x02"), + Frame(OP_BINARY, b"\x01\x02", fin=False), ) with self.assertRaises(PayloadTooBig) as raised: client.receive_data(b"\x80\x02\xfe\xff") @@ -590,7 +588,7 @@ def test_server_receives_fragmented_binary_over_size_limit(self): server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, - Frame(False, OP_BINARY, b"\x01\x02"), + Frame(OP_BINARY, b"\x01\x02", fin=False), ) with self.assertRaises(PayloadTooBig) as raised: server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") @@ -616,7 +614,7 @@ def test_client_receives_unexpected_binary(self): client.receive_data(b"\x02\x00") self.assertFrameReceived( client, - Frame(False, OP_BINARY, b""), + Frame(OP_BINARY, b"", fin=False), ) with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\x02\x00") @@ -628,7 +626,7 @@ def test_server_receives_unexpected_binary(self): server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertFrameReceived( server, - Frame(False, OP_BINARY, b""), + Frame(OP_BINARY, b"", fin=False), ) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\x02\x80\x00\x00\x00\x00") @@ -690,14 +688,14 @@ def test_client_receives_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): client.receive_data(b"\x88\x00") - self.assertEqual(client.events_received(), [Frame(True, OP_CLOSE, b"")]) + self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, State.CLOSING) def test_server_receives_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertEqual(server.events_received(), [Frame(True, OP_CLOSE, b"")]) + self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) self.assertIs(server.state, State.CLOSING) @@ -707,10 +705,10 @@ def test_client_sends_close_then_receives_close(self): client.send_close() self.assertFrameReceived(client, None) - self.assertFrameSent(client, Frame(True, OP_CLOSE, b"")) + self.assertFrameSent(client, Frame(OP_CLOSE, b"")) client.receive_data(b"\x88\x00") - self.assertFrameReceived(client, Frame(True, OP_CLOSE, b"")) + self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) self.assertFrameSent(client, None) client.receive_eof() @@ -723,10 +721,10 @@ def test_server_sends_close_then_receives_close(self): server.send_close() self.assertFrameReceived(server, None) - self.assertFrameSent(server, Frame(True, OP_CLOSE, b""), eof=True) + self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertFrameReceived(server, Frame(True, OP_CLOSE, b"")) + self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) self.assertFrameSent(server, None) server.receive_eof() @@ -738,8 +736,8 @@ def test_client_receives_close_then_sends_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x00") - self.assertFrameReceived(client, Frame(True, OP_CLOSE, b"")) - self.assertFrameSent(client, Frame(True, OP_CLOSE, b"")) + self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) + self.assertFrameSent(client, Frame(OP_CLOSE, b"")) client.receive_eof() self.assertFrameReceived(client, None) @@ -750,8 +748,8 @@ def test_server_receives_close_then_sends_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertFrameReceived(server, Frame(True, OP_CLOSE, b"")) - self.assertFrameSent(server, Frame(True, OP_CLOSE, b""), eof=True) + self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) + self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) server.receive_eof() self.assertFrameReceived(server, None) @@ -882,11 +880,11 @@ def test_client_receives_ping(self): client.receive_data(b"\x89\x00") self.assertFrameReceived( client, - Frame(True, OP_PING, b""), + Frame(OP_PING, b""), ) self.assertFrameSent( client, - Frame(True, OP_PONG, b""), + Frame(OP_PONG, b""), ) def test_server_receives_ping(self): @@ -894,11 +892,11 @@ def test_server_receives_ping(self): server.receive_data(b"\x89\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, - Frame(True, OP_PING, b""), + Frame(OP_PING, b""), ) self.assertFrameSent( server, - Frame(True, OP_PONG, b""), + Frame(OP_PONG, b""), ) def test_client_sends_ping_with_data(self): @@ -919,11 +917,11 @@ def test_client_receives_ping_with_data(self): client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, - Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent( client, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_ping_with_data(self): @@ -931,25 +929,25 @@ def test_server_receives_ping_with_data(self): server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, - Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent( server, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_client_sends_fragmented_ping_frame(self): client = Connection(Side.CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: - client.send_frame(Frame(False, OP_PING, b"")) + client.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_ping_frame(self): server = Connection(Side.SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: - server.send_frame(Frame(False, OP_PING, b"")) + server.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_ping_frame(self): @@ -1000,7 +998,7 @@ def test_client_receives_ping_after_receiving_close(self): client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, - Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent(client, None) @@ -1011,7 +1009,7 @@ def test_server_receives_ping_after_receiving_close(self): server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, - Frame(True, OP_PING, b"\x22\x66\xaa\xee"), + Frame(OP_PING, b"\x22\x66\xaa\xee"), ) self.assertFrameSent(server, None) @@ -1038,7 +1036,7 @@ def test_client_receives_pong(self): client.receive_data(b"\x8a\x00") self.assertFrameReceived( client, - Frame(True, OP_PONG, b""), + Frame(OP_PONG, b""), ) def test_server_receives_pong(self): @@ -1046,7 +1044,7 @@ def test_server_receives_pong(self): server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, - Frame(True, OP_PONG, b""), + Frame(OP_PONG, b""), ) def test_client_sends_pong_with_data(self): @@ -1067,7 +1065,7 @@ def test_client_receives_pong_with_data(self): client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_pong_with_data(self): @@ -1075,21 +1073,21 @@ def test_server_receives_pong_with_data(self): server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_client_sends_fragmented_pong_frame(self): client = Connection(Side.CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: - client.send_frame(Frame(False, OP_PONG, b"")) + client.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_pong_frame(self): server = Connection(Side.SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: - server.send_frame(Frame(False, OP_PONG, b"")) + server.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_pong_frame(self): @@ -1130,7 +1128,7 @@ def test_client_receives_pong_after_receiving_close(self): client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) def test_server_receives_pong_after_receiving_close(self): @@ -1140,7 +1138,7 @@ def test_server_receives_pong_after_receiving_close(self): server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, - Frame(True, OP_PONG, b"\x22\x66\xaa\xee"), + Frame(OP_PONG, b"\x22\x66\xaa\xee"), ) @@ -1155,59 +1153,59 @@ class FragmentationTests(ConnectionTestCase): def test_client_send_ping_pong_in_fragmented_message(self): client = Connection(Side.CLIENT) client.send_text(b"Spam", fin=False) - self.assertFrameSent(client, Frame(False, OP_TEXT, b"Spam")) + self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) client.send_ping(b"Ping") - self.assertFrameSent(client, Frame(True, OP_PING, b"Ping")) + self.assertFrameSent(client, Frame(OP_PING, b"Ping")) client.send_continuation(b"Ham", fin=False) - self.assertFrameSent(client, Frame(False, OP_CONT, b"Ham")) + self.assertFrameSent(client, Frame(OP_CONT, b"Ham", fin=False)) client.send_pong(b"Pong") - self.assertFrameSent(client, Frame(True, OP_PONG, b"Pong")) + self.assertFrameSent(client, Frame(OP_PONG, b"Pong")) client.send_continuation(b"Eggs", fin=True) - self.assertFrameSent(client, Frame(True, OP_CONT, b"Eggs")) + self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) def test_server_send_ping_pong_in_fragmented_message(self): server = Connection(Side.SERVER) server.send_text(b"Spam", fin=False) - self.assertFrameSent(server, Frame(False, OP_TEXT, b"Spam")) + self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) server.send_ping(b"Ping") - self.assertFrameSent(server, Frame(True, OP_PING, b"Ping")) + self.assertFrameSent(server, Frame(OP_PING, b"Ping")) server.send_continuation(b"Ham", fin=False) - self.assertFrameSent(server, Frame(False, OP_CONT, b"Ham")) + self.assertFrameSent(server, Frame(OP_CONT, b"Ham", fin=False)) server.send_pong(b"Pong") - self.assertFrameSent(server, Frame(True, OP_PONG, b"Pong")) + self.assertFrameSent(server, Frame(OP_PONG, b"Pong")) server.send_continuation(b"Eggs", fin=True) - self.assertFrameSent(server, Frame(True, OP_CONT, b"Eggs")) + self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) def test_client_receive_ping_pong_in_fragmented_message(self): client = Connection(Side.CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, - Frame(False, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam", fin=False), ) client.receive_data(b"\x89\x04Ping") self.assertFrameReceived( client, - Frame(True, OP_PING, b"Ping"), + Frame(OP_PING, b"Ping"), ) self.assertFrameSent( client, - Frame(True, OP_PONG, b"Ping"), + Frame(OP_PONG, b"Ping"), ) client.receive_data(b"\x00\x03Ham") self.assertFrameReceived( client, - Frame(False, OP_CONT, b"Ham"), + Frame(OP_CONT, b"Ham", fin=False), ) client.receive_data(b"\x8a\x04Pong") self.assertFrameReceived( client, - Frame(True, OP_PONG, b"Pong"), + Frame(OP_PONG, b"Pong"), ) client.receive_data(b"\x80\x04Eggs") self.assertFrameReceived( client, - Frame(True, OP_CONT, b"Eggs"), + Frame(OP_CONT, b"Eggs"), ) def test_server_receive_ping_pong_in_fragmented_message(self): @@ -1215,37 +1213,37 @@ def test_server_receive_ping_pong_in_fragmented_message(self): server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, - Frame(False, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam", fin=False), ) server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") self.assertFrameReceived( server, - Frame(True, OP_PING, b"Ping"), + Frame(OP_PING, b"Ping"), ) self.assertFrameSent( server, - Frame(True, OP_PONG, b"Ping"), + Frame(OP_PONG, b"Ping"), ) server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") self.assertFrameReceived( server, - Frame(False, OP_CONT, b"Ham"), + Frame(OP_CONT, b"Ham", fin=False), ) server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") self.assertFrameReceived( server, - Frame(True, OP_PONG, b"Pong"), + Frame(OP_PONG, b"Pong"), ) server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") self.assertFrameReceived( server, - Frame(True, OP_CONT, b"Eggs"), + Frame(OP_CONT, b"Eggs"), ) def test_client_send_close_in_fragmented_message(self): client = Connection(Side.CLIENT) client.send_text(b"Spam", fin=False) - self.assertFrameSent(client, Frame(False, OP_TEXT, b"Spam")) + self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the # endpoint must not send a data frame after a close frame, a close @@ -1258,7 +1256,7 @@ def test_client_send_close_in_fragmented_message(self): def test_server_send_close_in_fragmented_message(self): server = Connection(Side.CLIENT) server.send_text(b"Spam", fin=False) - self.assertFrameSent(server, Frame(False, OP_TEXT, b"Spam")) + self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the # endpoint must not send a data frame after a close frame, a close @@ -1272,7 +1270,7 @@ def test_client_receive_close_in_fragmented_message(self): client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, - Frame(False, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam", fin=False), ) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the @@ -1288,7 +1286,7 @@ def test_server_receive_close_in_fragmented_message(self): server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, - Frame(False, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam", fin=False), ) # The spec says: "An endpoint MUST be capable of handling control # frames in the middle of a fragmented message." However, since the @@ -1477,10 +1475,10 @@ def test_client_extension_decodes_frame(self): client = Connection(Side.CLIENT) client.extensions = [Rsv2Extension()] client.receive_data(b"\xaa\x00") - self.assertEqual(client.events_received(), [Frame(True, OP_PONG, b"")]) + self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) def test_server_extension_decodes_frame(self): server = Connection(Side.SERVER) server.extensions = [Rsv2Extension()] server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") - self.assertEqual(server.events_received(), [Frame(True, OP_PONG, b"")]) + self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) diff --git a/tests/test_frames.py b/tests/test_frames.py index c05fa43a5..d36344639 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -48,77 +48,77 @@ def assertFrameData(self, frame, data, mask, extensions=None): class FrameTests(FramesTestCase): def test_text_unmasked(self): self.assertFrameData( - Frame(True, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam"), b"\x81\x04Spam", mask=False, ) def test_text_masked(self): self.assertFrameData( - Frame(True, OP_TEXT, b"Spam"), + Frame(OP_TEXT, b"Spam"), b"\x81\x84\x5b\xfb\xe1\xa8\x08\x8b\x80\xc5", mask=True, ) def test_binary_unmasked(self): self.assertFrameData( - Frame(True, OP_BINARY, b"Eggs"), + Frame(OP_BINARY, b"Eggs"), b"\x82\x04Eggs", mask=False, ) def test_binary_masked(self): self.assertFrameData( - Frame(True, OP_BINARY, b"Eggs"), + Frame(OP_BINARY, b"Eggs"), b"\x82\x84\x53\xcd\xe2\x89\x16\xaa\x85\xfa", mask=True, ) def test_non_ascii_text_unmasked(self): self.assertFrameData( - Frame(True, OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode("utf-8")), b"\x81\x05caf\xc3\xa9", mask=False, ) def test_non_ascii_text_masked(self): self.assertFrameData( - Frame(True, OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode("utf-8")), b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", mask=True, ) def test_close(self): self.assertFrameData( - Frame(True, OP_CLOSE, b""), + Frame(OP_CLOSE, b""), b"\x88\x00", mask=False, ) def test_ping(self): self.assertFrameData( - Frame(True, OP_PING, b"ping"), + Frame(OP_PING, b"ping"), b"\x89\x04ping", mask=False, ) def test_pong(self): self.assertFrameData( - Frame(True, OP_PONG, b"pong"), + Frame(OP_PONG, b"pong"), b"\x8a\x04pong", mask=False, ) def test_long(self): self.assertFrameData( - Frame(True, OP_BINARY, 126 * b"a"), + Frame(OP_BINARY, 126 * b"a"), b"\x82\x7e\x00\x7e" + 126 * b"a", mask=False, ) def test_very_long(self): self.assertFrameData( - Frame(True, OP_BINARY, 65536 * b"a"), + Frame(OP_BINARY, 65536 * b"a"), b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x00" + 65536 * b"a", mask=False, ) @@ -187,7 +187,7 @@ def decode(frame, *, max_size=None): return Rot13.encode(frame) self.assertFrameData( - Frame(True, OP_TEXT, b"hello"), + Frame(OP_TEXT, b"hello"), b"\x81\x05uryyb", mask=False, extensions=[Rot13()], @@ -197,125 +197,125 @@ def decode(frame, *, max_size=None): class StrTests(unittest.TestCase): def test_cont_text(self): self.assertEqual( - str(Frame(False, OP_CONT, b" cr\xc3\xa8me")), + str(Frame(OP_CONT, b" cr\xc3\xa8me", fin=False)), "CONT crème [text, 7 bytes, continued]", ) def test_cont_binary(self): self.assertEqual( - str(Frame(False, OP_CONT, b"\xfc\xfd\xfe\xff")), + str(Frame(OP_CONT, b"\xfc\xfd\xfe\xff", fin=False)), "CONT fc fd fe ff [binary, 4 bytes, continued]", ) def test_cont_final_text(self): self.assertEqual( - str(Frame(True, OP_CONT, b" cr\xc3\xa8me")), + str(Frame(OP_CONT, b" cr\xc3\xa8me")), "CONT crème [text, 7 bytes]", ) def test_cont_final_binary(self): self.assertEqual( - str(Frame(True, OP_CONT, b"\xfc\xfd\xfe\xff")), + str(Frame(OP_CONT, b"\xfc\xfd\xfe\xff")), "CONT fc fd fe ff [binary, 4 bytes]", ) def test_cont_text_truncated(self): self.assertEqual( - str(Frame(False, OP_CONT, b"caf\xc3\xa9 " * 16)), + str(Frame(OP_CONT, b"caf\xc3\xa9 " * 16, fin=False)), "CONT café café café café café café café café café caf..." "afé café café café café [text, 96 bytes, continued]", ) def test_cont_binary_truncated(self): self.assertEqual( - str(Frame(False, OP_CONT, bytes(range(256)))), + str(Frame(OP_CONT, bytes(range(256)), fin=False)), "CONT 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", ) def test_text(self): self.assertEqual( - str(Frame(True, OP_TEXT, b"caf\xc3\xa9")), + str(Frame(OP_TEXT, b"caf\xc3\xa9")), "TEXT café [5 bytes]", ) def test_text_non_final(self): self.assertEqual( - str(Frame(False, OP_TEXT, b"caf\xc3\xa9")), + str(Frame(OP_TEXT, b"caf\xc3\xa9", fin=False)), "TEXT café [5 bytes, continued]", ) def test_text_truncated(self): self.assertEqual( - str(Frame(True, OP_TEXT, b"caf\xc3\xa9 " * 16)), + str(Frame(OP_TEXT, b"caf\xc3\xa9 " * 16)), "TEXT café café café café café café café café café caf..." "afé café café café café [96 bytes]", ) def test_binary(self): self.assertEqual( - str(Frame(True, OP_BINARY, b"\x00\x01\x02\x03")), + str(Frame(OP_BINARY, b"\x00\x01\x02\x03")), "BINARY 00 01 02 03 [4 bytes]", ) def test_binary_non_final(self): self.assertEqual( - str(Frame(False, OP_BINARY, b"\x00\x01\x02\x03")), + str(Frame(OP_BINARY, b"\x00\x01\x02\x03", fin=False)), "BINARY 00 01 02 03 [4 bytes, continued]", ) def test_binary_truncated(self): self.assertEqual( - str(Frame(True, OP_BINARY, bytes(range(256)))), + str(Frame(OP_BINARY, bytes(range(256)))), "BINARY 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." " f8 f9 fa fb fc fd fe ff [256 bytes]", ) def test_close(self): self.assertEqual( - str(Frame(True, OP_CLOSE, b"\x03\xe8")), + str(Frame(OP_CLOSE, b"\x03\xe8")), "CLOSE code = 1000 (OK), no reason [2 bytes]", ) def test_close_reason(self): self.assertEqual( - str(Frame(True, OP_CLOSE, b"\x03\xe9Bye!")), + str(Frame(OP_CLOSE, b"\x03\xe9Bye!")), "CLOSE code = 1001 (going away), reason = Bye! [6 bytes]", ) def test_ping(self): self.assertEqual( - str(Frame(True, OP_PING, b"")), + str(Frame(OP_PING, b"")), "PING [0 bytes]", ) def test_ping_text(self): self.assertEqual( - str(Frame(True, OP_PING, b"ping")), + str(Frame(OP_PING, b"ping")), "PING ping [text, 4 bytes]", ) def test_ping_binary(self): self.assertEqual( - str(Frame(True, OP_PING, b"\xff\x00\xff\x00")), + str(Frame(OP_PING, b"\xff\x00\xff\x00")), "PING ff 00 ff 00 [binary, 4 bytes]", ) def test_pong(self): self.assertEqual( - str(Frame(True, OP_PONG, b"")), + str(Frame(OP_PONG, b"")), "PONG [0 bytes]", ) def test_pong_text(self): self.assertEqual( - str(Frame(True, OP_PONG, b"pong")), + str(Frame(OP_PONG, b"pong")), "PONG pong [text, 4 bytes]", ) def test_pong_binary(self): self.assertEqual( - str(Frame(True, OP_PONG, b"\xff\x00\xff\x00")), + str(Frame(OP_PONG, b"\xff\x00\xff\x00")), "PONG ff 00 ff 00 [binary, 4 bytes]", ) diff --git a/tests/test_server.py b/tests/test_server.py index 86fa3f34d..d2c41598e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -634,7 +634,7 @@ def test_bypass_handshake(self): server = ServerConnection(state=OPEN) server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") [frame] = server.events_received() - self.assertEqual(frame, Frame(True, OP_TEXT, b"Hello!")) + self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): logger = logging.getLogger("test") From d5ccd3658b29fea1b57b869d7b9911efa333ff8d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 3 Jun 2021 21:04:19 +0200 Subject: [PATCH 069/762] Fix docstring after refactoring. --- src/websockets/frames.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index ec6ff2258..440f5a93b 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -106,7 +106,7 @@ class Frame: :param bytes data: payload data Only these fields are needed. The MASK bit, payload length and masking-key - are handled on the fly by :func:`parse_frame` and :meth:`serialize_frame`. + are handled on the fly by :meth:`parse` and :meth:`serialize`. """ From d3d1d0550c62ac68edb09dc6f6a11b4715148251 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 07:39:01 +0200 Subject: [PATCH 070/762] Store response headers in InvalidStatusCode. Also introduce InvalidStatus for the new implementation, storing the entire response (but not for the legacy implementationn). Fix #712. --- docs/project/changelog.rst | 2 +- src/websockets/__init__.py | 2 ++ src/websockets/client.py | 4 ++-- src/websockets/exceptions.py | 25 ++++++++++++++++++++++--- src/websockets/legacy/client.py | 2 +- tests/test_exceptions.py | 2 +- 6 files changed, 29 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index caf637614..2a0199147 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -548,7 +548,7 @@ Also: * Added an optional C extension to speed up low-level operations. * An invalid response status code during :func:`~legacy.client.connect` now - raises :class:`~exceptions.InvalidStatusCode` with a ``code`` attribute. + raises :class:`~exceptions.InvalidStatusCode`. * Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer crashes. diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index f136a4e45..8c69a6d63 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -24,6 +24,7 @@ "InvalidParameterName", "InvalidParameterValue", "InvalidState", + "InvalidStatus", "InvalidStatusCode", "InvalidUpgrade", "InvalidURI", @@ -73,6 +74,7 @@ "InvalidHeaderValue": ".exceptions", "InvalidOrigin": ".exceptions", "InvalidUpgrade": ".exceptions", + "InvalidStatus": ".exceptions", "InvalidStatusCode": ".exceptions", "NegotiationError": ".exceptions", "DuplicateParameter": ".exceptions", diff --git a/src/websockets/client.py b/src/websockets/client.py index e9bc12cbe..a517140cd 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -7,7 +7,7 @@ InvalidHandshake, InvalidHeader, InvalidHeaderValue, - InvalidStatusCode, + InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -127,7 +127,7 @@ def process_response(self, response: Response) -> None: """ if response.status_code != 101: - raise InvalidStatusCode(response.status_code) + raise InvalidStatus(response) headers = response.headers diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 4bd2a41a6..644735258 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -13,7 +13,8 @@ * :exc:`InvalidHeaderValue` * :exc:`InvalidOrigin` * :exc:`InvalidUpgrade` - * :exc:`InvalidStatusCode` + * :exc:`InvalidStatus` + * :exc:`InvalidStatusCode` (legacy) * :exc:`NegotiationError` * :exc:`DuplicateParameter` * :exc:`InvalidParameterName` @@ -46,6 +47,7 @@ "InvalidHeaderValue", "InvalidOrigin", "InvalidUpgrade", + "InvalidStatus", "InvalidStatusCode", "NegotiationError", "DuplicateParameter", @@ -190,16 +192,30 @@ class InvalidUpgrade(InvalidHeader): """ +class InvalidStatus(InvalidHandshake): + """ + Raised when a handshake response rejects the WebSocket upgrade. + + """ + + def __init__(self, response: "Response") -> None: + self.response = response + message = f"server rejected WebSocket connection: HTTP {response.status_code:d}" + super().__init__(message) + + class InvalidStatusCode(InvalidHandshake): """ Raised when a handshake response status code is invalid. - The integer status code is available in the ``status_code`` attribute. + The integer status code is available in the ``status_code`` attribute and + HTTP headers in the ``headers`` attribute. """ - def __init__(self, status_code: int) -> None: + def __init__(self, status_code: int, headers: Headers) -> None: self.status_code = status_code + self.headers = headers message = f"server rejected WebSocket connection: HTTP {status_code}" super().__init__(message) @@ -333,3 +349,6 @@ class ProtocolError(WebSocketException): # at the bottom to allow circular import, because the frames module imports exceptions from .frames import format_close # noqa + +# at the bottom to allow circular import, because the http11 module imports exceptions +from .http11 import Response # noqa diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 5c281d3b8..68ab9acf1 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -391,7 +391,7 @@ async def handshake( raise InvalidHeader("Location") raise RedirectHandshake(response_headers["Location"]) elif status_code != 101: - raise InvalidStatusCode(status_code) + raise InvalidStatusCode(status_code, response_headers) check_response(response_headers, key) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b800d4f91..094fb6d33 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -88,7 +88,7 @@ def test_str(self): "invalid Connection header: websocket", ), ( - InvalidStatusCode(403), + InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", ), ( From 0c157d457b871a3055329a50ac30cef640787b67 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 08:21:40 +0200 Subject: [PATCH 071/762] Add timeout to connect. Fix #574. --- docs/project/changelog.rst | 10 ++++++++++ src/websockets/legacy/client.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 2a0199147..603466593 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,14 @@ They may change at any time. **Version 10.0 drops compatibility with Python 3.6.** +.. note:: + + **Version 10.0 enables a timeout of 10 seconds on** + :func:`~legacy.client.connect` **by default.** + + You can adjust the timeout with the ``open_timeout`` parameter. Set it to + ``None`` to disable the timeout entirely. + .. note:: **Version 10.0 deprecates the** ``loop`` **parameter from all APIs.** @@ -43,6 +51,8 @@ They may change at any time. * Added compatibility with Python 3.10. +* Added ``open_timeout`` to :func:`~legacy.client.connect`. + * Improved logging. * Optimized default compression settings to reduce memory usage. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 68ab9acf1..df1a2f57c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -439,6 +439,10 @@ class Connect: be replaced by a wrapper or a subclass to customize the protocol that manages the connection. + If the WebSocket connection isn't established within ``open_timeout`` + seconds, :func:`connect` raises :exc:`~asyncio.TimeoutError`. The default + is 10 seconds. Set ``open_timeout`` to ``None`` to disable the timeout. + The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is described in :class:`WebSocketClientProtocol`. @@ -471,6 +475,7 @@ def __init__( uri: str, *, create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None, + open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, @@ -571,6 +576,7 @@ def __init__( loop.create_connection, factory, host, port, **kwargs ) + self.open_timeout = open_timeout # This is a coroutine function. self._create_connection = create_connection self._wsuri = wsuri @@ -626,7 +632,10 @@ async def __aexit__( def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. - return self.__await_impl__().__await__() + return self.__await_impl_timeout__().__await__() + + async def __await_impl_timeout__(self) -> WebSocketClientProtocol: + return await asyncio.wait_for(self.__await_impl__(), self.open_timeout) async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): From 0779eb973d2779c33b560488212ea8b652601596 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 4 Jun 2021 23:08:11 +0200 Subject: [PATCH 072/762] Prevent crash of data transfer task When the read buffer isn't empty, the following scenario is possible: * TCP connection is closed (one way or another) * transfer_data is still processing buffered data * it encounters a ping and attempts to send a pong transfer_data must never crash, so ignore the error in this case. Fix #977 --- src/websockets/legacy/protocol.py | 9 ++++++--- tests/legacy/test_protocol.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 992d678e2..353317ab5 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -926,14 +926,17 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: # is empty and parse_close() synthetizes a 1005 close code. await self.write_close_frame(frame.data) except ConnectionClosed: - # It doesn't really matter if the connection was closed - # before we could send back a close frame. + # Connection closed before we could echo the close frame. pass return None elif frame.opcode == OP_PING: # Answer pings. - await self.pong(frame.data) + try: + await self.pong(frame.data) + except ConnectionClosed: + # Connection closed before we could respond to the ping. + pass elif frame.opcode == OP_PONG: if frame.data in self.pings: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 6f6e6f686..c1948c838 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -877,6 +877,16 @@ def test_answer_ping(self): self.run_loop_once() self.assertOneFrameSent(True, OP_PONG, b"test") + def test_answer_ping_does_not_crash_if_connection_closed(self): + self.make_drain_slow() + # Drop the connection right after receiving a ping frame, + # which prevents responding wwith a pong frame properly. + self.receive_frame(Frame(True, OP_PING, b"test")) + self.receive_eof() + + with self.assertNoLogs(): + self.loop.run_until_complete(self.protocol.close()) + def test_ignore_pong(self): self.receive_frame(Frame(True, OP_PONG, b"test")) self.run_loop_once() From ebc8448b6603977c4e81e182e9827705c9a5889c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 13:14:04 +0200 Subject: [PATCH 073/762] Improve representation of frames. Specifically this avoids cutting log messages in two lines. Ref #765. --- src/websockets/frames.py | 6 +++--- tests/test_frames.py | 42 ++++++++++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 440f5a93b..9011ac867 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -129,7 +129,7 @@ def __str__(self) -> str: if self.opcode is OP_TEXT: # Decoding only the beginning and the end is needlessly hard. # Decode the entire payload then elide later if necessary. - data = self.data.decode() + data = repr(self.data.decode()) elif self.opcode is OP_BINARY: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. @@ -145,7 +145,7 @@ def __str__(self) -> str: # Ping and Pong frames could contain UTF-8. Attempt to decode as # UTF-8 and display it as text; fallback to binary. try: - data = self.data.decode() + data = repr(self.data.decode()) coding = "text" except UnicodeDecodeError: binary = self.data @@ -154,7 +154,7 @@ def __str__(self) -> str: data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: - data = "" + data = "''" if len(data) > 75: data = data[:48] + "..." + data[-24:] diff --git a/tests/test_frames.py b/tests/test_frames.py index d36344639..85414e9d3 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -198,7 +198,7 @@ class StrTests(unittest.TestCase): def test_cont_text(self): self.assertEqual( str(Frame(OP_CONT, b" cr\xc3\xa8me", fin=False)), - "CONT crème [text, 7 bytes, continued]", + "CONT ' crème' [text, 7 bytes, continued]", ) def test_cont_binary(self): @@ -210,7 +210,7 @@ def test_cont_binary(self): def test_cont_final_text(self): self.assertEqual( str(Frame(OP_CONT, b" cr\xc3\xa8me")), - "CONT crème [text, 7 bytes]", + "CONT ' crème' [text, 7 bytes]", ) def test_cont_final_binary(self): @@ -222,8 +222,8 @@ def test_cont_final_binary(self): def test_cont_text_truncated(self): self.assertEqual( str(Frame(OP_CONT, b"caf\xc3\xa9 " * 16, fin=False)), - "CONT café café café café café café café café café caf..." - "afé café café café café [text, 96 bytes, continued]", + "CONT 'café café café café café café café café café ca..." + "fé café café café café ' [text, 96 bytes, continued]", ) def test_cont_binary_truncated(self): @@ -236,20 +236,26 @@ def test_cont_binary_truncated(self): def test_text(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9")), - "TEXT café [5 bytes]", + "TEXT 'café' [5 bytes]", ) def test_text_non_final(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9", fin=False)), - "TEXT café [5 bytes, continued]", + "TEXT 'café' [5 bytes, continued]", ) def test_text_truncated(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9 " * 16)), - "TEXT café café café café café café café café café caf..." - "afé café café café café [96 bytes]", + "TEXT 'café café café café café café café café café ca..." + "fé café café café café ' [96 bytes]", + ) + + def test_text_with_newline(self): + self.assertEqual( + str(Frame(OP_TEXT, b"Hello\nworld!")), + "TEXT 'Hello\\nworld!' [12 bytes]", ) def test_binary(self): @@ -286,13 +292,19 @@ def test_close_reason(self): def test_ping(self): self.assertEqual( str(Frame(OP_PING, b"")), - "PING [0 bytes]", + "PING '' [0 bytes]", ) def test_ping_text(self): self.assertEqual( str(Frame(OP_PING, b"ping")), - "PING ping [text, 4 bytes]", + "PING 'ping' [text, 4 bytes]", + ) + + def test_ping_text_with_newline(self): + self.assertEqual( + str(Frame(OP_PING, b"ping\n")), + "PING 'ping\\n' [text, 5 bytes]", ) def test_ping_binary(self): @@ -304,13 +316,19 @@ def test_ping_binary(self): def test_pong(self): self.assertEqual( str(Frame(OP_PONG, b"")), - "PONG [0 bytes]", + "PONG '' [0 bytes]", ) def test_pong_text(self): self.assertEqual( str(Frame(OP_PONG, b"pong")), - "PONG pong [text, 4 bytes]", + "PONG 'pong' [text, 4 bytes]", + ) + + def test_pong_text_with_newline(self): + self.assertEqual( + str(Frame(OP_PONG, b"pong\n")), + "PONG 'pong\\n' [text, 5 bytes]", ) def test_pong_binary(self): From e5458a16f7c4162289c248a386e21fce24fa621f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 22:29:22 +0200 Subject: [PATCH 074/762] Simplify table. Fix #985. --- docs/topics/compression.rst | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index 206bfa7b7..e23319636 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -109,25 +109,18 @@ Here's how various compression settings affect memory usage of a single connection on a 64-bit system, as well a benchmark of compressed size and compression time for a corpus of small JSON documents. -+-------------+-------------+--------------+--------------+------------------+------------------+ -| Compression | Window Bits | Memory Level | Memory usage | Size vs. default | Time vs. default | -+=============+=============+==============+==============+==================+==================+ -| | 15 | 8 | 322 KiB | -4.0% | +15% + -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 14 | 7 | 178 KiB | -2.6% | +10% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 13 | 6 | 106 KiB | -1.4% | +5% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| *default* | 12 | 5 | 70 KiB | = | = | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 11 | 4 | 52 KiB | +3.7% | -5% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 10 | 3 | 43 KiB | +90% | +50% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| | 9 | 2 | 39 KiB | +160% | +100% | -+-------------+-------------+--------------+--------------+------------------+------------------+ -| *disabled* | — | — | 19 KiB | +452% | — | -+-------------+-------------+--------------+--------------+------------------+------------------+ +=========== ============ ============ ================ ================ +Window Bits Memory Level Memory usage Size vs. default Time vs. default +=========== ============ ============ ================ ================ +15 8 322 KiB -4.0% +15% +14 7 178 KiB -2.6% +10% +13 6 106 KiB -1.4% +5% +**12** **5** **70 KiB** **=** **=** +11 4 52 KiB +3.7% -5% +10 3 43 KiB +90% +50% +9 2 39 KiB +160% +100% +— — 19 KiB +452% — +=========== ============ ============ ================ ================ Window Bits and Memory Level don't have to move in lockstep. However, other combinations don't yield significantly better results than those shown above. From e444fb57b88c5c446fbe406c66d230e9ce15a8d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 7 Jun 2021 20:20:51 +0200 Subject: [PATCH 075/762] Add CVE reference. --- docs/project/changelog.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 603466593..fc751bbd0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -69,7 +69,11 @@ They may change at any time. **Version 9.1 fixes a security issue introduced in version 8.0.** - Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords. + Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords + (`CVE-2021-33880`_). + + .. _CVE-2021-33880: https://nvd.nist.gov/vuln/detail/CVE-2021-33880 + 9.0.2 ..... From cc1254b28867fcd2391e436fe5f10f7f40c77729 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 7 Jun 2021 21:01:23 +0200 Subject: [PATCH 076/762] Don't use get_running_loop outside of coroutines. get_event_loop() returns the running loop if there's one anyway. --- src/websockets/legacy/client.py | 3 +-- src/websockets/legacy/compatibility.py | 11 ----------- src/websockets/legacy/server.py | 4 ++-- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index df1a2f57c..55bc77422 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -33,7 +33,6 @@ from ..http import USER_AGENT, build_host from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .compatibility import asyncio_get_running_loop from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol @@ -517,7 +516,7 @@ def __init__( # Backwards compatibility: the loop parameter used to be supported. loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) if loop is None: - loop = asyncio_get_running_loop() + loop = asyncio.get_event_loop() else: warnings.warn("remove loop argument", DeprecationWarning) diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 86f6715fd..96df028e4 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -9,14 +9,3 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ return {"loop": loop} if sys.version_info[:2] < (3, 8) else {} - - -def asyncio_get_running_loop() -> asyncio.AbstractEventLoop: - """ - Helper for the deprecation of get_event_loop in Python 3.10. - - """ - if sys.version_info[:2] < (3, 10): # pragma: no cover - return asyncio.get_event_loop() - else: # pragma: no cover - return asyncio.get_running_loop() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 1704ae083..ab03112b2 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -42,7 +42,7 @@ from ..headers import build_extension, parse_extension, parse_subprotocol from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol -from .compatibility import asyncio_get_running_loop, loop_if_py_lt_38 +from .compatibility import loop_if_py_lt_38 from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol @@ -1026,7 +1026,7 @@ def __init__( # Backwards compatibility: the loop parameter used to be supported. loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) if loop is None: - loop = asyncio_get_running_loop() + loop = asyncio.get_event_loop() else: warnings.warn("remove loop argument", DeprecationWarning) From cb11516e0ed4fe2b67ab6c1511650bd42115d0b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 7 Jun 2021 22:11:26 +0200 Subject: [PATCH 077/762] Stop requiring loop in WebSocketCommonProtocol. Change tests to avoid passing a loop argument. Fix #988. --- src/websockets/legacy/client.py | 7 ++--- src/websockets/legacy/protocol.py | 8 ++++-- src/websockets/legacy/server.py | 7 ++--- tests/legacy/test_client_server.py | 35 ++++++++++++------------- tests/legacy/test_protocol.py | 41 +++++++++++++++++++++--------- 5 files changed, 59 insertions(+), 39 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 55bc77422..695da3cdb 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -514,10 +514,11 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) - if loop is None: + _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + if _loop is None: loop = asyncio.get_event_loop() else: + loop = _loop warnings.warn("remove loop argument", DeprecationWarning) wsuri = parse_uri(uri) @@ -543,7 +544,7 @@ def __init__( max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, - loop=loop, + loop=_loop, host=wsuri.host, port=wsuri.port, secure=wsuri.secure, diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 353317ab5..ca6076142 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -124,6 +124,12 @@ def __init__( if close_timeout is None: close_timeout = timeout + # Backwards compatibility: the loop parameter used to be supported. + if loop is None: + loop = asyncio.get_event_loop() + else: + warnings.warn("remove loop argument", DeprecationWarning) + self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout @@ -145,8 +151,6 @@ def __init__( # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) - assert loop is not None - # Remove when dropping Python < 3.10 - use get_running_loop instead. self.loop = loop self._host = host diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index ab03112b2..c8425993d 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1024,10 +1024,11 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) - if loop is None: + _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + if _loop is None: loop = asyncio.get_event_loop() else: + loop = _loop warnings.warn("remove loop argument", DeprecationWarning) ws_server = WebSocketServer(logger=logger, loop=loop) @@ -1053,7 +1054,7 @@ def __init__( max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, - loop=loop, + loop=_loop, legacy_recv=legacy_recv, origins=origins, extensions=extensions, diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 60c0a14ae..041eca28e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -213,17 +213,17 @@ def start_server(self, deprecation_warnings=None, **kwargs): kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) - # Python 3.10 dislikes not having a running event loop - if sys.version_info[:2] >= (3, 10): # pragma: no cover - kwargs.setdefault("loop", self.loop) with warnings.catch_warnings(record=True) as recorded_warnings: start_server = serve(handler, "localhost", 0, **kwargs) self.server = self.loop.run_until_complete(start_server) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if sys.version_info[:2] >= (3, 10): # pragma: no cover - expected_warnings += ["remove loop argument"] + if ( + sys.version_info[:2] >= (3, 10) + and "remove loop argument" not in expected_warnings + ): # pragma: no cover + expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_redirecting_server( @@ -234,10 +234,6 @@ def start_redirecting_server( deprecation_warnings=None, **kwargs, ): - # Python 3.10 dislikes not having a running event loop - if sys.version_info[:2] >= (3, 10): # pragma: no cover - kwargs.setdefault("loop", self.loop) - async def process_request(path, headers): server_uri = get_server_uri(self.server, self.secure, path) if force_insecure: @@ -259,8 +255,11 @@ async def process_request(path, headers): self.redirecting_server = self.loop.run_until_complete(start_server) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if sys.version_info[:2] >= (3, 10): # pragma: no cover - expected_warnings += ["remove loop argument"] + if ( + sys.version_info[:2] >= (3, 10) + and "remove loop argument" not in expected_warnings + ): # pragma: no cover + expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_client( @@ -270,9 +269,6 @@ def start_client( kwargs.setdefault("compression", None) # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) - # Python 3.10 dislikes not having a running event loop - if sys.version_info[:2] >= (3, 10): # pragma: no cover - kwargs.setdefault("loop", self.loop) secure = kwargs.get("ssl") is not None try: @@ -286,8 +282,11 @@ def start_client( self.client = self.loop.run_until_complete(start_client) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if sys.version_info[:2] >= (3, 10): # pragma: no cover - expected_warnings += ["remove loop argument"] + if ( + sys.version_info[:2] >= (3, 10) + and "remove loop argument" not in expected_warnings + ): # pragma: no cover + expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): @@ -409,8 +408,6 @@ def test_infinite_redirect(self): with temp_test_redirecting_server( self, http.HTTPStatus.FOUND, - loop=self.loop, - deprecation_warnings=["remove loop argument"], ): self.server = self.redirecting_server with self.assertRaises(InvalidHandshake): @@ -430,7 +427,7 @@ def test_redirect_missing_location(self): with temp_test_client(self): self.fail("Did not raise") # pragma: no cover - def test_explicit_event_loop(self): + def test_loop_backwards_compatibility(self): with self.temp_server( loop=self.loop, deprecation_warnings=["remove loop argument"] ): diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index c1948c838..ccbbffe7c 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import sys import unittest import unittest.mock import warnings @@ -86,8 +87,9 @@ class CommonTests: def setUp(self): super().setUp() - # Disable pings to make it easier to test what frames are sent exactly. - self.protocol = WebSocketCommonProtocol(ping_interval=None, loop=self.loop) + with warnings.catch_warnings(record=True): + # Disable pings to make it easier to test what frames are sent exactly. + self.protocol = WebSocketCommonProtocol(ping_interval=None) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) @@ -309,14 +311,29 @@ def assertCompletesWithin(self, min_time, max_time): # Test constructor. def test_timeout_backwards_compatibility(self): - with warnings.catch_warnings(record=True) as recorded_warnings: - protocol = WebSocketCommonProtocol(timeout=5, loop=self.loop) + with warnings.catch_warnings(record=True) as recorded: + protocol = WebSocketCommonProtocol(timeout=5) self.assertEqual(protocol.close_timeout, 5) - self.assertDeprecationWarnings( - recorded_warnings, ["rename timeout to close_timeout"] - ) + expected = ["rename timeout to close_timeout"] + if sys.version_info[:2] >= (3, 10): # pragma: no cover + expected += ["There is no current event loop"] + + self.assertDeprecationWarnings(recorded, expected) + + def test_loop_backwards_compatibility(self): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + with warnings.catch_warnings(record=True) as recorded: + protocol = WebSocketCommonProtocol(loop=loop) + + self.assertEqual(protocol.loop, loop) + + expected = ["remove loop argument"] + + self.assertDeprecationWarnings(recorded, expected) # Test public attributes. @@ -1116,11 +1133,11 @@ def restart_protocol_with_keepalive_ping( self.transport.close() self.loop.run_until_complete(self.protocol.close()) # copied from setUp, but enables keepalive pings - self.protocol = WebSocketCommonProtocol( - ping_interval=ping_interval, - ping_timeout=ping_timeout, - loop=self.loop, - ) + with warnings.catch_warnings(record=True): + self.protocol = WebSocketCommonProtocol( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + ) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) self.protocol.is_client = initial_protocol.is_client From 29094a27e65ae34b9219effd2f4bd755f0a63ff5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Jun 2021 11:49:15 +0200 Subject: [PATCH 078/762] Add support for reconnecting automatically. Fix #414. --- docs/howto/logging.rst | 1 + docs/project/changelog.rst | 3 ++ docs/spelling_wordlist.txt | 2 + src/websockets/legacy/client.py | 69 +++++++++++++++++++++++--- src/websockets/legacy/server.py | 2 +- tests/legacy/test_client_server.py | 79 ++++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+), 7 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index f69ee47b9..824812959 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -210,6 +210,7 @@ Here's what websockets logs at each level. * Server starting and stopping * Server establishing and closing connections +* Client reconnecting automatically :attr:`~logging.DEBUG` ...................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index fc751bbd0..a0cd8f07c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -51,6 +51,9 @@ They may change at any time. * Added compatibility with Python 3.10. +* Added support for reconnecting automatically by using + :func:`~legacy.client.connect` as an asynchronous iterator. + * Added ``open_timeout`` to :func:`~legacy.client.connect`. * Improved logging. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b460ef033..8346acefa 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -6,6 +6,7 @@ autoscaler awaitable aymeric backend +backoff backpressure balancer balancers @@ -52,6 +53,7 @@ pong pongs proxying pythonic +reconnection redis retransmit runtime diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 695da3cdb..63fdbf9b2 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -9,7 +9,18 @@ import logging import warnings from types import TracebackType -from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, Type, cast +from typing import ( + Any, + AsyncIterator, + Callable, + Generator, + List, + Optional, + Sequence, + Tuple, + Type, + cast, +) from ..datastructures import Headers, HeadersLike from ..exceptions import ( @@ -412,12 +423,23 @@ class Connect: Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. - :func:`connect` can also be used as a asynchronous context manager:: + :func:`connect` can be used as a asynchronous context manager:: async with connect(...) as websocket: ... - In that case, the connection is closed when exiting the context. + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + ... + + You must catch all exceptions, or else you will exit the loop prematurely. + As above, connections are closed automatically. Connection attempts are + delayed with exponential backoff, starting at three seconds and + increasing up to one minute. :func:`connect` is a wrapper around the event loop's :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments @@ -577,6 +599,10 @@ def __init__( ) self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.client") + self.logger = logger + # This is a coroutine function. self._create_connection = create_connection self._wsuri = wsuri @@ -615,7 +641,38 @@ def handle_redirect(self, uri: str) -> None: # Set the new WebSocket URI. This suffices for same-origin redirects. self._wsuri = new_wsuri - # async with connect(...) + # async for ... in connect(...): + + BACKOFF_MIN = 2.0 + BACKOFF_MAX = 60.0 + BACKOFF_FACTOR = 1.5 + + async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: + backoff_delay = self.BACKOFF_MIN + while True: + try: + async with self as protocol: + yield protocol + # Remove this branch when dropping support for Python < 3.8 + # because CancelledError no longer inherits Exception. + except asyncio.CancelledError: # pragma: no cover + raise + except Exception: + # Connection timed out - increase backoff delay + backoff_delay = backoff_delay * self.BACKOFF_FACTOR + backoff_delay = min(backoff_delay, self.BACKOFF_MAX) + self.logger.info( + "! connect failed; retrying in %d seconds", + int(backoff_delay), + exc_info=True, + ) + await asyncio.sleep(backoff_delay) + continue + else: + # Connection succeeded - reset backoff delay + backoff_delay = self.BACKOFF_MIN + + # async with connect(...) as ...: async def __aenter__(self) -> WebSocketClientProtocol: return await self @@ -628,7 +685,7 @@ async def __aexit__( ) -> None: await self.protocol.close() - # await connect(...) + # ... = await connect(...) def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. @@ -665,7 +722,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: else: raise SecurityError("too many redirects") - # yield from connect(...) + # ... = yield from connect(...) __iter__ = __await__ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c8425993d..9ace9a2c6 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -904,7 +904,7 @@ class Serve: :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their current or next interaction with the WebSocket connection. - :func:`serve` can also be used as an asynchronous context manager:: + :func:`serve` can be used as an asynchronous context manager:: stop = asyncio.Future() # set this future to exit the server diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 041eca28e..2f754aa89 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1471,6 +1471,85 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) +class ReconnectionTests(ClientServerTestsMixin, AsyncioTestCase): + async def echo_handler(ws, path): + async for msg in ws: + await ws.send(msg) + + service_available = True + + async def maybe_service_unavailable(path, headers): + if not ReconnectionTests.service_available: + return http.HTTPStatus.SERVICE_UNAVAILABLE, [], b"" + + async def disable_server(self, duration): + ReconnectionTests.service_available = False + await asyncio.sleep(duration) + ReconnectionTests.service_available = True + + @with_server(handler=echo_handler, process_request=maybe_service_unavailable) + def test_reconnect(self): + # Big, ugly integration test :-( + + async def run_client(): + iteration = 0 + connect_inst = connect(get_server_uri(self.server)) + connect_inst.BACKOFF_MIN = 10 * MS + connect_inst.BACKOFF_MAX = 200 * MS + async for ws in connect_inst: + await ws.send("spam") + msg = await ws.recv() + self.assertEqual(msg, "spam") + + iteration += 1 + if iteration == 1: + # Exit block normally. + pass + elif iteration == 2: + # Disable server for a little bit + asyncio.create_task(self.disable_server(70 * MS)) + await asyncio.sleep(0) + elif iteration == 3: + # Exit block after catching connection error. + server_ws = next(iter(self.server.websockets)) + await server_ws.close() + with self.assertRaises(ConnectionClosed): + await ws.recv() + else: + # Exit block with an exception. + raise Exception("BOOM!") + + with self.assertLogs("websockets", logging.INFO) as logs: + with self.assertRaisesRegex(Exception, "BOOM!"): + self.loop.run_until_complete(run_client()) + + self.assertEqual( + [record.getMessage() for record in logs.records], + [ + # Iteration 1 + "connection open", + "connection closed", + # Iteration 2 + "connection open", + "connection closed", + # Iteration 3 + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed; retrying in 0 seconds", + "connection open", + "connection closed", + # Iteration 4 + "connection open", + ], + ) + + class LoggerTests(ClientServerTestsMixin, AsyncioTestCase): def test_logger_client(self): with self.assertLogs("test.server", logging.DEBUG) as server_logs: From 0a60fac6c2cd922198bd17e034c9a4406e273db7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 9 Jun 2021 08:15:00 +0200 Subject: [PATCH 079/762] Make test less timing-sensitive. Fix #991. --- tests/legacy/test_client_server.py | 36 ++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2f754aa89..d3e3f1e9f 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1507,7 +1507,7 @@ async def run_client(): pass elif iteration == 2: # Disable server for a little bit - asyncio.create_task(self.disable_server(70 * MS)) + asyncio.create_task(self.disable_server(50 * MS)) await asyncio.sleep(0) elif iteration == 3: # Exit block after catching connection error. @@ -1523,28 +1523,40 @@ async def run_client(): with self.assertRaisesRegex(Exception, "BOOM!"): self.loop.run_until_complete(run_client()) + # Iteration 1 self.assertEqual( - [record.getMessage() for record in logs.records], + [record.getMessage() for record in logs.records][:2], [ - # Iteration 1 "connection open", "connection closed", - # Iteration 2 + ], + ) + # Iteration 2 + self.assertEqual( + [record.getMessage() for record in logs.records][2:4], + [ "connection open", "connection closed", - # Iteration 3 - "connection failed (503 Service Unavailable)", - "connection closed", - "! connect failed; retrying in 0 seconds", - "connection failed (503 Service Unavailable)", - "connection closed", - "! connect failed; retrying in 0 seconds", + ], + ) + # Iteration 3 + self.assertEqual( + [record.getMessage() for record in logs.records][4:-1], + [ "connection failed (503 Service Unavailable)", "connection closed", "! connect failed; retrying in 0 seconds", + ] + * ((len(logs.records) - 5) // 3) + + [ "connection open", "connection closed", - # Iteration 4 + ], + ) + # Iteration 4 + self.assertEqual( + [record.getMessage() for record in logs.records][-1:], + [ "connection open", ], ) From 298cdabef68e9619d89100684eb27f719fc361a6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 10 Jun 2021 22:58:57 +0200 Subject: [PATCH 080/762] Postpone evaluation of annotations. See PEP 563. --- src/websockets/__main__.py | 11 +++++++---- src/websockets/client.py | 2 ++ src/websockets/connection.py | 2 ++ src/websockets/datastructures.py | 2 ++ src/websockets/exceptions.py | 4 +++- src/websockets/extensions/base.py | 2 ++ src/websockets/extensions/permessage_deflate.py | 6 ++++-- src/websockets/frames.py | 6 ++++-- src/websockets/headers.py | 2 ++ src/websockets/http.py | 2 ++ src/websockets/http11.py | 2 ++ src/websockets/imports.py | 2 ++ src/websockets/legacy/auth.py | 1 + src/websockets/legacy/client.py | 2 ++ src/websockets/legacy/compatibility.py | 2 ++ src/websockets/legacy/framing.py | 6 ++++-- src/websockets/legacy/handshake.py | 2 ++ src/websockets/legacy/http.py | 2 ++ src/websockets/legacy/protocol.py | 2 ++ src/websockets/legacy/server.py | 6 ++++-- src/websockets/server.py | 2 ++ src/websockets/streams.py | 2 ++ src/websockets/typing.py | 2 ++ src/websockets/uri.py | 2 ++ src/websockets/utils.py | 2 ++ 25 files changed, 63 insertions(+), 13 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 4358c323c..dae165cd2 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import asyncio import os @@ -44,7 +46,8 @@ def win_enable_vt100() -> None: def exit_from_event_loop_thread( - loop: asyncio.AbstractEventLoop, stop: "asyncio.Future[None]" + loop: asyncio.AbstractEventLoop, + stop: asyncio.Future[None], ) -> None: loop.stop() if not stop.done(): @@ -92,8 +95,8 @@ def print_over_input(string: str) -> None: async def run_client( uri: str, loop: asyncio.AbstractEventLoop, - inputs: "asyncio.Queue[str]", - stop: "asyncio.Future[None]", + inputs: asyncio.Queue[str], + stop: asyncio.Future[None], ) -> None: try: websocket = await connect(uri) @@ -183,7 +186,7 @@ async def queue_factory() -> "asyncio.Queue[str]": return asyncio.Queue() # Create a queue of user inputs. There's no need to limit its size. - inputs: "asyncio.Queue[str]" = loop.run_until_complete(queue_factory()) + inputs: asyncio.Queue[str] = loop.run_until_complete(queue_factory()) # Create a stop condition when receiving SIGINT or SIGTERM. stop: asyncio.Future[None] = loop.create_future() diff --git a/src/websockets/client.py b/src/websockets/client.py index a517140cd..cd83a59bc 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections from typing import Generator, List, Optional, Sequence diff --git a/src/websockets/connection.py b/src/websockets/connection.py index dbdc15bb5..c98f5c39b 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import enum import logging import uuid diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 66f91e9bb..117ffd4f2 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -3,6 +3,8 @@ """ +from __future__ import annotations + from typing import ( Any, Dict, diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 644735258..061b902fb 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -28,6 +28,8 @@ """ +from __future__ import annotations + import http from typing import Optional @@ -198,7 +200,7 @@ class InvalidStatus(InvalidHandshake): """ - def __init__(self, response: "Response") -> None: + def __init__(self, response: Response) -> None: self.response = response message = f"server rejected WebSocket connection: HTTP {response.status_code:d}" super().__init__(message) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index cfc090799..82ae5e27f 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -8,6 +8,8 @@ """ +from __future__ import annotations + from typing import List, Optional, Sequence, Tuple from ..frames import Frame diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 56ef03e0c..d578a7a1d 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -4,6 +4,8 @@ """ +from __future__ import annotations + import dataclasses import zlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -328,7 +330,7 @@ def get_request_params(self) -> List[ExtensionParameter]: def process_response_params( self, params: Sequence[ExtensionParameter], - accepted_extensions: Sequence["Extension"], + accepted_extensions: Sequence[Extension], ) -> PerMessageDeflate: """ Process response parameters. @@ -510,7 +512,7 @@ def __init__( def process_request_params( self, params: Sequence[ExtensionParameter], - accepted_extensions: Sequence["Extension"], + accepted_extensions: Sequence[Extension], ) -> Tuple[List[ExtensionParameter], PerMessageDeflate]: """ Process request parameters. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 9011ac867..99e43388b 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -3,6 +3,8 @@ """ +from __future__ import annotations + import dataclasses import enum import io @@ -170,7 +172,7 @@ def parse( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["extensions.Extension"]] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, ) -> Generator[None, None, "Frame"]: """ Read a WebSocket frame. @@ -239,7 +241,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence["extensions.Extension"]] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, ) -> bytes: """ Write a WebSocket frame. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 6779c9c04..12d2a4e94 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,6 +4,8 @@ """ +from __future__ import annotations + import base64 import binascii import re diff --git a/src/websockets/http.py b/src/websockets/http.py index 9092836c2..6168c5144 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ipaddress import sys diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 22488dc89..aaa61f8c7 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import re from typing import Callable, Generator, Optional diff --git a/src/websockets/imports.py b/src/websockets/imports.py index 06917ce1d..c9508d188 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from typing import Any, Dict, Iterable, Optional diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 16016e6fd..5f2b1311a 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -4,6 +4,7 @@ """ +from __future__ import annotations import functools import hmac diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 63fdbf9b2..468b5c15c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -3,6 +3,8 @@ """ +from __future__ import annotations + import asyncio import collections.abc import functools diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 96df028e4..df81de9db 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import sys from typing import Any, Dict diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 901b2e2e3..14667925f 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -10,6 +10,8 @@ """ +from __future__ import annotations + import struct from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence @@ -56,7 +58,7 @@ async def read( *, mask: bool, max_size: Optional[int] = None, - extensions: Optional[Sequence["extensions.Extension"]] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, ) -> "Frame": """ Read a WebSocket frame. @@ -134,7 +136,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence["extensions.Extension"]] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, ) -> None: """ Write a WebSocket frame. diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 44da72d21..49d08cfe8 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -25,6 +25,8 @@ """ +from __future__ import annotations + import base64 import binascii from typing import List diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index c18e08e8d..0b9a92267 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import re from typing import Tuple diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index ca6076142..f8daf544b 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -7,6 +7,8 @@ """ +from __future__ import annotations + import asyncio import codecs import collections diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9ace9a2c6..a7a98e006 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -3,6 +3,8 @@ """ +from __future__ import annotations + import asyncio import collections.abc import email.utils @@ -169,8 +171,8 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, - ws_handler: Callable[["WebSocketServerProtocol", str], Awaitable[Any]], - ws_server: "WebSocketServer", + ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_server: WebSocketServer, *, origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, diff --git a/src/websockets/server.py b/src/websockets/server.py index 483ddbe1e..09ed63150 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import binascii import collections diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 6f3163034..e02a6ab39 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Generator diff --git a/src/websockets/typing.py b/src/websockets/typing.py index b1858d73e..13b172f15 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from typing import List, NewType, Optional, Tuple, Union diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 958975b22..7406a60a8 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -7,6 +7,8 @@ """ +from __future__ import annotations + import dataclasses import urllib.parse from typing import Optional, Tuple diff --git a/src/websockets/utils.py b/src/websockets/utils.py index 59210e438..ffb706963 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import hashlib import itertools From 71dbbffabcaaaba5dbcfb21df20f98e2bdbe01f6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 10 Jun 2021 23:16:30 +0200 Subject: [PATCH 081/762] Break import cycles. Import modules rather than their contents. Fix #989. --- src/websockets/client.py | 2 +- src/websockets/connection.py | 2 +- src/websockets/exceptions.py | 19 +++++---------- src/websockets/extensions/base.py | 11 ++++++--- .../extensions/permessage_deflate.py | 19 +++++++++------ src/websockets/frames.py | 24 +++++++++---------- src/websockets/http11.py | 19 +++++++-------- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/framing.py | 15 +++++------- src/websockets/legacy/protocol.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- src/websockets/uri.py | 6 ++--- 13 files changed, 61 insertions(+), 64 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index cd83a59bc..3217090d0 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -13,7 +13,7 @@ InvalidUpgrade, NegotiationError, ) -from .extensions.base import ClientExtensionFactory, Extension +from .extensions import ClientExtensionFactory, Extension from .headers import ( build_authorization_basic, build_extension, diff --git a/src/websockets/connection.py b/src/websockets/connection.py index c98f5c39b..57d3a2227 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -6,7 +6,7 @@ from typing import Generator, List, Optional, Union from .exceptions import InvalidState, PayloadTooBig, ProtocolError -from .extensions.base import Extension +from .extensions import Extension from .frames import ( OP_BINARY, OP_CLOSE, diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 061b902fb..f110322aa 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -33,7 +33,7 @@ import http from typing import Optional -from .datastructures import Headers, HeadersLike +from . import datastructures, frames, http11 __all__ = [ @@ -84,7 +84,7 @@ class ConnectionClosed(WebSocketException): def __init__(self, code: int, reason: str) -> None: self.code = code self.reason = reason - super().__init__(format_close(code, reason)) + super().__init__(frames.format_close(code, reason)) class ConnectionClosedError(ConnectionClosed): @@ -200,7 +200,7 @@ class InvalidStatus(InvalidHandshake): """ - def __init__(self, response: Response) -> None: + def __init__(self, response: http11.Response) -> None: self.response = response message = f"server rejected WebSocket connection: HTTP {response.status_code:d}" super().__init__(message) @@ -215,7 +215,7 @@ class InvalidStatusCode(InvalidHandshake): """ - def __init__(self, status_code: int, headers: Headers) -> None: + def __init__(self, status_code: int, headers: datastructures.Headers) -> None: self.status_code = status_code self.headers = headers message = f"server rejected WebSocket connection: HTTP {status_code}" @@ -284,11 +284,11 @@ class AbortHandshake(InvalidHandshake): def __init__( self, status: http.HTTPStatus, - headers: HeadersLike, + headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: self.status = status - self.headers = Headers(headers) + self.headers = datastructures.Headers(headers) self.body = body message = f"HTTP {status:d}, {len(self.headers)} headers, {len(body)} bytes" super().__init__(message) @@ -347,10 +347,3 @@ class ProtocolError(WebSocketException): WebSocketProtocolError = ProtocolError # for backwards compatibility - - -# at the bottom to allow circular import, because the frames module imports exceptions -from .frames import format_close # noqa - -# at the bottom to allow circular import, because the http11 module imports exceptions -from .http11 import Response # noqa diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 82ae5e27f..2de9176bd 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -12,7 +12,7 @@ from typing import List, Optional, Sequence, Tuple -from ..frames import Frame +from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -32,7 +32,12 @@ def name(self) -> ExtensionName: """ - def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: + def decode( + self, + frame: frames.Frame, + *, + max_size: Optional[int] = None, + ) -> frames.Frame: """ Decode an incoming frame. @@ -41,7 +46,7 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: """ - def encode(self, frame: Frame) -> Frame: + def encode(self, frame: frames.Frame) -> frames.Frame: """ Encode an outgoing frame. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index d578a7a1d..5604fb8f9 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -10,6 +10,7 @@ import zlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from .. import frames from ..exceptions import ( DuplicateParameter, InvalidParameterName, @@ -17,7 +18,6 @@ NegotiationError, PayloadTooBig, ) -from ..frames import CTRL_OPCODES, OP_CONT, Frame from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -93,19 +93,24 @@ def __repr__(self) -> str: f"local_max_window_bits={self.local_max_window_bits})" ) - def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: + def decode( + self, + frame: frames.Frame, + *, + max_size: Optional[int] = None, + ) -> frames.Frame: """ Decode an incoming frame. """ # Skip control frames. - if frame.opcode in CTRL_OPCODES: + if frame.opcode in frames.CTRL_OPCODES: return frame # Handle continuation data frames: # - skip if the message isn't encoded # - reset "decode continuation data" flag if it's a final frame - if frame.opcode == OP_CONT: + if frame.opcode is frames.OP_CONT: if not self.decode_cont_data: return frame if frame.fin: @@ -143,19 +148,19 @@ def decode(self, frame: Frame, *, max_size: Optional[int] = None) -> Frame: return dataclasses.replace(frame, data=data) - def encode(self, frame: Frame) -> Frame: + def encode(self, frame: frames.Frame) -> frames.Frame: """ Encode an outgoing frame. """ # Skip control frames. - if frame.opcode in CTRL_OPCODES: + if frame.opcode in frames.CTRL_OPCODES: return frame # Since we always encode messages, there's no "encode continuation # data" flag similar to "decode continuation data" at this time. - if frame.opcode != OP_CONT: + if frame.opcode is not frames.OP_CONT: # Set the rsv1 flag on the first frame of a compressed message. frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 99e43388b..510fc6198 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -12,7 +12,7 @@ import struct from typing import Callable, Generator, Optional, Sequence, Tuple -from .exceptions import PayloadTooBig, ProtocolError +from . import exceptions, extensions from .typing import Data @@ -204,10 +204,10 @@ def parse( try: opcode = Opcode(head1 & 0b00001111) except ValueError as exc: - raise ProtocolError("invalid opcode") from exc + raise exceptions.ProtocolError("invalid opcode") from exc if (True if head2 & 0b10000000 else False) != mask: - raise ProtocolError("incorrect masking") + raise exceptions.ProtocolError("incorrect masking") length = head2 & 0b01111111 if length == 126: @@ -217,7 +217,9 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise exceptions.PayloadTooBig( + f"over size limit ({length} > {max_size} bytes)" + ) if mask: mask_bytes = yield from read_exact(4) @@ -306,13 +308,13 @@ def check(self) -> None: """ if self.rsv1 or self.rsv2 or self.rsv3: - raise ProtocolError("reserved bits must be 0") + raise exceptions.ProtocolError("reserved bits must be 0") if self.opcode in CTRL_OPCODES: if len(self.data) > 125: - raise ProtocolError("control frame too long") + raise exceptions.ProtocolError("control frame too long") if not self.fin: - raise ProtocolError("fragmented control frame") + raise exceptions.ProtocolError("fragmented control frame") def prepare_data(data: Data) -> Tuple[int, bytes]: @@ -380,7 +382,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: return 1005, "" else: assert length == 1 - raise ProtocolError("close frame too short") + raise exceptions.ProtocolError("close frame too short") def serialize_close(code: int, reason: str) -> bytes: @@ -403,7 +405,7 @@ def check_close(code: int) -> None: """ if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): - raise ProtocolError("invalid status code") + raise exceptions.ProtocolError("invalid status code") def format_close(code: int, reason: str) -> str: @@ -425,7 +427,3 @@ def format_close(code: int, reason: str) -> str: result += "no reason" return result - - -# at the bottom to allow circular import, because Extension depends on Frame -from . import extensions # noqa diff --git a/src/websockets/http11.py b/src/websockets/http11.py index aaa61f8c7..11b9d7f39 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -4,8 +4,7 @@ import re from typing import Callable, Generator, Optional -from .datastructures import Headers -from .exceptions import SecurityError +from . import datastructures, exceptions MAX_HEADERS = 256 @@ -50,7 +49,7 @@ class Request: """ path: str - headers: Headers + headers: datastructures.Headers # body isn't useful is the context of this library # If processing the request triggers an exception, it's stored here. @@ -75,7 +74,7 @@ def parse( :param read_line: generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data :raises EOFError: if the connection is closed without a full HTTP request - :raises SecurityError: if the request exceeds a security limit + :raises exceptions.SecurityError: if the request exceeds a security limit :raises ValueError: if the request isn't well formatted """ @@ -134,7 +133,7 @@ class Response: status_code: int reason_phrase: str - headers: Headers + headers: datastructures.Headers body: Optional[bytes] = None # If processing the response triggers an exception, it's stored here. @@ -162,7 +161,7 @@ def parse( :param read_exact: generator-based coroutine that reads the requested number of bytes or raises an exception if there isn't enough data :raises EOFError: if the connection is closed without a full HTTP response - :raises SecurityError: if the response exceeds a security limit + :raises exceptions.SecurityError: if the response exceeds a security limit :raises LookupError: if the response isn't well formatted :raises ValueError: if the response isn't well formatted @@ -242,7 +241,7 @@ def serialize(self) -> bytes: def parse_headers( read_line: Callable[[], Generator[None, None, bytes]] -) -> Generator[None, None, Headers]: +) -> Generator[None, None, datastructures.Headers]: """ Parse HTTP headers. @@ -256,7 +255,7 @@ def parse_headers( # We don't attempt to support obsolete line folding. - headers = Headers() + headers = datastructures.Headers() for _ in range(MAX_HEADERS + 1): try: line = yield from parse_line(read_line) @@ -280,7 +279,7 @@ def parse_headers( headers[name] = value else: - raise SecurityError("too many HTTP headers") + raise exceptions.SecurityError("too many HTTP headers") return headers @@ -301,7 +300,7 @@ def parse_line( line = yield from read_line() # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: - raise SecurityError("line too long") + raise exceptions.SecurityError("line too long") # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 468b5c15c..9197bd504 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -34,7 +34,7 @@ RedirectHandshake, SecurityError, ) -from ..extensions.base import ClientExtensionFactory, Extension +from ..extensions import ClientExtensionFactory, Extension from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import ( build_authorization_basic, diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 14667925f..12e006911 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -15,8 +15,8 @@ import struct from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence +from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError -from ..frames import Frame as NewFrame, Opcode try: @@ -28,15 +28,15 @@ class Frame(NamedTuple): fin: bool - opcode: Opcode + opcode: frames.Opcode data: bytes rsv1: bool = False rsv2: bool = False rsv3: bool = False @property - def new_frame(self) -> NewFrame: - return NewFrame( + def new_frame(self) -> frames.Frame: + return frames.Frame( self.opcode, self.data, self.fin, @@ -89,7 +89,7 @@ async def read( rsv3 = True if head1 & 0b00010000 else False try: - opcode = Opcode(head1 & 0b00001111) + opcode = frames.Opcode(head1 & 0b00001111) except ValueError as exc: raise ProtocolError("invalid opcode") from exc @@ -113,7 +113,7 @@ async def read( if mask: data = apply_mask(data, mask_bits) - new_frame = NewFrame(opcode, data, fin, rsv1, rsv2, rsv3) + new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3) if extensions is None: extensions = [] @@ -158,9 +158,6 @@ def write( write(self.new_frame.serialize(mask=mask, extensions=extensions)) -# at the bottom to allow circular import, because Extension depends on Frame -from .. import extensions # noqa - # Backwards compatibility with previously documented public APIs from ..frames import parse_close # noqa from ..frames import prepare_data # noqa diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index f8daf544b..45204ea3f 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -42,7 +42,7 @@ PayloadTooBig, ProtocolError, ) -from ..extensions.base import Extension +from ..extensions import Extension from ..frames import ( OP_BINARY, OP_CLOSE, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index a7a98e006..10416abc5 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -39,7 +39,7 @@ InvalidUpgrade, NegotiationError, ) -from ..extensions.base import Extension, ServerExtensionFactory +from ..extensions import Extension, ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import build_extension, parse_extension, parse_subprotocol from ..http import USER_AGENT diff --git a/src/websockets/server.py b/src/websockets/server.py index 09ed63150..a53798f65 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -17,7 +17,7 @@ InvalidUpgrade, NegotiationError, ) -from .extensions.base import Extension, ServerExtensionFactory +from .extensions import Extension, ServerExtensionFactory from .headers import ( build_extension, parse_connection, diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 7406a60a8..ed4521d53 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -13,7 +13,7 @@ import urllib.parse from typing import Optional, Tuple -from .exceptions import InvalidURI +from . import exceptions __all__ = ["parse_uri", "WebSocketURI"] @@ -59,7 +59,7 @@ def parse_uri(uri: str) -> WebSocketURI: assert parsed.fragment == "" assert parsed.hostname is not None except AssertionError as exc: - raise InvalidURI(uri) from exc + raise exceptions.InvalidURI(uri) from exc secure = parsed.scheme == "wss" host = parsed.hostname @@ -72,7 +72,7 @@ def parse_uri(uri: str) -> WebSocketURI: # urllib.parse.urlparse accepts URLs with a username but without a # password. This doesn't make sense for HTTP Basic Auth credentials. if parsed.password is None: - raise InvalidURI(uri) + raise exceptions.InvalidURI(uri) user_info = (parsed.username, parsed.password) try: From d6188e71df9b5cf6dabadb352ac0ce489e4408cf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 11 Jun 2021 09:57:29 +0200 Subject: [PATCH 082/762] Standardize import style. --- .../extensions/permessage_deflate.py | 45 ++++++++----------- src/websockets/headers.py | 41 +++++++++++------ 2 files changed, 46 insertions(+), 40 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 5604fb8f9..0c9088a9e 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -10,14 +10,7 @@ import zlib from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -from .. import frames -from ..exceptions import ( - DuplicateParameter, - InvalidParameterName, - InvalidParameterValue, - NegotiationError, - PayloadTooBig, -) +from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -140,7 +133,7 @@ def decode( max_length = 0 if max_size is None else max_size data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -224,40 +217,40 @@ def _extract_parameters( if name == "server_no_context_takeover": if server_no_context_takeover: - raise DuplicateParameter(name) + raise exceptions.DuplicateParameter(name) if value is None: server_no_context_takeover = True else: - raise InvalidParameterValue(name, value) + raise exceptions.InvalidParameterValue(name, value) elif name == "client_no_context_takeover": if client_no_context_takeover: - raise DuplicateParameter(name) + raise exceptions.DuplicateParameter(name) if value is None: client_no_context_takeover = True else: - raise InvalidParameterValue(name, value) + raise exceptions.InvalidParameterValue(name, value) elif name == "server_max_window_bits": if server_max_window_bits is not None: - raise DuplicateParameter(name) + raise exceptions.DuplicateParameter(name) if value in _MAX_WINDOW_BITS_VALUES: server_max_window_bits = int(value) else: - raise InvalidParameterValue(name, value) + raise exceptions.InvalidParameterValue(name, value) elif name == "client_max_window_bits": if client_max_window_bits is not None: - raise DuplicateParameter(name) + raise exceptions.DuplicateParameter(name) if is_server and value is None: # only in handshake requests client_max_window_bits = True elif value in _MAX_WINDOW_BITS_VALUES: client_max_window_bits = int(value) else: - raise InvalidParameterValue(name, value) + raise exceptions.InvalidParameterValue(name, value) else: - raise InvalidParameterName(name) + raise exceptions.InvalidParameterName(name) return ( server_no_context_takeover, @@ -344,7 +337,7 @@ def process_response_params( """ if any(other.name == self.name for other in accepted_extensions): - raise NegotiationError(f"received duplicate {self.name}") + raise exceptions.NegotiationError(f"received duplicate {self.name}") # Request parameters are available in instance variables. @@ -370,7 +363,7 @@ def process_response_params( if self.server_no_context_takeover: if not server_no_context_takeover: - raise NegotiationError("expected server_no_context_takeover") + raise exceptions.NegotiationError("expected server_no_context_takeover") # client_no_context_takeover # @@ -400,9 +393,9 @@ def process_response_params( else: if server_max_window_bits is None: - raise NegotiationError("expected server_max_window_bits") + raise exceptions.NegotiationError("expected server_max_window_bits") elif server_max_window_bits > self.server_max_window_bits: - raise NegotiationError("unsupported server_max_window_bits") + raise exceptions.NegotiationError("unsupported server_max_window_bits") # client_max_window_bits @@ -418,7 +411,7 @@ def process_response_params( if self.client_max_window_bits is None: if client_max_window_bits is not None: - raise NegotiationError("unexpected client_max_window_bits") + raise exceptions.NegotiationError("unexpected client_max_window_bits") elif self.client_max_window_bits is True: pass @@ -427,7 +420,7 @@ def process_response_params( if client_max_window_bits is None: client_max_window_bits = self.client_max_window_bits elif client_max_window_bits > self.client_max_window_bits: - raise NegotiationError("unsupported client_max_window_bits") + raise exceptions.NegotiationError("unsupported client_max_window_bits") return PerMessageDeflate( server_no_context_takeover, # remote_no_context_takeover @@ -526,7 +519,7 @@ def process_request_params( """ if any(other.name == self.name for other in accepted_extensions): - raise NegotiationError(f"skipped duplicate {self.name}") + raise exceptions.NegotiationError(f"skipped duplicate {self.name}") # Load request parameters in local variables. ( @@ -604,7 +597,7 @@ def process_request_params( else: if client_max_window_bits is None: - raise NegotiationError("required client_max_window_bits") + raise exceptions.NegotiationError("required client_max_window_bits") elif client_max_window_bits is True: client_max_window_bits = self.client_max_window_bits elif self.client_max_window_bits < client_max_window_bits: diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 12d2a4e94..82aaa8848 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -11,7 +11,7 @@ import re from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast -from .exceptions import InvalidHeaderFormat, InvalidHeaderValue +from . import exceptions from .typing import ( ConnectionOption, ExtensionHeader, @@ -87,7 +87,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: """ match = _token_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected token", header, pos) + raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos) return match.group(), match.end() @@ -110,7 +110,9 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i """ match = _quoted_string_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos) + raise exceptions.InvalidHeaderFormat( + header_name, "expected quoted string", header, pos + ) return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() @@ -181,7 +183,9 @@ def parse_list( if peek_ahead(header, pos) == ",": pos = parse_OWS(header, pos + 1) else: - raise InvalidHeaderFormat(header_name, "expected comma", header, pos) + raise exceptions.InvalidHeaderFormat( + header_name, "expected comma", header, pos + ) # Remove extra delimiters before the next item. while peek_ahead(header, pos) == ",": @@ -244,7 +248,9 @@ def parse_upgrade_protocol( """ match = _protocol_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected protocol", header, pos) + raise exceptions.InvalidHeaderFormat( + header_name, "expected protocol", header, pos + ) return cast(UpgradeProtocol, match.group()), match.end() @@ -285,7 +291,7 @@ def parse_extension_item_param( # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value # after quoted-string unescaping MUST conform to the 'token' ABNF. if _token_re.fullmatch(value) is None: - raise InvalidHeaderFormat( + raise exceptions.InvalidHeaderFormat( header_name, "invalid quoted header content", header, pos_before ) else: @@ -451,7 +457,9 @@ def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: """ match = _token68_re.match(header, pos) if match is None: - raise InvalidHeaderFormat(header_name, "expected token68", header, pos) + raise exceptions.InvalidHeaderFormat( + header_name, "expected token68", header, pos + ) return match.group(), match.end() @@ -461,7 +469,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None: """ if pos < len(header): - raise InvalidHeaderFormat(header_name, "trailing data", header, pos) + raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos) def parse_authorization_basic(header: str) -> Tuple[str, str]: @@ -479,9 +487,12 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: # https://tools.ietf.org/html/rfc7617#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": - raise InvalidHeaderValue("Authorization", f"unsupported scheme: {scheme}") + raise exceptions.InvalidHeaderValue( + "Authorization", + f"unsupported scheme: {scheme}", + ) if peek_ahead(header, pos) != " ": - raise InvalidHeaderFormat( + raise exceptions.InvalidHeaderFormat( "Authorization", "expected space after scheme", header, pos ) pos += 1 @@ -491,14 +502,16 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: try: user_pass = base64.b64decode(basic_credentials.encode()).decode() except binascii.Error: - raise InvalidHeaderValue( - "Authorization", "expected base64-encoded credentials" + raise exceptions.InvalidHeaderValue( + "Authorization", + "expected base64-encoded credentials", ) from None try: username, password = user_pass.split(":", 1) except ValueError: - raise InvalidHeaderValue( - "Authorization", "expected username:password credentials" + raise exceptions.InvalidHeaderValue( + "Authorization", + "expected username:password credentials", ) from None return username, password From 85e8799de819da0d6e370c61b10f72a5368fa40b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 14:26:46 +0200 Subject: [PATCH 083/762] Deduplicate State class. --- src/websockets/legacy/protocol.py | 9 +-------- tests/legacy/test_client_server.py | 2 +- tests/legacy/test_protocol.py | 3 ++- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 45204ea3f..c7dd4d22c 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -12,7 +12,6 @@ import asyncio import codecs import collections -import enum import logging import random import struct @@ -33,6 +32,7 @@ cast, ) +from ..connection import State from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, @@ -64,13 +64,6 @@ __all__ = ["WebSocketCommonProtocol"] -# A WebSocket connection goes through the following four states, in order: - - -class State(enum.IntEnum): - CONNECTING, OPEN, CLOSING, CLOSED = range(4) - - # In order to ensure consistency, the code always checks the current value of # WebSocketCommonProtocol.state before assigning a new value and never yields # between the check and the assignment. diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index d3e3f1e9f..f3e96ac4e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -15,6 +15,7 @@ import urllib.request import warnings +from websockets.connection import State from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -32,7 +33,6 @@ from websockets.legacy.client import * from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response -from websockets.legacy.protocol import State from websockets.legacy.server import * from websockets.uri import parse_uri diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ccbbffe7c..11589baf1 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -5,6 +5,7 @@ import unittest.mock import warnings +from websockets.connection import State from websockets.exceptions import ConnectionClosed, InvalidState from websockets.frames import ( OP_BINARY, @@ -17,7 +18,7 @@ ) from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame -from websockets.legacy.protocol import State, WebSocketCommonProtocol +from websockets.legacy.protocol import WebSocketCommonProtocol from .utils import MS, AsyncioTestCase From f96c0d0e71e781e8603561c1d0e3c2322ffd76ed Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 14:20:31 +0200 Subject: [PATCH 084/762] Add a utility function for broadcasting messages. Fix #870. --- docs/project/changelog.rst | 2 + docs/reference/utilities.rst | 5 +++ src/websockets/__init__.py | 2 + src/websockets/legacy/protocol.py | 73 ++++++++++++++++++++++++++----- src/websockets/legacy/server.py | 4 +- tests/legacy/test_protocol.py | 48 +++++++++++++++++++- 6 files changed, 120 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a0cd8f07c..c8553d17d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -51,6 +51,8 @@ They may change at any time. * Added compatibility with Python 3.10. +* Added :func:`~websockets.broadcast` to send a message to many clients. + * Added support for reconnecting automatically by using :func:`~legacy.client.connect` as an asynchronous iterator. diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index 198e928b0..f1d89eddc 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -1,6 +1,11 @@ Utilities ========= +Broadcast +--------- + +.. autofunction:: websockets.broadcast + Data structures --------------- diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 8c69a6d63..6378b82cf 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -6,6 +6,7 @@ "AbortHandshake", "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", + "broadcast", "ClientConnection", "connect", "ConnectionClosed", @@ -56,6 +57,7 @@ "auth": ".legacy", "basic_auth_protocol_factory": ".legacy.auth", "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "broadcast": ".legacy.protocol", "ClientConnection": ".client", "connect": ".legacy.client", "unix_connect": ".legacy.client", diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c7dd4d22c..8a9557792 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -61,7 +61,7 @@ from .framing import Frame -__all__ = ["WebSocketCommonProtocol"] +__all__ = ["WebSocketCommonProtocol", "broadcast"] # In order to ensure consistency, the code always checks the current value of @@ -974,15 +974,7 @@ async def read_frame(self, max_size: Optional[int]) -> Frame: self.logger.debug("< %s", frame) return frame - async def write_frame( - self, fin: bool, opcode: int, data: bytes, *, _expected_state: int = State.OPEN - ) -> None: - # Defensive assertion for protocol compliance. - if self.state is not _expected_state: # pragma: no cover - raise InvalidState( - f"Cannot write to a WebSocket in the {self.state.name} state" - ) - + def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None: frame = Frame(fin, Opcode(opcode), data) if self.debug: self.logger.debug("> %s", frame) @@ -992,6 +984,7 @@ async def write_frame( extensions=self.extensions, ) + async def drain(self) -> None: try: # drain() cannot be called concurrently by multiple coroutines: # http://bugs.python.org/issue29930. Remove this lock when no @@ -1006,6 +999,17 @@ async def write_frame( # with the correct code and reason. await self.ensure_open() + async def write_frame( + self, fin: bool, opcode: int, data: bytes, *, _state: int = State.OPEN + ) -> None: + # Defensive assertion for protocol compliance. + if self.state is not _state: # pragma: no cover + raise InvalidState( + f"Cannot write to a WebSocket in the {self.state.name} state" + ) + self.write_frame_sync(fin, opcode, data) + await self.drain() + async def write_close_frame(self, data: bytes = b"") -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1023,7 +1027,7 @@ async def write_close_frame(self, data: bytes = b"") -> None: self.logger.debug("= connection is CLOSING") # 7.1.2. Start the WebSocket Closing Handshake - await self.write_frame(True, OP_CLOSE, data, _expected_state=State.CLOSING) + await self.write_frame(True, OP_CLOSE, data, _state=State.CLOSING) async def keepalive_ping(self) -> None: """ @@ -1371,3 +1375,50 @@ def eof_received(self) -> None: """ self.reader.feed_eof() + + +def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> None: + """ + Broadcast a message to several WebSocket connections. + + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. + + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + + :func:`broadcast` pushes the message synchronously to all connections even + if their write buffers overflow ``write_limit``. There's no backpressure. + + :func:`broadcast` skips silently connections that aren't open in order to + avoid errors on connections where the closing handshake is in progress. + + If you broadcast messages faster than a connection can handle them, + messages will pile up in its write buffer until the connection times out. + Keep low values for ``ping_interval`` and ``ping_timeout`` to prevent + excessive memory use by slow connections when you use :func:`broadcast`. + + Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering + them in memory, while :func:`broadcast` buffers one copy per connection + as fast as possible. + + :raises RuntimeError: if a connection is busy sending a fragmented message + :raises TypeError: for unsupported inputs + + """ + if not isinstance(message, (str, bytes, bytearray, memoryview)): + raise TypeError("data must be bytes, str, or iterable") + + opcode, data = prepare_data(message) + + for websocket in websockets: + if websocket.state is not State.OPEN: + continue + + if websocket._fragmented_message_waiter is not None: + raise RuntimeError("busy sending a fragmented message") + + websocket.write_frame_sync(True, opcode, data) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 10416abc5..ae749aa5f 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -836,8 +836,8 @@ async def _close(self) -> None: await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING conections with a HTTP 503 error. - # Wait until all connections are closed. + # closed, handshake() closes OPENING connections with a HTTP 503 + # error. Wait until all connections are closed. # asyncio.wait doesn't accept an empty first argument if self.websockets: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 11589baf1..97aa01596 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -18,7 +18,7 @@ ) from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame -from websockets.legacy.protocol import WebSocketCommonProtocol +from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast from .utils import MS, AsyncioTestCase @@ -1388,6 +1388,52 @@ def test_remote_close_during_send(self): # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. + def test_broadcast_text(self): + broadcast([self.protocol], "café") + self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + + def test_broadcast_binary(self): + broadcast([self.protocol], b"tea") + self.assertOneFrameSent(True, OP_BINARY, b"tea") + + def test_broadcast_type_error(self): + with self.assertRaises(TypeError): + broadcast([self.protocol], ["ca", "fé"]) + + def test_broadcast_no_clients(self): + broadcast([], "café") + self.assertNoFrameSent() + + def test_broadcast_two_clients(self): + broadcast([self.protocol, self.protocol], "café") + self.assertFramesSent( + (True, OP_TEXT, "café".encode("utf-8")), + (True, OP_TEXT, "café".encode("utf-8")), + ) + + def test_broadcast_skips_closed_connection(self): + self.close_connection() + + broadcast([self.protocol], "café") + self.assertNoFrameSent() + + def test_broadcast_skips_closing_connection(self): + close_task = self.half_close_connection_local() + + broadcast([self.protocol], "café") + self.assertNoFrameSent() + + self.loop.run_until_complete(close_task) # cleanup + + def test_broadcast_within_fragmented_text(self): + self.make_drain_slow() + self.loop.create_task(self.protocol.send(["ca", "fé"])) + self.run_loop_once() + self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + + with self.assertRaises(RuntimeError): + broadcast([self.protocol], "café") + class ServerTests(CommonTests, AsyncioTestCase): def setUp(self): From a3958847c033c5d532eaef0b335c103171f9f7e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 6 Jun 2021 20:02:02 +0200 Subject: [PATCH 085/762] Add broadcast benchmarking scripts. --- experiments/broadcast/clients.py | 61 ++++++++++++ experiments/broadcast/server.py | 153 +++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 experiments/broadcast/clients.py create mode 100644 experiments/broadcast/server.py diff --git a/experiments/broadcast/clients.py b/experiments/broadcast/clients.py new file mode 100644 index 000000000..fe39dfe05 --- /dev/null +++ b/experiments/broadcast/clients.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +import asyncio +import statistics +import sys +import time + +import websockets + + +LATENCIES = {} + + +async def log_latency(interval): + while True: + await asyncio.sleep(interval) + p = statistics.quantiles(LATENCIES.values(), n=100) + print(f"clients = {len(LATENCIES)}") + print( + f"p50 = {p[49] / 1e6:.1f}ms, " + f"p95 = {p[94] / 1e6:.1f}ms, " + f"p99 = {p[98] / 1e6:.1f}ms" + ) + print() + + +async def client(): + try: + async with websockets.connect( + "ws://localhost:8765", + ping_timeout=None, + ) as websocket: + async for msg in websocket: + client_time = time.time_ns() + server_time = int(msg[:19].decode()) + LATENCIES[websocket] = client_time - server_time + except Exception as exc: + print(exc) + + +async def main(count, interval): + asyncio.create_task(log_latency(interval)) + clients = [] + for _ in range(count): + clients.append(asyncio.create_task(client())) + await asyncio.sleep(0.001) # 1ms between each connection + await asyncio.wait(clients) + + +if __name__ == "__main__": + try: + count = int(sys.argv[1]) + interval = float(sys.argv[2]) + except Exception as exc: + print(f"Usage: {sys.argv[0]} count interval") + print(" Connect clients e.g. 1000") + print(" Report latency every seconds e.g. 1") + print() + print(exc) + else: + asyncio.run(main(count, interval)) diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py new file mode 100644 index 000000000..355d6b5e9 --- /dev/null +++ b/experiments/broadcast/server.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python + +import asyncio +import functools +import os +import sys +import time + +import websockets + + +CLIENTS = set() + + +async def send(websocket, message): + try: + await websocket.send(message) + except websockets.ConnectionClosed: + pass + + +async def relay(queue, websocket): + while True: + message = await queue.get() + await websocket.send(message) + + +class PubSub: + def __init__(self): + self.waiter = asyncio.Future() + + def publish(self, value): + waiter, self.waiter = self.waiter, asyncio.Future() + waiter.set_result((value, self.waiter)) + + async def subscribe(self): + waiter = self.waiter + while True: + value, waiter = await waiter + yield value + + __aiter__ = subscribe + + +PUBSUB = PubSub() + + +async def handler(websocket, path, method=None): + if method in ["default", "naive", "task", "wait"]: + CLIENTS.add(websocket) + try: + await websocket.wait_closed() + finally: + CLIENTS.remove(websocket) + elif method == "queue": + queue = asyncio.Queue() + relay_task = asyncio.create_task(relay(queue, websocket)) + CLIENTS.add(queue) + try: + await websocket.wait_closed() + finally: + CLIENTS.remove(queue) + relay_task.cancel() + elif method == "pubsub": + async for message in PUBSUB: + await websocket.send(message) + else: + raise NotImplementedError(f"unsupported method: {method}") + + +async def broadcast(method, size, delay): + """Broadcast messages at regular intervals.""" + load_average = 0 + time_average = 0 + pc1, pt1 = time.perf_counter_ns(), time.process_time_ns() + await asyncio.sleep(delay) + while True: + print(f"clients = {len(CLIENTS)}") + pc0, pt0 = time.perf_counter_ns(), time.process_time_ns() + load_average = 0.9 * load_average + 0.1 * (pt0 - pt1) / (pc0 - pc1) + print( + f"load = {(pt0 - pt1) / (pc0 - pc1) * 100:.1f}% / " + f"average = {load_average * 100:.1f}%, " + f"late = {(pc0 - pc1 - delay * 1e9) / 1e6:.1f} ms" + ) + pc1, pt1 = pc0, pt0 + + assert size > 20 + message = str(time.time_ns()).encode() + b" " + os.urandom(size - 20) + + if method == "default": + websockets.broadcast(CLIENTS, message) + elif method == "naive": + # Since the loop can yield control, make a copy of CLIENTS + # to avoid: RuntimeError: Set changed size during iteration + for websocket in CLIENTS.copy(): + await send(websocket, message) + elif method == "task": + for websocket in CLIENTS: + asyncio.create_task(send(websocket, message)) + elif method == "wait": + if CLIENTS: # asyncio.wait doesn't accept an empty list + await asyncio.wait( + [ + asyncio.create_task(send(websocket, message)) + for websocket in CLIENTS + ] + ) + elif method == "queue": + for queue in CLIENTS: + queue.put_nowait(message) + elif method == "pubsub": + PUBSUB.publish(message) + else: + raise NotImplementedError(f"unsupported method: {method}") + + pc2 = time.perf_counter_ns() + wait = delay + (pc1 - pc2) / 1e9 + time_average = 0.9 * time_average + 0.1 * (pc2 - pc1) + print( + f"broadcast = {(pc2 - pc1) / 1e6:.1f}ms / " + f"average = {time_average / 1e6:.1f}ms, " + f"wait = {wait * 1e3:.1f}ms" + ) + await asyncio.sleep(wait) + print() + + +async def main(method, size, delay): + async with websockets.serve( + functools.partial(handler, method=method), + "localhost", + 8765, + compression=None, + ping_timeout=None, + ): + await broadcast(method, size, delay) + + +if __name__ == "__main__": + try: + method = sys.argv[1] + assert method in ["default", "naive", "task", "wait", "queue", "pubsub"] + size = int(sys.argv[2]) + delay = float(sys.argv[3]) + except Exception as exc: + print(f"Usage: {sys.argv[0]} method size delay") + print(" Start a server broadcasting messages with e.g. naive") + print(" Send a payload of bytes every seconds") + print() + print(exc) + else: + asyncio.run(main(method, size, delay)) From 185e9c6e076aecdff0aee3e858049f569cc0ed8e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 17:55:29 +0200 Subject: [PATCH 086/762] Discuss broadcasting messages. Fix #653. --- docs/spelling_wordlist.txt | 1 + docs/topics/broadcast.rst | 349 +++++++++++++++++++++++++++++++++++++ docs/topics/index.rst | 1 + 3 files changed, 351 insertions(+) create mode 100644 docs/topics/broadcast.rst diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 8346acefa..86a41ff7a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -59,6 +59,7 @@ retransmit runtime scalable serializers +stateful subclasses subclassing subprotocol diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst new file mode 100644 index 000000000..c43a8fa36 --- /dev/null +++ b/docs/topics/broadcast.rst @@ -0,0 +1,349 @@ +Broadcasting messages +===================== + +.. currentmodule: websockets + + +.. note:: + + If you just want to send a message to all connected clients, use + :func:`~websockets.broadcast`. + + If you want to learn about its design in depth, continue reading this + document. + +WebSocket servers often send the same message to all connected clients to a +subset of clients for which the message is relevant. + +Let's explore options for broadcasting a message, explain the design +of :func:`~websockets.broadcast`, and discuss alternatives. + +For each option, we'll provide a connection handler called ``handler()`` and a +function or coroutine called ``broadcast()`` that sends a message to all +connected clients. + +Integrating them is left as an exercise for the reader. You could start with:: + + import asyncio + import websockets + + async def handler(websocket, path): + ... + + async def broadcast(message): + ... + + async def broadcast_messages(): + while True: + await asyncio.sleep(1) + message = ... # your application logic goes here + await broadcast(message) + + async def main(): + async with websockets.serve(handler, "localhost", 8765): + await broadcast_messages() # runs forever + + if __name__ == "__main__": + asyncio.run(main()) + +``broadcast_messages()`` must yield control to the event loop between each +message, or else it will never let the server run. That's why it includes +``await asyncio.sleep(1)``. + +A complete example is available in the `experiments/broadcast`_ directory. + +.. _experiments/broadcast: https://github.com/aaugustin/websockets/tree/main/experiments/broadcast + +The naive way +------------- + +The most obvious way to send a message to all connected clients consists in +keeping track of them and sending the message to each of them. + +Here's a connection handler that registers clients in a global variable:: + + CLIENTS = set() + + async def handler(websocket, path): + CLIENTS.add(websocket) + try: + await websocket.wait_closed() + finally: + CLIENTS.remove(websocket) + +This implementation assumes that the client will never send any messages. If +you'd rather not make this assumption, you can change:: + + await websocket.wait_closed() + +to:: + + async for _ in websocket: + pass + +Here's a coroutine that broadcasts a message to all clients:: + + async def broadcast(message): + for websocket in CLIENTS.copy(): + try: + await websocket.send(message) + except websockets.ConnectionClosed: + pass + +There are two tricks in this version of ``broadcast()``. + +First, it makes a copy of ``CLIENTS`` before iterating it. Else, if a client +connects or disconnects while ``broadcast()`` is running, the loop would fail +with:: + + RuntimeError: Set changed size during iteration + +Second, it ignores :exc:`~exceptions.ConnectionClosed` exceptions because a +client could disconnect between the moment ``broadcast()`` makes a copy of +``CLIENTS`` and the moment it sends a message to this client. This is fine: a +client that disconnected doesn't belongs to "all connected clients" anymore. + +The naive way can be very fast. Indeed, if all connections have enough free +space in their write buffers, ``await websocket.send(message)`` writes the +message and returns immediately, as it doesn't need to wait for the buffer to +drain. In this case, ``broadcast()`` doesn't yield control to the event loop, +which minimizes overhead. + +The naive way can also fail badly. If the write buffer of a connection reaches +``write_limit``, ``broadcast()`` waits for the buffer to drain before sending +the message to other clients. This can cause a massive drop in performance. + +As a consequence, this pattern works only when write buffers never fill up, +which is usually outside of the control of the server. + +If you know for sure that you will never write more than ``write_limit`` bytes +within ``ping_interval + ping_timeout``, then websockets will terminate slow +connections before the write buffer has time to fill up. + +Don't set extreme ``write_limit``, ``ping_interval``, and ``ping_timeout`` +values to ensure that this condition holds. Set reasonable values and use the +built-in :func:`~websockets.broadcast` function instead. + +The concurrent way +------------------ + +The naive way didn't work well because it serialized writes, while the whole +point of asynchronous I/O is to perform I/O concurrently. + +Let's modify ``broadcast()`` to send messages concurrently:: + + async def send(websocket, message): + try: + await websocket.send(message) + except websockets.ConnectionClosed: + pass + + def broadcast(message): + for websocket in CLIENTS: + asyncio.create_task(send(websocket, message)) + +We move the error handling logic in a new coroutine and we schedule +a :class:`~asyncio.Task` to run it instead of executing it immediately. + +Since ``broadcast()`` no longer awaits coroutines, we can make it a function +rather than a coroutine and do away with the copy of ``CLIENTS``. + +This version of ``broadcast()`` makes clients independent from one another: a +slow client won't block others. As a side effect, it makes messages +independent from one another. + +If you broadcast several messages, there is no strong guarantee that they will +be sent in the expected order. Fortunately, the event loop runs tasks in the +order in which they are created, so the order is correct in practice. + +Technically, this is an implementation detail of the event loop. However, it +seems unlikely for an event loop to run tasks in an order other than FIFO. + +If you wanted to enforce the order without relying this implementation detail, +you could be tempted to wait until all clients have received the message:: + + async def broadcast(message): + if CLIENTS: # asyncio.wait doesn't accept an empty list + await asyncio.wait([ + asyncio.create_task(send(websocket, message)) + for websocket in CLIENTS + ]) + +However, this doesn't really work in practice. Quite often, it will block +until the slowest client times out. + +Backpressure meets broadcast +---------------------------- + +At this point, it becomes apparent that backpressure, usually a good practice, +doesn't work well when broadcasting a message to thousands of clients. + +When you're sending messages to a single client, you don't want to send them +faster than the network can transfer them and the client accept them. This is +why :meth:`~server.WebSocketServerProtocol.send` checks if the write buffer +is full and, if it is, waits until it drain, giving the network and the +client time to catch up. This provides backpressure. + +Without backpressure, you could pile up data in the write buffer until the +server process runs out of memory and the operating system kills it. + +The :meth:`~server.WebSocketServerProtocol.send` API is designed to enforce +backpressure by default. This helps users of websockets write robust programs +even if they never heard about backpressure. + +For comparison, :class:`asyncio.StreamWriter` requires users to understand +backpressure and to await :meth:`~asyncio.StreamWriter.drain` explicitly +after each :meth:`~asyncio.StreamWriter.write`. + +When broadcasting messages, backpressure consists in slowing down all clients +in an attempt to let the slowest client catch up. With thousands of clients, +the slowest one is probably timing out and isn't going to receive the message +anyway. So it doesn't make sense to synchronize with the slowest client. + +How do we avoid running out of memory when slow clients can't keep up with the +broadcast rate, then? The most straightforward option is to disconnect them. + +If a client gets too far behind, eventually it reaches the limit defined by +``ping_timeout`` and websockets terminates the connection. You can read the +discussion of :doc:`keepalive and timeouts <./timeouts>` for details. + +How :func:`~websockets.broadcast` works +--------------------------------------- + +The built-in :func:`~websockets.broadcast` function is similar to the naive +way. The main difference is that it doesn't apply backpressure. + +This provides the best performance by avoiding the overhead of scheduling and +running one task per client. + +Also, when sending text messages, encoding to UTF-8 happens only once rather +than once per client, providing a small performance gain. + +Per-client queues +----------------- + +At this point, we deal with slow clients rather brutally: we disconnect then. + +Can we do better? For example, we could decide to skip or to batch messages, +depending on how far behind a client is. + +To implement this logic, we can create a queue of messages for each client and +run a task that gets messages from the queue and sends them to the client:: + + import asyncio + + CLIENTS = set() + + async def relay(queue, websocket): + while True: + # Implement custom logic based on queue.qsize() and + # websocket.transport.get_write_buffer_size() here. + message = await queue.get() + await websocket.send(message) + + async def handler(websocket, path): + queue = asyncio.Queue() + relay_task = asyncio.create_task(relay(queue, websocket)) + CLIENTS.add(queue) + try: + await websocket.wait_closed() + finally: + CLIENTS.remove(queue) + relay_task.cancel() + +Then we can broadcast a message by pushing it to all queues:: + + def broadcast(message): + for queue in CLIENTS: + queue.put_nowait(message) + +The queues provide an additional buffer between the ``broadcast()`` function +and clients. This makes it easier to support slow clients without excessive +memory usage because queued messages aren't duplicated to write buffers +until ``relay()`` processes them. + +Publish–subscribe +----------------- + +Can we avoid centralizing the list of connected clients in a global variable? + +If each client subscribes to a stream a messages, then broadcasting becomes as +simple as publishing a message to the stream. + +Here's a message stream that supports multiple consumers:: + + class PubSub: + def __init__(self): + self.waiter = asyncio.Future() + + def publish(self, value): + waiter, self.waiter = self.waiter, asyncio.Future() + waiter.set_result((value, self.waiter)) + + async def subscribe(self): + waiter = self.waiter + while True: + value, waiter = await waiter + yield value + + __aiter__ = subscribe + + PUBSUB = PubSub() + +The stream is implemented as a linked list of futures. It isn't necessary to +synchronize consumers. They can read the stream at their own pace, +independently from one another. Once all consumers read a message, there are +no references left, therefore the garbage collector deletes it. + +The connection handler subscribes to the stream and sends messages:: + + async def handler(websocket, path): + async for message in PUBSUB: + await websocket.send(message) + +The broadcast function publishes to the stream:: + + def broadcast(message): + PUBSUB.publish(message) + +Like per-client queues, this version supports slow clients with limited memory +usage. Unlike per-client queues, it makes it difficult to tell how far behind +a client is. The ``PubSub`` class could be extended or refactored to provide +this information. + +The ``for`` loop is gone from this version of the ``broadcast()`` function. +However, there's still a ``for`` loop iterating on all clients hidden deep +inside :mod:`asyncio`. When ``publish()`` sets the result of the ``waiter`` +future, :mod:`asyncio` loops on callbacks registered with this future and +schedules them. This is how connection handlers receive the next value from +the asynchronous iterator returned by ``subscribe()``. + +Performance considerations +-------------------------- + +The built-in :func:`~websockets.broadcast` function sends all messages without +yielding control to the event loop. So does the naive way when the network +and clients are fast and reliable. + +For each client, a WebSocket frame is prepared and sent to the network. This +is the minimum amount of work required to broadcast a message. + +It would be tempting to prepare a frame and reuse it for all connections. +However, this isn't possible in general for two reasons: + +* Clients can negotiate different extensions. You would have to enforce the + same extensions with the same parameters. For example, you would have to + select some compression settings and reject clients that cannot support + these settings. + +* Extensions can be stateful, producing different encodings of the same + message depending on previous messages. For example, you would have to + disable context takeover to make compression stateless, resulting in poor + compression rates. + +All other patterns discussed above yield control to the event loop once per +client because messages are sent by different tasks. This makes them slower +than the built-in :func:`~websockets.broadcast` function. + +There is no major difference between the performance of per-message queues and +publish–subscribe. diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 9269c585a..dc192a290 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -6,6 +6,7 @@ Topics deployment authentication + broadcast compression timeouts design From fa497e501b505ba7255315cb660128d01cfbef44 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 22:13:15 +0200 Subject: [PATCH 087/762] Take advantage of broadcast() in examples. Fix #995. --- docs/intro/index.rst | 2 +- example/counter.py | 20 ++++---------------- example/django/notifications.py | 9 ++++++--- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 49e17b668..1df92ed59 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -181,7 +181,7 @@ unregister them when they disconnect. connected.add(websocket) try: # Broadcast a message to all connected clients. - await asyncio.wait([ws.send("Hello!") for ws in connected]) + websockets.broadcast(connected, "Hello!") await asyncio.sleep(10) finally: # Unregister. diff --git a/example/counter.py b/example/counter.py index a9ed61893..e41f6fabb 100755 --- a/example/counter.py +++ b/example/counter.py @@ -22,23 +22,11 @@ def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) -async def broadcast(message): - # asyncio.wait doesn't accept an empty list - if not USERS: - return - # Ignore return value. If a user disconnects before we send - # the message to them, there's nothing we can do anyway. - await asyncio.wait([ - asyncio.create_task(user.send(message)) - for user in USERS - ]) - - async def counter(websocket, path): try: # Register user USERS.add(websocket) - await broadcast(users_event()) + websockets.broadcast(USERS, users_event()) # Send current state to user await websocket.send(state_event()) # Manage state changes @@ -46,16 +34,16 @@ async def counter(websocket, path): data = json.loads(message) if data["action"] == "minus": STATE["value"] -= 1 - await broadcast(state_event()) + websockets.broadcast(USERS, state_event()) elif data["action"] == "plus": STATE["value"] += 1 - await broadcast(state_event()) + websockets.broadcast(USERS, state_event()) else: logging.error("unsupported event: %s", data) finally: # Unregister user USERS.remove(websocket) - await broadcast(users_event()) + websockets.broadcast(USERS, users_event()) async def main(): diff --git a/example/django/notifications.py b/example/django/notifications.py index 641643f92..41fb719dc 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -55,9 +55,12 @@ async def process_events(): payload = message["data"].decode() # Broadcast event to all users who have permissions to see it. event = json.loads(payload) - for websocket, connection in CONNECTIONS.items(): - if event["content_type_id"] in connection["content_type_ids"]: - asyncio.create_task(websocket.send(payload)) + recipients = ( + websocket + for websocket, connection in CONNECTIONS.items() + if event["content_type_id"] in connection["content_type_ids"] + ) + websockets.broadcast(recipients, payload) async def main(): From 4c8d987e02895e0c372ddfca63b73d86aace35e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Jun 2021 22:54:15 +0200 Subject: [PATCH 088/762] Add performance tips. Fix #968. --- docs/spelling_wordlist.txt | 1 + docs/topics/index.rst | 1 + docs/topics/performance.rst | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+) create mode 100644 docs/topics/performance.rst diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 86a41ff7a..b3389a920 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -73,6 +73,7 @@ unregister uple username uvicorn +uvloop virtualenv WebSocket websocket diff --git a/docs/topics/index.rst b/docs/topics/index.rst index dc192a290..993303106 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -12,3 +12,4 @@ Topics design memory security + performance diff --git a/docs/topics/performance.rst b/docs/topics/performance.rst new file mode 100644 index 000000000..45e23b239 --- /dev/null +++ b/docs/topics/performance.rst @@ -0,0 +1,20 @@ +Performance +=========== + +Here are tips to optimize performance. + +uvloop +------ + +You can make a websockets application faster by running it with uvloop_. + +(This advice isn't specific to websockets. It applies to any :mod:`asyncio` +application.) + +.. _uvloop: https://github.com/MagicStack/uvloop + +broadcast +--------- + +:func:`~websockets.broadcast` is the most efficient way to send a message to +many clients. From 897d1b27f362c4113308e76b989321ece3bfd4af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 14 Jun 2021 20:53:10 +0200 Subject: [PATCH 089/762] Validate subprotocols argument. Fix #263. --- src/websockets/headers.py | 19 +++++++++++++++++-- src/websockets/legacy/client.py | 4 ++++ src/websockets/legacy/server.py | 10 +++++++++- tests/legacy/test_client_server.py | 9 +++++++++ tests/test_headers.py | 15 +++++++++++++++ 5 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 82aaa8848..181e976e3 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -29,6 +29,7 @@ "build_extension", "parse_subprotocol", "build_subprotocol", + "validate_subprotocols", "build_www_authenticate_basic", "parse_authorization_basic", "build_authorization_basic", @@ -417,19 +418,33 @@ def parse_subprotocol(header: str) -> List[Subprotocol]: parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility -def build_subprotocol(protocols: Sequence[Subprotocol]) -> str: +def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str: """ Build a ``Sec-WebSocket-Protocol`` header. This is the reverse of :func:`parse_subprotocol`. """ - return ", ".join(protocols) + return ", ".join(subprotocols) build_subprotocol_list = build_subprotocol # alias for backwards compatibility +def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None: + """ + Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`. + + """ + if not isinstance(subprotocols, Sequence): + raise TypeError("subprotocols must be a list") + if isinstance(subprotocols, str): + raise TypeError("subprotocols must be a list, not a str") + for subprotocol in subprotocols: + if not _token_re.fullmatch(subprotocol): + raise ValueError(f"invalid subprotocol: {subprotocol}") + + def build_www_authenticate_basic(realm: str) -> str: """ Build a ``WWW-Authenticate`` header for HTTP Basic Auth. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 9197bd504..d94e9d7b9 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -42,6 +42,7 @@ build_subprotocol, parse_extension, parse_subprotocol, + validate_subprotocols, ) from ..http import USER_AGENT, build_host from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol @@ -559,6 +560,9 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + if subprotocols is not None: + validate_subprotocols(subprotocols) + factory = functools.partial( create_protocol, ping_interval=ping_interval, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index ae749aa5f..80d72f93e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -41,7 +41,12 @@ ) from ..extensions import Extension, ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate -from ..headers import build_extension, parse_extension, parse_subprotocol +from ..headers import ( + build_extension, + parse_extension, + parse_subprotocol, + validate_subprotocols, +) from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from .compatibility import loop_if_py_lt_38 @@ -1042,6 +1047,9 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + if subprotocols is not None: + validate_subprotocols(subprotocols) + factory = functools.partial( create_protocol, ws_handler, diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f3e96ac4e..755fcefdd 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1013,6 +1013,15 @@ def test_subprotocol(self): self.assertEqual(server_subprotocol, repr("chat")) self.assertEqual(self.client.subprotocol, "chat") + def test_invalid_subprotocol_server(self): + with self.assertRaises(TypeError): + self.start_server(subprotocols="sip") + + @with_server() + def test_invalid_subprotocol_client(self): + with self.assertRaises(TypeError): + self.start_client(subprotocols="sip") + @with_server(subprotocols=["superchat"]) @with_client("/subprotocol", subprotocols=["otherchat"]) def test_subprotocol_not_accepted(self): diff --git a/tests/test_headers.py b/tests/test_headers.py index 26d85fa5e..badec5a86 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -126,6 +126,21 @@ def test_parse_subprotocol_invalid_header(self): with self.assertRaises(InvalidHeaderFormat): parse_subprotocol(header) + def test_validate_subprotocols(self): + for subprotocols in [[], ["sip"], ["v1.usp"], ["sip", "v1.usp"]]: + with self.subTest(subprotocols=subprotocols): + validate_subprotocols(subprotocols) + + def test_validate_subprotocols_invalid(self): + for subprotocols, exception in [ + ({"sip": None}, TypeError), + ("sip", TypeError), + ([""], ValueError), + ]: + with self.subTest(subprotocols=subprotocols): + with self.assertRaises(exception): + validate_subprotocols(subprotocols) + def test_build_www_authenticate_basic(self): # Test vector from RFC 7617 self.assertEqual( From 047222c434fa8f54d92bb904438267113a09f5a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 07:42:46 +0200 Subject: [PATCH 090/762] Add an abstraction for close codes and reasons. --- src/websockets/frames.py | 104 ++++++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 34 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 510fc6198..175478794 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -372,17 +372,7 @@ def parse_close(data: bytes) -> Tuple[int, str]: :raises UnicodeDecodeError: if the reason isn't valid UTF-8 """ - length = len(data) - if length >= 2: - (code,) = struct.unpack("!H", data[:2]) - check_close(code) - reason = data[2:].decode("utf-8") - return code, reason - elif length == 0: - return 1005, "" - else: - assert length == 1 - raise exceptions.ProtocolError("close frame too short") + return dataclasses.astuple(Close.parse(data)) # type: ignore def serialize_close(code: int, reason: str) -> bytes: @@ -392,38 +382,84 @@ def serialize_close(code: int, reason: str) -> bytes: This is the reverse of :func:`parse_close`. """ - check_close(code) - return struct.pack("!H", code) + reason.encode("utf-8") + return Close(code, reason).serialize() -def check_close(code: int) -> None: +def format_close(code: int, reason: str) -> str: """ - Check that the close code has an acceptable value for a close frame. - - :raises ~websockets.exceptions.ProtocolError: if the close code - is invalid + Display a human-readable version of the close code and reason. """ - if not (code in EXTERNAL_CLOSE_CODES or 3000 <= code < 5000): - raise exceptions.ProtocolError("invalid status code") + return str(Close(code, reason)) -def format_close(code: int, reason: str) -> str: +@dataclasses.dataclass +class Close: """ - Display a human-readable version of the close code and reason. + WebSocket close code and reason. """ - if 3000 <= code < 4000: - explanation = "registered" - elif 4000 <= code < 5000: - explanation = "private use" - else: - explanation = CLOSE_CODES.get(code, "unknown") - result = f"code = {code} ({explanation}), " - if reason: - result += f"reason = {reason}" - else: - result += "no reason" + code: int + reason: str + + def __str__(self) -> str: + """ + Return a human-readable represention of a close code and reason. + + """ + if 3000 <= self.code < 4000: + explanation = "registered" + elif 4000 <= self.code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODES.get(self.code, "unknown") + result = f"code = {self.code} ({explanation}), " + + if self.reason: + result += f"reason = {self.reason}" + else: + result += "no reason" + + return result - return result + @classmethod + def parse(cls, data: bytes) -> Close: + """ + Parse the payload of a close frame. + + :raises ~websockets.exceptions.ProtocolError: if data is ill-formed + :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + + """ + if len(data) >= 2: + (code,) = struct.unpack("!H", data[:2]) + reason = data[2:].decode("utf-8") + close = cls(code, reason) + close.check() + return close + elif len(data) == 0: + return cls(1005, "") + else: + raise exceptions.ProtocolError("close frame too short") + + def serialize(self) -> bytes: + """ + Serialize the payload for a close frame. + + This is the reverse of :meth:`parse`. + + """ + self.check() + return struct.pack("!H", self.code) + self.reason.encode("utf-8") + + def check(self) -> None: + """ + Check that the close code has a valid value for a close frame. + + :raises ~websockets.exceptions.ProtocolError: if the close code + is invalid + + """ + if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): + raise exceptions.ProtocolError("invalid status code") From 5bb653f2da9351ee00275e04c63bef19da565528 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 07:59:27 +0200 Subject: [PATCH 091/762] Adopt the abstraction for close code and reasons. Preserve backwards compatibility in the framing module. --- src/websockets/__main__.py | 4 +-- src/websockets/connection.py | 14 ++++---- src/websockets/exceptions.py | 2 +- src/websockets/frames.py | 37 ++------------------- src/websockets/legacy/framing.py | 32 +++++++++++++++--- src/websockets/legacy/protocol.py | 14 ++++---- tests/extensions/test_permessage_deflate.py | 4 +-- tests/legacy/test_framing.py | 31 +++++++++++++++++ tests/legacy/test_protocol.py | 22 ++++++------ tests/test_connection.py | 6 ++-- tests/test_frames.py | 34 +++++++++---------- 11 files changed, 110 insertions(+), 90 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index dae165cd2..6347ee278 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -9,7 +9,7 @@ from typing import Any, Set from .exceptions import ConnectionClosed -from .frames import format_close +from .frames import Close from .legacy.client import connect @@ -143,7 +143,7 @@ async def run_client( finally: await websocket.close() - close_status = format_close(websocket.close_code, websocket.close_reason) + close_status = Close(websocket.close_code, websocket.close_reason) print_over_input(f"Connection closed: {close_status}.") diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 57d3a2227..067ad54ce 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -14,9 +14,8 @@ OP_PING, OP_PONG, OP_TEXT, + Close, Frame, - parse_close, - serialize_close, ) from .http11 import Request, Response from .streams import StreamReader @@ -202,7 +201,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: raise ValueError("cannot send a reason without a code") data = b"" else: - data = serialize_close(code, reason) + data = Close(code, reason).serialize() self.send_frame(Frame(OP_CLOSE, data)) # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started @@ -259,7 +258,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. if code != 1006 and self.state is OPEN: - self.send_frame(Frame(OP_CLOSE, serialize_close(code, reason))) + self.send_frame(Frame(OP_CLOSE, Close(code, reason).serialize())) self.set_state(CLOSING) if not self.eof_sent: self.send_eof() @@ -379,7 +378,8 @@ def parse(self) -> Generator[None, None, None]: self.close_frame_received = True # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - self.close_code, self.close_reason = parse_close(frame.data) + close = Close.parse(frame.data) + self.close_code, self.close_reason = close.code, close.reason if self.cur_size is not None: raise ProtocolError("incomplete fragmented message") @@ -390,8 +390,8 @@ def parse(self) -> Generator[None, None, None]: # received.)" if self.state is OPEN: # Echo the original data instead of re-serializing it with - # serialize_close() because that fails when the close frame - # is empty and parse_close() synthetizes a 1005 close code. + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthetizes a 1005 close code. # The rest is identical to send_close(). self.send_frame(Frame(OP_CLOSE, frame.data)) self.set_state(CLOSING) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f110322aa..59922be9c 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -84,7 +84,7 @@ class ConnectionClosed(WebSocketException): def __init__(self, code: int, reason: str) -> None: self.code = code self.reason = reason - super().__init__(frames.format_close(code, reason)) + super().__init__(str(frames.Close(code, reason))) class ConnectionClosedError(ConnectionClosed): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 175478794..b0de645ec 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -35,8 +35,7 @@ "Frame", "prepare_data", "prepare_ctrl", - "parse_close", - "serialize_close", + "Close", ] @@ -140,8 +139,7 @@ def __str__(self) -> str: binary = binary[:16] + b"\x00\x00" + binary[-8:] data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: - code, reason = parse_close(self.data) - data = format_close(code, reason) + data = str(Close.parse(self.data)) elif self.data: # We don't know if a Continuation frame contains text or binary. # Ping and Pong frames could contain UTF-8. Attempt to decode as @@ -362,37 +360,6 @@ def prepare_ctrl(data: Data) -> bytes: raise TypeError("data must be bytes-like or str") -def parse_close(data: bytes) -> Tuple[int, str]: - """ - Parse the payload from a close frame. - - Return ``(code, reason)``. - - :raises ~websockets.exceptions.ProtocolError: if data is ill-formed - :raises UnicodeDecodeError: if the reason isn't valid UTF-8 - - """ - return dataclasses.astuple(Close.parse(data)) # type: ignore - - -def serialize_close(code: int, reason: str) -> bytes: - """ - Serialize the payload for a close frame. - - This is the reverse of :func:`parse_close`. - - """ - return Close(code, reason).serialize() - - -def format_close(code: int, reason: str) -> str: - """ - Display a human-readable version of the close code and reason. - - """ - return str(Close(code, reason)) - - @dataclasses.dataclass class Close: """ diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 12e006911..c8ae48690 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -12,8 +12,9 @@ from __future__ import annotations +import dataclasses import struct -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -159,7 +160,28 @@ def write( # Backwards compatibility with previously documented public APIs -from ..frames import parse_close # noqa -from ..frames import prepare_data # noqa -from ..frames import serialize_close # noqa -from ..frames import prepare_ctrl as encode_data # noqa + +from ..frames import Close, prepare_ctrl as encode_data, prepare_data # noqa + + +def parse_close(data: bytes) -> Tuple[int, str]: + """ + Parse the payload from a close frame. + + Return ``(code, reason)``. + + :raises ~websockets.exceptions.ProtocolError: if data is ill-formed + :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + + """ + return dataclasses.astuple(Close.parse(data)) # type: ignore + + +def serialize_close(code: int, reason: str) -> bytes: + """ + Serialize the payload for a close frame. + + This is the reverse of :func:`parse_close`. + + """ + return Close(code, reason).serialize() diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 8a9557792..c3221bb45 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -50,11 +50,10 @@ OP_PING, OP_PONG, OP_TEXT, + Close, Opcode, - parse_close, prepare_ctrl, prepare_data, - serialize_close, ) from ..typing import Data, LoggerLike, Subprotocol from .compatibility import loop_if_py_lt_38 @@ -609,7 +608,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: """ try: await asyncio.wait_for( - self.write_close_frame(serialize_close(code, reason)), + self.write_close_frame(Close(code, reason).serialize()), self.close_timeout, **loop_if_py_lt_38(self.loop), ) @@ -918,11 +917,12 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: if frame.opcode == OP_CLOSE: # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - self.close_code, self.close_reason = parse_close(frame.data) + close = Close.parse(frame.data) + self.close_code, self.close_reason = close.code, close.reason try: # Echo the original data instead of re-serializing it with - # serialize_close() because that fails when the close frame - # is empty and parse_close() synthetizes a 1005 close code. + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthetizes a 1005 close code. await self.write_close_frame(frame.data) except ConnectionClosed: # Connection closed before we could echo the close frame. @@ -1220,7 +1220,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # Don't send a close frame if the connection is broken. if code != 1006 and self.state is State.OPEN: - frame_data = serialize_close(code, reason) + frame_data = Close(code, reason).serialize() # Write the close frame without draining the write buffer. diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 5ba4d8ddf..bcd08a7ef 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -17,8 +17,8 @@ OP_PING, OP_PONG, OP_TEXT, + Close, Frame, - serialize_close, ) from .test_base import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory @@ -70,7 +70,7 @@ def test_no_encode_decode_pong_frame(self): self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_close_frame(self): - frame = Frame(OP_CLOSE, serialize_close(1000, "")) + frame = Frame(OP_CLOSE, Close(1000, "").serialize()) self.assertEqual(self.extension.encode(frame), frame) diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 2baa827a9..4646817a8 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -171,3 +171,34 @@ def decode(frame, *, max_size=None): self.round_trip( b"\x81\x05uryyb", Frame(True, OP_TEXT, b"hello"), extensions=[Rot13()] ) + + +class ParseAndSerializeCloseTests(unittest.TestCase): + def assertCloseData(self, code, reason, data): + """ + Serializing code / reason yields data. Parsing data yields code / reason. + + """ + serialized = serialize_close(code, reason) + self.assertEqual(serialized, data) + parsed = parse_close(data) + self.assertEqual(parsed, (code, reason)) + + def test_parse_close_and_serialize_close(self): + self.assertCloseData(1000, "", b"\x03\xe8") + self.assertCloseData(1000, "OK", b"\x03\xe8OK") + + def test_parse_close_empty(self): + self.assertEqual(parse_close(b""), (1005, "")) + + def test_parse_close_errors(self): + with self.assertRaises(ProtocolError): + parse_close(b"\x03") + with self.assertRaises(ProtocolError): + parse_close(b"\x03\xe7") + with self.assertRaises(UnicodeDecodeError): + parse_close(b"\x03\xe8\xff\xff") + + def test_serialize_close_errors(self): + with self.assertRaises(ProtocolError): + serialize_close(999, "") diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 97aa01596..1e3f1b77e 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -14,7 +14,7 @@ OP_PING, OP_PONG, OP_TEXT, - serialize_close, + Close, ) from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame @@ -113,9 +113,9 @@ async def delayed_drain(): self.protocol._drain = delayed_drain - close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) - local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) - remote_close = Frame(True, OP_CLOSE, serialize_close(1000, "remote")) + close_frame = Frame(True, OP_CLOSE, Close(1000, "close").serialize()) + local_close = Frame(True, OP_CLOSE, Close(1000, "local").serialize()) + remote_close = Frame(True, OP_CLOSE, Close(1000, "remote").serialize()) def receive_frame(self, frame): """ @@ -156,7 +156,7 @@ def close_connection(self, code=1000, reason="close"): This puts the connection in the CLOSED state. """ - close_frame_data = serialize_close(code, reason) + close_frame_data = Close(code, reason).serialize() # Prepare the response to the closing handshake from the remote side. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.receive_eof_if_client() @@ -178,7 +178,7 @@ def half_close_connection_local(self, code=1000, reason="close"): canceled, else asyncio complains about destroying a pending task. """ - close_frame_data = serialize_close(code, reason) + close_frame_data = Close(code, reason).serialize() # Trigger the closing handshake from the local endpoint. close_task = self.loop.create_task(self.protocol.close(code, reason)) self.run_loop_once() # wait_for executes @@ -213,7 +213,7 @@ def half_close_connection_remote(self, code=1000, reason="close"): if not self.protocol.is_client: self.make_drain_slow() - close_frame_data = serialize_close(code, reason) + close_frame_data = Close(code, reason).serialize() # Trigger the closing handshake from the remote endpoint. self.receive_frame(Frame(True, OP_CLOSE, close_frame_data)) self.run_loop_once() # read_frame executes @@ -298,7 +298,7 @@ def assertConnectionFailed(self, code, message): if code == 1006: self.assertNoFrameSent() else: - self.assertOneFrameSent(True, OP_CLOSE, serialize_close(code, message)) + self.assertOneFrameSent(True, OP_CLOSE, Close(code, message).serialize()) @contextlib.contextmanager def assertCompletesWithin(self, min_time, max_time): @@ -649,7 +649,7 @@ def test_send_iterable_mixed_type_error(self): self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, serialize_close(1011, "")), + (True, OP_CLOSE, Close(1011, "").serialize()), ) def test_send_iterable_prevents_concurrent_send(self): @@ -723,7 +723,7 @@ def test_send_async_iterable_mixed_type_error(self): ) self.assertFramesSent( (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, serialize_close(1011, "")), + (True, OP_CLOSE, Close(1011, "").serialize()), ) def test_send_async_iterable_prevents_concurrent_send(self): @@ -1172,7 +1172,7 @@ def test_keepalive_ping_not_acknowledged_closes_connection(self): # Connection is closed at 6ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) self.assertOneFrameSent( - True, OP_CLOSE, serialize_close(1011, "keepalive ping timeout") + True, OP_CLOSE, Close(1011, "keepalive ping timeout").serialize() ) # The keepalive ping task is complete. diff --git a/tests/test_connection.py b/tests/test_connection.py index 677881238..f2ce8de46 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -9,8 +9,8 @@ OP_PING, OP_PONG, OP_TEXT, + Close, Frame, - serialize_close, ) from .extensions.utils import Rsv2Extension @@ -60,7 +60,7 @@ def assertConnectionClosing(self, connection, code=None, reason=""): """ close_frame = Frame( OP_CLOSE, - b"" if code is None else serialize_close(code, reason), + b"" if code is None else Close(code, reason).serialize(), ) # A close frame was received. self.assertFrameReceived(connection, close_frame) @@ -76,7 +76,7 @@ def assertConnectionFailing(self, connection, code=None, reason=""): """ close_frame = Frame( OP_CLOSE, - b"" if code is None else serialize_close(code, reason), + b"" if code is None else Close(code, reason).serialize(), ) # No frame was received. self.assertFrameReceived(connection, None) diff --git a/tests/test_frames.py b/tests/test_frames.py index 85414e9d3..f5ed7ebc1 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -394,32 +394,32 @@ def test_prepare_ctrl_none(self): prepare_ctrl(None) -class ParseAndSerializeCloseTests(unittest.TestCase): - def assertCloseData(self, code, reason, data): +class CloseTests(unittest.TestCase): + def assertCloseData(self, close, data): """ - Serializing code / reason yields data. Parsing data yields code / reason. + Serializing close yields data. Parsing data yields close. """ - serialized = serialize_close(code, reason) + serialized = close.serialize() self.assertEqual(serialized, data) - parsed = parse_close(data) - self.assertEqual(parsed, (code, reason)) + parsed = Close.parse(data) + self.assertEqual(parsed, close) - def test_parse_close_and_serialize_close(self): - self.assertCloseData(1000, "", b"\x03\xe8") - self.assertCloseData(1000, "OK", b"\x03\xe8OK") + def test_parse_and_serialize(self): + self.assertCloseData(Close(1000, ""), b"\x03\xe8") + self.assertCloseData(Close(1000, "OK"), b"\x03\xe8OK") - def test_parse_close_empty(self): - self.assertEqual(parse_close(b""), (1005, "")) + def test_parse_empty(self): + self.assertEqual(Close.parse(b""), Close(1005, "")) - def test_parse_close_errors(self): + def test_parse_errors(self): with self.assertRaises(ProtocolError): - parse_close(b"\x03") + Close.parse(b"\x03") with self.assertRaises(ProtocolError): - parse_close(b"\x03\xe7") + Close.parse(b"\x03\xe7") with self.assertRaises(UnicodeDecodeError): - parse_close(b"\x03\xe8\xff\xff") + Close.parse(b"\x03\xe8\xff\xff") - def test_serialize_close_errors(self): + def test_serialize_errors(self): with self.assertRaises(ProtocolError): - serialize_close(999, "") + Close(999, "").serialize() From 2dd793ffc0e0faab2ed34e4e06fc0b9702a84999 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 08:42:03 +0200 Subject: [PATCH 092/762] Simplify display of close codes and reasons. --- src/websockets/frames.py | 6 ++---- src/websockets/legacy/protocol.py | 2 +- tests/test_exceptions.py | 12 ++++++------ tests/test_frames.py | 4 ++-- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index b0de645ec..15b1a9de2 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -381,12 +381,10 @@ def __str__(self) -> str: explanation = "private use" else: explanation = CLOSE_CODES.get(self.code, "unknown") - result = f"code = {self.code} ({explanation}), " + result = f"{self.code} ({explanation})" if self.reason: - result += f"reason = {self.reason}" - else: - result += "no reason" + result = f"{result} {self.reason}" return result diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c3221bb45..7443a565e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1308,7 +1308,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: if not hasattr(self, "close_reason"): self.close_reason = "" self.logger.debug( - "= connection is CLOSED - code = %d, reason = %s", + "= connection is CLOSED - %d %s", self.close_code, self.close_reason or "[no reason]", ) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 094fb6d33..9c8eef4fc 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -14,27 +14,27 @@ def test_str(self): ), ( ConnectionClosed(1000, ""), - "code = 1000 (OK), no reason", + "1000 (OK)", ), ( ConnectionClosed(1006, None), - "code = 1006 (connection closed abnormally [internal]), no reason" + "1006 (connection closed abnormally [internal])" ), ( ConnectionClosed(3000, None), - "code = 3000 (registered), no reason" + "3000 (registered)" ), ( ConnectionClosed(4000, None), - "code = 4000 (private use), no reason" + "4000 (private use)" ), ( ConnectionClosedError(1016, None), - "code = 1016 (unknown), no reason" + "1016 (unknown)" ), ( ConnectionClosedOK(1001, "bye"), - "code = 1001 (going away), reason = bye", + "1001 (going away) bye", ), ( InvalidHandshake("invalid request"), diff --git a/tests/test_frames.py b/tests/test_frames.py index f5ed7ebc1..7620fe415 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -280,13 +280,13 @@ def test_binary_truncated(self): def test_close(self): self.assertEqual( str(Frame(OP_CLOSE, b"\x03\xe8")), - "CLOSE code = 1000 (OK), no reason [2 bytes]", + "CLOSE 1000 (OK) [2 bytes]", ) def test_close_reason(self): self.assertEqual( str(Frame(OP_CLOSE, b"\x03\xe9Bye!")), - "CLOSE code = 1001 (going away), reason = Bye! [6 bytes]", + "CLOSE 1001 (going away) Bye! [6 bytes]", ) def test_ping(self): From 62eb267c34034e97813bc210f0b04a6acbedd7a5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 22:21:01 +0200 Subject: [PATCH 093/762] Add details to ConnectionClosed. Fix #587. Ref #767. --- docs/howto/django.rst | 4 +- docs/howto/faq.rst | 11 +-- docs/howto/haproxy.rst | 2 +- docs/howto/heroku.rst | 4 +- docs/howto/kubernetes.rst | 4 +- docs/howto/nginx.rst | 2 +- docs/howto/supervisor.rst | 4 +- docs/project/changelog.rst | 9 +++ docs/reference/client.rst | 12 +--- docs/reference/server.rst | 12 +--- src/websockets/__main__.py | 1 + src/websockets/exceptions.py | 54 ++++++++++---- src/websockets/frames.py | 6 +- src/websockets/legacy/protocol.py | 114 ++++++++++++++++++++---------- tests/legacy/test_protocol.py | 14 ++++ tests/test_exceptions.py | 33 +++++---- tests/test_frames.py | 9 ++- 17 files changed, 195 insertions(+), 100 deletions(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 001d1cb73..c87c3821c 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -146,7 +146,7 @@ prompt: Connected to ws://localhost:8888/ > < Hello ! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). It works! @@ -158,7 +158,7 @@ closes the connection: $ python -m websockets ws://localhost:8888/ Connected to ws://localhost:8888. > not a token - Connection closed: code = 1011 (unexpected error), reason = authentication failed. + Connection closed: 1011 (unexpected error) authentication failed. You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index ce3704a2e..a4ad84680 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -244,8 +244,8 @@ See `issue 867`_. Both sides ---------- -What does ``ConnectionClosedError: code = 1006`` mean? -...................................................... +What does ``ConnectionClosedError: no close frame received or sent`` mean? +.......................................................................... If you're seeing this traceback in the logs of a server: @@ -260,7 +260,7 @@ If you're seeing this traceback in the logs of a server: Traceback (most recent call last): ... - websockets.exceptions.ConnectionClosedError: code = 1006 (connection closed abnormally [internal]), no reason + websockets.exceptions.ConnectionClosedError: no close frame received or sent or if a client crashes with this traceback: @@ -274,10 +274,11 @@ or if a client crashes with this traceback: Traceback (most recent call last): ... - websockets.exceptions.ConnectionClosedError: code = 1006 (connection closed abnormally [internal]), no reason + websockets.exceptions.ConnectionClosedError: no close frame received or sent it means that the TCP connection was lost. As a consequence, the WebSocket -connection was closed without receiving a close frame, which is abnormal. +connection was closed without receiving and sending a close frame, which is +abnormal. You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it from being logged. diff --git a/docs/howto/haproxy.rst b/docs/howto/haproxy.rst index d520d278a..0ecb46a04 100644 --- a/docs/howto/haproxy.rst +++ b/docs/howto/haproxy.rst @@ -58,4 +58,4 @@ You can confirm that HAProxy proxies connections properly: Connected to ws://localhost:8080/. > Hello! < Hello! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 92dee5f27..b728106e9 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -142,7 +142,7 @@ then press Ctrl-D to terminate the connection: > Hello! < Hello! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). You can also confirm that your application shuts down gracefully. Connect an interactive client again — remember to replace ``websockets-echo`` with your app: @@ -167,7 +167,7 @@ away). $ python -m websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. - Connection closed: code = 1001 (going away), no reason. + Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing handshake and the connection would be closed with code 1006 (connection closed diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst index ef5c963d7..0e77aeac1 100644 --- a/docs/howto/kubernetes.rst +++ b/docs/howto/kubernetes.rst @@ -73,7 +73,7 @@ shut down gracefully: Connected to ws://localhost:32080/. > Hey there! < Hey there! - Connection closed: code = 1001 (going away), no reason. + Connection closed: 1001 (going away). If it didn't, you'd get code 1006 (connection closed abnormally). @@ -119,7 +119,7 @@ You can connect to the service — press Ctrl-D to exit: $ python -m websockets ws://localhost:32080/ Connected to ws://localhost:32080/. - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). Validate deployment ------------------- diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index cb4dc83f2..e20f82098 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -81,4 +81,4 @@ You can confirm that nginx proxies connections properly: Connected to ws://localhost:8080/. > Hello! < Hello! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). diff --git a/docs/howto/supervisor.rst b/docs/howto/supervisor.rst index 6d679ca2b..0c07aebae 100644 --- a/docs/howto/supervisor.rst +++ b/docs/howto/supervisor.rst @@ -68,7 +68,7 @@ press Ctrl-D to exit: Connected to ws://localhost:8080/. > Hello! < Hello! - Connection closed: code = 1000 (OK), no reason. + Connection closed: 1000 (OK). Look at the pid of an instance of the app in the logs and terminate it: @@ -117,7 +117,7 @@ And you can see that the connection to the app was closed gracefully: $ python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. - Connection closed: code = 1001 (going away), no reason. + Connection closed: 1001 (going away). In this example, we've been sharing the same virtualenv for supervisor and websockets. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c8553d17d..cf38c6159 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -49,6 +49,13 @@ They may change at any time. This reflects a decision made in Python 3.8. See the release notes of Python 3.10 for details. +.. note:: + + **Version 10.0 changes parameters of** ``ConnectionClosed.__init__`` **.** + + If you raise :exc:`~exceptions.ConnectionClosed` or a subclass — rather + than catch them when websockets raises them — you must change your code. + * Added compatibility with Python 3.10. * Added :func:`~websockets.broadcast` to send a message to many clients. @@ -60,6 +67,8 @@ They may change at any time. * Improved logging. +* Provided additional information in :exc:`ConnectionClosed` exceptions. + * Optimized default compression settings to reduce memory usage. * Made it easier to customize authentication with diff --git a/docs/reference/client.rst b/docs/reference/client.rst index c7c738c81..84f66a19a 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -55,17 +55,9 @@ Client Available once the connection is open. - .. attribute:: close_code + .. autoattribute:: close_code - WebSocket close code. - - Available once the connection is closed. - - .. attribute:: close_reason - - WebSocket close reason. - - Available once the connection is closed. + .. autoattribute:: close_reason .. automethod:: recv diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 2d54eca7a..667c0b9d0 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -65,17 +65,9 @@ Server Available once the connection is open. - .. attribute:: close_code + .. autoattribute:: close_code - WebSocket close code. - - Available once the connection is closed. - - .. attribute:: close_reason - - WebSocket close reason. - - Available once the connection is closed. + .. autoattribute:: close_reason .. automethod:: process_request diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 6347ee278..785d2c3c9 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -143,6 +143,7 @@ async def run_client( finally: await websocket.close() + assert websocket.close_code is not None and websocket.close_reason is not None close_status = Close(websocket.close_code, websocket.close_reason) print_over_input(f"Connection closed: {close_status}.") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 59922be9c..67745da61 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -81,37 +81,63 @@ class ConnectionClosed(WebSocketException): """ - def __init__(self, code: int, reason: str) -> None: - self.code = code - self.reason = reason - super().__init__(str(frames.Close(code, reason))) + def __init__( + self, + rcvd: Optional[frames.Close], + sent: Optional[frames.Close], + rcvd_then_sent: Optional[bool] = None, + ) -> None: + self.rcvd = rcvd + self.sent = sent + self.rcvd_then_sent = rcvd_then_sent + if rcvd is None: + if sent is None: + assert rcvd_then_sent is None + msg = "no close frame received or sent" + else: + assert rcvd_then_sent is None + msg = f"sent {sent}; no close frame received" + else: + if sent is None: + assert rcvd_then_sent is None + msg = f"received {rcvd}; no close frame sent" + else: + assert rcvd_then_sent is not None + if rcvd_then_sent: + msg = f"received {rcvd}; then sent {sent}" + else: + msg = f"sent {sent}; then received {rcvd}" + super().__init__(msg) + + # code and reason attributes are provided for backwards-compatibility + + @property + def code(self) -> int: + return 1006 if self.rcvd is None else self.rcvd.code + + @property + def reason(self) -> str: + return "" if self.rcvd is None else self.rcvd.reason class ConnectionClosedError(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated with an error. - This means the close code is different from 1000 (OK) and 1001 (going away). + A close code other than 1000 (OK) or 1001 (going away) was received or + sent, or the closing handshake didn't complete properly. """ - def __init__(self, code: int, reason: str) -> None: - assert code != 1000 and code != 1001 - super().__init__(code, reason) - class ConnectionClosedOK(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated properly. - This means the close code is 1000 (OK) or 1001 (going away). + A close code 1000 (OK) or 1001 (going away) was received and sent. """ - def __init__(self, code: int, reason: str) -> None: - assert code == 1000 or code == 1001 - super().__init__(code, reason) - class InvalidHandshake(WebSocketException): """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 15b1a9de2..4c57386b7 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -93,18 +93,20 @@ class Opcode(enum.IntEnum): 1014, } +OK_CLOSE_CODES = {1000, 1001} + @dataclasses.dataclass class Frame: """ WebSocket frame. + :param int opcode: opcode + :param bytes data: payload data :param bool fin: FIN bit :param bool rsv1: RSV1 bit :param bool rsv2: RSV2 bit :param bool rsv3: RSV3 bit - :param int opcode: opcode - :param bytes data: payload data Only these fields are needed. The MASK bit, payload length and masking-key are handled on the fly by :meth:`parse` and :meth:`serialize`. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7443a565e..badd1e0d8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -44,6 +44,7 @@ ) from ..extensions import Extension from ..frames import ( + OK_CLOSE_CODES, OP_BINARY, OP_CLOSE, OP_CONT, @@ -181,10 +182,10 @@ def __init__( self.extensions: List[Extension] = [] self.subprotocol: Optional[Subprotocol] = None - # The close code and reason are set when receiving a close frame or - # losing the TCP connection. - self.close_code: int - self.close_reason: str + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Optional[Close] = None + self.close_sent: Optional[Close] = None + self.close_rcvd_then_sent: Optional[bool] = None # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost` callback to a :class:`~asyncio.Future` @@ -339,6 +340,36 @@ def closed(self) -> bool: """ return self.state is State.CLOSED + @property + def close_code(self) -> Optional[int]: + """ + WebSocket close code received in a close frame. + + Available once the connection is closed. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return 1006 + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> Optional[str]: + """ + WebSocket close reason received in a close frame. + + Available once the connection is closed. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + async def wait_closed(self) -> None: """ Wait until the connection is closed. @@ -608,7 +639,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: """ try: await asyncio.wait_for( - self.write_close_frame(Close(code, reason).serialize()), + self.write_close_frame(Close(code, reason)), self.close_timeout, **loop_if_py_lt_38(self.loop), ) @@ -714,14 +745,27 @@ async def pong(self, data: Data = b"") -> None: # Private methods - no guarantees. def connection_closed_exc(self) -> ConnectionClosed: - exception: ConnectionClosed - if self.close_code == 1000 or self.close_code == 1001: - exception = ConnectionClosedOK(self.close_code, self.close_reason) + exc: ConnectionClosed + if ( + self.close_rcvd is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent is not None + and self.close_sent.code in OK_CLOSE_CODES + ): + exc = ConnectionClosedOK( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) else: - exception = ConnectionClosedError(self.close_code, self.close_reason) + exc = ConnectionClosedError( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) # Chain to the exception that terminated data transfer, if any. - exception.__cause__ = self.transfer_data_exc - return exception + exc.__cause__ = self.transfer_data_exc + return exc async def ensure_open(self) -> None: """ @@ -917,13 +961,14 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: if frame.opcode == OP_CLOSE: # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - close = Close.parse(frame.data) - self.close_code, self.close_reason = close.code, close.reason + self.close_rcvd = Close.parse(frame.data) + if self.close_sent is not None: + self.close_rcvd_then_sent = False try: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame # is empty and Close.parse() synthetizes a 1005 close code. - await self.write_close_frame(frame.data) + await self.write_close_frame(self.close_rcvd, frame.data) except ConnectionClosed: # Connection closed before we could echo the close frame. pass @@ -1010,7 +1055,9 @@ async def write_frame( self.write_frame_sync(fin, opcode, data) await self.drain() - async def write_close_frame(self, data: bytes = b"") -> None: + async def write_close_frame( + self, close: Close, data: Optional[bytes] = None + ) -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1026,6 +1073,12 @@ async def write_close_frame(self, data: bytes = b"") -> None: if self.debug: self.logger.debug("= connection is CLOSING") + self.close_sent = close + if self.close_rcvd is not None: + self.close_rcvd_then_sent = True + if data is None: + data = close.serialize() + # 7.1.2. Start the WebSocket Closing Handshake await self.write_frame(True, OP_CLOSE, data, _state=State.CLOSING) @@ -1219,8 +1272,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # an error reading from or writing to the network. # Don't send a close frame if the connection is broken. if code != 1006 and self.state is State.OPEN: - - frame_data = Close(code, reason).serialize() + close = Close(code, reason) # Write the close frame without draining the write buffer. @@ -1228,21 +1280,19 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # get stuck and simplifies the implementation of the callers. # Not drainig the write buffer is acceptable in this context. - # This duplicates a few lines of code from write_close_frame() - # and write_frame(). + # This duplicates a few lines of code from write_close_frame(). self.state = State.CLOSING if self.debug: self.logger.debug("= connection is CLOSING") - frame = Frame(True, OP_CLOSE, frame_data) - if self.debug: - self.logger.debug("> %s", frame) - frame.write( - self.transport.write, - mask=self.is_client, - extensions=self.extensions, - ) + # If self.close_rcvd was set, the connection state would be + # CLOSING. Therefore self.close_rcvd isn't set and we don't + # have to set self.close_rcvd_then_sent. + assert self.close_rcvd is None + self.close_sent = close + + self.write_frame_sync(True, OP_CLOSE, close.serialize()) # Start close_connection_task if the opening handshake didn't succeed. if not hasattr(self, "close_connection_task"): @@ -1303,15 +1353,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: """ self.state = State.CLOSED - if not hasattr(self, "close_code"): - self.close_code = 1006 - if not hasattr(self, "close_reason"): - self.close_reason = "" - self.logger.debug( - "= connection is CLOSED - %d %s", - self.close_code, - self.close_reason or "[no reason]", - ) + self.logger.debug("= connection is CLOSED") self.abort_pings() diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 1e3f1b77e..033cebe17 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -386,6 +386,20 @@ def test_wait_closed(self): self.close_connection() self.assertTrue(wait_closed.done()) + def test_close_code(self): + self.close_connection(1001, "Bye!") + self.assertEqual(self.protocol.close_code, 1001) + + def test_close_reason(self): + self.close_connection(1001, "Bye!") + self.assertEqual(self.protocol.close_reason, "Bye!") + + def test_close_code_not_set(self): + self.assertIsNone(self.protocol.close_code) + + def test_close_reason_not_set(self): + self.assertIsNone(self.protocol.close_reason) + # Test the recv coroutine. def test_recv_text(self): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 9c8eef4fc..85ebc24da 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -2,6 +2,7 @@ from websockets.datastructures import Headers from websockets.exceptions import * +from websockets.frames import Close class ExceptionsTests(unittest.TestCase): @@ -13,28 +14,36 @@ def test_str(self): "something went wrong", ), ( - ConnectionClosed(1000, ""), - "1000 (OK)", + ConnectionClosed(Close(1000, ""), Close(1000, ""), True), + "received 1000 (OK); then sent 1000 (OK)", ), ( - ConnectionClosed(1006, None), - "1006 (connection closed abnormally [internal])" + ConnectionClosed(Close(1001, "Bye!"), Close(1001, "Bye!"), False), + "sent 1001 (going away) Bye!; then received 1001 (going away) Bye!", ), ( - ConnectionClosed(3000, None), - "3000 (registered)" + ConnectionClosed(Close(1000, "race"), Close(1000, "cond"), True), + "received 1000 (OK) race; then sent 1000 (OK) cond", ), ( - ConnectionClosed(4000, None), - "4000 (private use)" + ConnectionClosed(Close(1000, "cond"), Close(1000, "race"), False), + "sent 1000 (OK) race; then received 1000 (OK) cond", ), ( - ConnectionClosedError(1016, None), - "1016 (unknown)" + ConnectionClosed(None, Close(1009, ""), None), + "sent 1009 (message too big); no close frame received", ), ( - ConnectionClosedOK(1001, "bye"), - "1001 (going away) bye", + ConnectionClosed(Close(1002, ""), None, None), + "received 1002 (protocol error); no close frame sent", + ), + ( + ConnectionClosedOK(Close(1000, ""), Close(1000, ""), True), + "received 1000 (OK); then sent 1000 (OK)", + ), + ( + ConnectionClosedError(None, None, None), + "no close frame received or sent" ), ( InvalidHandshake("invalid request"), diff --git a/tests/test_frames.py b/tests/test_frames.py index 7620fe415..c8f9867d4 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -405,8 +405,15 @@ def assertCloseData(self, close, data): parsed = Close.parse(data) self.assertEqual(parsed, close) + def test_str(self): + self.assertEqual(str(Close(1000, "")), "1000 (OK)") + self.assertEqual(str(Close(1001, "Bye!")), "1001 (going away) Bye!") + self.assertEqual(str(Close(3000, "")), "3000 (registered)") + self.assertEqual(str(Close(4000, "")), "4000 (private use)") + self.assertEqual(str(Close(5000, "")), "5000 (unknown)") + def test_parse_and_serialize(self): - self.assertCloseData(Close(1000, ""), b"\x03\xe8") + self.assertCloseData(Close(1001, ""), b"\x03\xe9") self.assertCloseData(Close(1000, "OK"), b"\x03\xe8OK") def test_parse_empty(self): From 1186dd6cfc072a7038d5478f6f5722038efba51c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 16 Jun 2021 22:51:50 +0200 Subject: [PATCH 094/762] Store details of close frames. Ref #587. --- src/websockets/connection.py | 70 ++++++++++++++++++++++++++---------- tests/test_connection.py | 50 +++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 19 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 067ad54ce..d682226ab 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -102,15 +102,10 @@ def __init__( self.extensions: List[Extension] = [] self.subprotocol: Optional[Subprotocol] = None - # Connection state isn't enough to tell if a close frame was received: - # when this side closes the connection, state is CLOSING as soon as a - # close frame is sent, before a close frame is received. - self.close_frame_received = False - - # Close code and reason. Set when receiving a close frame or when the - # TCP connection drops. - self.close_code: int - self.close_reason: str + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Optional[Close] = None + self.close_sent: Optional[Close] = None + self.close_rcvd_then_sent: Optional[bool] = None # Track if send_eof() was called. self.eof_sent = False @@ -128,6 +123,36 @@ def set_state(self, state: State) -> None: self.logger.debug("= connection is %s", state.name) self.state = state + @property + def close_code(self) -> Optional[int]: + """ + WebSocket close code received in a close frame. + + Available once the connection is closed. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return 1006 + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> Optional[str]: + """ + WebSocket close reason received in a close frame. + + Available once the connection is closed. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + # Public APIs for receiving data. def receive_data(self, data: bytes) -> None: @@ -199,10 +224,13 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: if code is None: if reason != "": raise ValueError("cannot send a reason without a code") + close = Close(1005, "") data = b"" else: - data = Close(code, reason).serialize() + close = Close(code, reason) + data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started self.set_state(CLOSING) @@ -258,7 +286,9 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. if code != 1006 and self.state is OPEN: - self.send_frame(Frame(OP_CLOSE, Close(code, reason).serialize())) + close = Close(code, reason) + self.send_frame(Frame(OP_CLOSE, close.serialize())) + self.close_sent = close self.set_state(CLOSING) if not self.eof_sent: self.send_eof() @@ -304,7 +334,7 @@ def parse(self) -> Generator[None, None, None]: if eof: if self.debug: self.logger.debug("< EOF") - if self.close_frame_received: + if self.close_rcvd is not None: if not self.eof_sent: self.send_eof() yield @@ -338,7 +368,7 @@ def parse(self) -> Generator[None, None, None]: if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: # 5.5.1 Close: "The application MUST NOT send any more data # frames after sending a Close frame." - if self.close_frame_received: + if self.close_rcvd is not None: raise ProtocolError("data frame after close frame") if self.cur_size is not None: @@ -351,7 +381,7 @@ def parse(self) -> Generator[None, None, None]: elif frame.opcode is OP_CONT: # 5.5.1 Close: "The application MUST NOT send any more data # frames after sending a Close frame." - if self.close_frame_received: + if self.close_rcvd is not None: raise ProtocolError("data frame after close frame") if self.cur_size is None: @@ -365,7 +395,7 @@ def parse(self) -> Generator[None, None, None]: # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST # send a Pong frame in response, unless it already received a # Close frame." - if not self.close_frame_received: + if self.close_rcvd is None: pong_frame = Frame(OP_PONG, frame.data) self.send_frame(pong_frame) @@ -375,11 +405,12 @@ def parse(self) -> Generator[None, None, None]: pass elif frame.opcode is OP_CLOSE: - self.close_frame_received = True # 7.1.5. The WebSocket Connection Close Code # 7.1.6. The WebSocket Connection Close Reason - close = Close.parse(frame.data) - self.close_code, self.close_reason = close.code, close.reason + self.close_rcvd = Close.parse(frame.data) + if self.state is CLOSING: + assert self.close_sent is not None + self.close_rcvd_then_sent = False if self.cur_size is not None: raise ProtocolError("incomplete fragmented message") @@ -388,12 +419,15 @@ def parse(self) -> Generator[None, None, None]: # Close frame in response. (When sending a Close frame in # response, the endpoint typically echos the status code it # received.)" + if self.state is OPEN: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame # is empty and Close.parse() synthetizes a 1005 close code. # The rest is identical to send_close(). self.send_frame(Frame(OP_CLOSE, frame.data)) + self.close_sent = self.close_rcvd + self.close_rcvd_then_sent = True self.set_state(CLOSING) if self.side is SERVER: self.send_eof() diff --git a/tests/test_connection.py b/tests/test_connection.py index f2ce8de46..2a887564b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -667,10 +667,58 @@ def test_server_receives_binary_after_receiving_close(self): class CloseTests(ConnectionTestCase): """ - Test close frames. See 5.5.1. Close in RFC 6544. + Test close frames. + + See RFC 6544: + + 5.5.1. Close + 7.1.6. The WebSocket Connection Close Reason + 7.1.7. Fail the WebSocket Connection """ + def test_close_code(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x04\x03\xe8OK") + client.set_state(State.CLOSED) + self.assertEqual(client.close_code, 1000) + + def test_close_reason(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") + server.set_state(State.CLOSED) + self.assertEqual(server.close_reason, "OK") + + def test_close_code_not_provided(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x00\x00\x00\x00") + server.set_state(State.CLOSED) + self.assertEqual(server.close_code, 1005) + + def test_close_reason_not_provided(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + client.set_state(State.CLOSED) + self.assertEqual(client.close_reason, "") + + def test_close_code_not_available(self): + client = Connection(Side.CLIENT) + client.set_state(State.CLOSED) + self.assertEqual(client.close_code, 1006) + + def test_close_reason_not_available(self): + server = Connection(Side.SERVER) + server.set_state(State.CLOSED) + self.assertEqual(server.close_reason, "") + + def test_close_code_not_available_yet(self): + server = Connection(Side.SERVER) + self.assertIsNone(server.close_code) + + def test_close_reason_not_available_yet(self): + client = Connection(Side.CLIENT) + self.assertIsNone(client.close_reason) + def test_client_sends_close(self): client = Connection(Side.CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): From 8152df6f0b02c1e8c655a3845b7d39e9904c5a9b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 06:57:47 +0200 Subject: [PATCH 095/762] Refactor exceptions to create messages on demand. --- src/websockets/exceptions.py | 96 +++++++++++++++++++++--------------- tests/test_exceptions.py | 5 ++ 2 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 67745da61..c8ae1d6b5 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -90,24 +90,25 @@ def __init__( self.rcvd = rcvd self.sent = sent self.rcvd_then_sent = rcvd_then_sent - if rcvd is None: - if sent is None: - assert rcvd_then_sent is None - msg = "no close frame received or sent" + + def __str__(self) -> str: + if self.rcvd is None: + if self.sent is None: + assert self.rcvd_then_sent is None + return "no close frame received or sent" else: - assert rcvd_then_sent is None - msg = f"sent {sent}; no close frame received" + assert self.rcvd_then_sent is None + return f"sent {self.sent}; no close frame received" else: - if sent is None: - assert rcvd_then_sent is None - msg = f"received {rcvd}; no close frame sent" + if self.sent is None: + assert self.rcvd_then_sent is None + return f"received {self.rcvd}; no close frame sent" else: - assert rcvd_then_sent is not None - if rcvd_then_sent: - msg = f"received {rcvd}; then sent {sent}" + assert self.rcvd_then_sent is not None + if self.rcvd_then_sent: + return f"received {self.rcvd}; then sent {self.sent}" else: - msg = f"sent {sent}; then received {rcvd}" - super().__init__(msg) + return f"sent {self.sent}; then received {self.rcvd}" # code and reason attributes are provided for backwards-compatibility @@ -171,13 +172,14 @@ class InvalidHeader(InvalidHandshake): def __init__(self, name: str, value: Optional[str] = None) -> None: self.name = name self.value = value - if value is None: - message = f"missing {name} header" - elif value == "": - message = f"empty {name} header" + + def __str__(self) -> str: + if self.value is None: + return f"missing {self.name} header" + elif self.value == "": + return f"empty {self.name} header" else: - message = f"invalid {name} header: {value}" - super().__init__(message) + return f"invalid {self.name} header: {self.value}" class InvalidHeaderFormat(InvalidHeader): @@ -189,9 +191,7 @@ class InvalidHeaderFormat(InvalidHeader): """ def __init__(self, name: str, error: str, header: str, pos: int) -> None: - self.name = name - error = f"{error} at {pos} in {header}" - super().__init__(name, error) + super().__init__(name, f"{error} at {pos} in {header}") class InvalidHeaderValue(InvalidHeader): @@ -228,8 +228,12 @@ class InvalidStatus(InvalidHandshake): def __init__(self, response: http11.Response) -> None: self.response = response - message = f"server rejected WebSocket connection: HTTP {response.status_code:d}" - super().__init__(message) + + def __str__(self) -> str: + return ( + "server rejected WebSocket connection: " + f"HTTP {self.response.status_code:d}" + ) class InvalidStatusCode(InvalidHandshake): @@ -244,8 +248,9 @@ class InvalidStatusCode(InvalidHandshake): def __init__(self, status_code: int, headers: datastructures.Headers) -> None: self.status_code = status_code self.headers = headers - message = f"server rejected WebSocket connection: HTTP {status_code}" - super().__init__(message) + + def __str__(self) -> str: + return f"server rejected WebSocket connection: HTTP {self.status_code}" class NegotiationError(InvalidHandshake): @@ -263,8 +268,9 @@ class DuplicateParameter(NegotiationError): def __init__(self, name: str) -> None: self.name = name - message = f"duplicate parameter: {name}" - super().__init__(message) + + def __str__(self) -> str: + return f"duplicate parameter: {self.name}" class InvalidParameterName(NegotiationError): @@ -275,8 +281,9 @@ class InvalidParameterName(NegotiationError): def __init__(self, name: str) -> None: self.name = name - message = f"invalid parameter name: {name}" - super().__init__(message) + + def __str__(self) -> str: + return f"invalid parameter name: {self.name}" class InvalidParameterValue(NegotiationError): @@ -288,13 +295,14 @@ class InvalidParameterValue(NegotiationError): def __init__(self, name: str, value: Optional[str]) -> None: self.name = name self.value = value - if value is None: - message = f"missing value for parameter {name}" - elif value == "": - message = f"empty value for parameter {name}" + + def __str__(self) -> str: + if self.value is None: + return f"missing value for parameter {self.name}" + elif self.value == "": + return f"empty value for parameter {self.name}" else: - message = f"invalid value for parameter {name}: {value}" - super().__init__(message) + return f"invalid value for parameter {self.name}: {self.value}" class AbortHandshake(InvalidHandshake): @@ -316,8 +324,13 @@ def __init__( self.status = status self.headers = datastructures.Headers(headers) self.body = body - message = f"HTTP {status:d}, {len(self.headers)} headers, {len(body)} bytes" - super().__init__(message) + + def __str__(self) -> str: + return ( + f"HTTP {self.status:d}, " + f"{len(self.headers)} headers, " + f"{len(self.body)} bytes" + ) class RedirectHandshake(InvalidHandshake): @@ -354,8 +367,9 @@ class InvalidURI(WebSocketException): def __init__(self, uri: str) -> None: self.uri = uri - message = "{} isn't a valid URI".format(uri) - super().__init__(message) + + def __str__(self) -> str: + return f"{self.uri} isn't a valid URI" class PayloadTooBig(WebSocketException): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 85ebc24da..e172cdd02 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,6 +3,7 @@ from websockets.datastructures import Headers from websockets.exceptions import * from websockets.frames import Close +from websockets.http11 import Response class ExceptionsTests(unittest.TestCase): @@ -96,6 +97,10 @@ def test_str(self): InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), + ( + InvalidStatus(Response(401, "Unauthorized", Headers())), + "server rejected WebSocket connection: HTTP 401", + ), ( InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", From 0bfa9f2ea86dbc12f24e4a2756cb339efafe76d8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 07:17:58 +0200 Subject: [PATCH 096/762] Document ConnectionClosed attributes. --- src/websockets/exceptions.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index c8ae1d6b5..b3462484f 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -76,8 +76,17 @@ class ConnectionClosed(WebSocketException): """ Raised when trying to interact with a closed connection. - Provides the connection close code and reason in its ``code`` and - ``reason`` attributes respectively. + If a close frame was received, its code and reason are available in the + ``rcvd.code`` and ``rcvd.reason`` attributes. Else, the ``rcvd`` + attribute is ``None``. + + Likewise, if a close frame was sent, its code and reason are available in + the ``sent.code`` and ``sent.reason`` attributes. Else, the ``sent`` + attribute is ``None``. + + If close frames were received and sent, the ``rcvd_then_sent`` attribute + tells in which order this happened, from the perspective of this side of + the connection. """ From f148d821f3b9ce620c65011281056f7655ea6fa6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 22:06:16 +0200 Subject: [PATCH 097/762] Remove superfluous logging. --- src/websockets/legacy/protocol.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index badd1e0d8..cab72caef 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1316,14 +1316,6 @@ def abort_pings(self) -> None: # nothing, but it prevents logging the exception. ping.cancel() - if self.debug: - if self.pings: - pings_hex = ", ".join( - ping_id.hex() or "[empty]" for ping_id in self.pings - ) - plural = "s" if len(self.pings) > 1 else "" - self.logger.debug("% aborted pending ping%s: %s", plural, pings_hex) - # asyncio.Protocol methods def connection_made(self, transport: asyncio.BaseTransport) -> None: From b2a95c45fae19fff0a3473158dd02afe2ca42604 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 22:30:00 +0200 Subject: [PATCH 098/762] Don't answer pings on closing connection. Technically, this is the wrong behavior, but I'll live with this in the legacy layer. The new Sans I/O layer has the right behavior. Fix #669. --- src/websockets/legacy/protocol.py | 13 +++++++------ tests/legacy/test_protocol.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index cab72caef..df74bdb38 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -975,12 +975,13 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: return None elif frame.opcode == OP_PING: - # Answer pings. - try: - await self.pong(frame.data) - except ConnectionClosed: - # Connection closed before we could respond to the ping. - pass + # Answer pings, unless connection is CLOSING. + if self.state is State.OPEN: + try: + await self.pong(frame.data) + except ConnectionClosed: + # Connection closed while draining write buffer. + pass elif frame.opcode == OP_PONG: if frame.data in self.pings: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 033cebe17..61a5fe7cf 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -909,10 +909,20 @@ def test_answer_ping(self): self.run_loop_once() self.assertOneFrameSent(True, OP_PONG, b"test") + def test_answer_ping_does_not_crash_if_connection_closing(self): + close_task = self.half_close_connection_local() + + self.receive_frame(Frame(True, OP_PING, b"test")) + + with self.assertNoLogs(): + self.loop.run_until_complete(self.protocol.close()) + + self.loop.run_until_complete(close_task) # cleanup + def test_answer_ping_does_not_crash_if_connection_closed(self): self.make_drain_slow() # Drop the connection right after receiving a ping frame, - # which prevents responding wwith a pong frame properly. + # which prevents responding with a pong frame properly. self.receive_frame(Frame(True, OP_PING, b"test")) self.receive_eof() From 361f38066622636eff8d9e0d3d071e93e8e02b05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 17 Jun 2021 23:31:31 +0200 Subject: [PATCH 099/762] Handle connection drop during handshake. Fix #984. --- src/websockets/legacy/server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 80d72f93e..3fde99568 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -29,6 +29,7 @@ cast, ) +from ..connection import State from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( AbortHandshake, @@ -656,6 +657,10 @@ async def handshake( warnings.warn("declare process_request as a coroutine", DeprecationWarning) early_response = early_response_awaitable # type: ignore + # The connection may drop while process_request is running. + if self.state is State.CLOSED: + raise self.connection_closed_exc() # pragma: no cover + # Change the response to a 503 error if the server is shutting down. if not self.ws_server.is_serving(): early_response = ( From f9c28e8b7810c5a2fcadc88825046618f1fdc012 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jun 2021 18:00:30 +0200 Subject: [PATCH 100/762] Polish class structure. --- src/websockets/connection.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index d682226ab..9fd7db490 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -118,10 +118,7 @@ def __init__( next(self.parser) # start coroutine self.parser_exc: Optional[Exception] = None - def set_state(self, state: State) -> None: - if self.debug: - self.logger.debug("= connection is %s", state.name) - self.state = state + # Public attributes @property def close_code(self) -> Optional[int]: @@ -153,6 +150,13 @@ def close_reason(self) -> Optional[str]: else: return self.close_rcvd.reason + # Private attributes + + def set_state(self, state: State) -> None: + if self.debug: + self.logger.debug("= connection is %s", state.name) + self.state = state + # Public APIs for receiving data. def receive_data(self, data: bytes) -> None: From ebda766d856b59faa86e14f9cf53a5e0ede0a355 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jun 2021 18:03:01 +0200 Subject: [PATCH 101/762] Fix connection closure. * In a regular closing handshake, server closes the connection after receiving and sending a close frame. * When failing the connection, server closes the connection after sending a close frame. --- src/websockets/connection.py | 55 +++++++++++++++++++++--------------- tests/test_connection.py | 26 +++++++++-------- 2 files changed, 46 insertions(+), 35 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 9fd7db490..e5c9b866c 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -238,8 +238,6 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started self.set_state(CLOSING) - if self.side is SERVER: - self.send_eof() def send_ping(self, data: bytes) -> None: """ @@ -255,6 +253,22 @@ def send_pong(self, data: bytes) -> None: """ self.send_frame(Frame(OP_PONG, data)) + def fail(self, code: Optional[int] = None, reason: str = "") -> None: + """ + Fail the WebSocket connection. + + """ + # 7.1.7. Fail the WebSocket Connection + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because + # of an error reading from or writing to the network. + if self.state is OPEN and code != 1006: + self.send_close(code, reason) + + if self.side is SERVER and not self.eof_sent: + self.send_eof() + # Public API for getting incoming events after receiving data. def events_received(self) -> List[Event]: @@ -274,8 +288,9 @@ def data_to_send(self) -> List[bytes]: """ Return data to write to the connection. - Call this method immediately after calling any of the ``receive_*()`` - or ``send_*()`` methods and write the data to the connection. + Call this method immediately after calling any of the ``receive_*()``, + ``send_*()``, or ``fail()`` methods and write the data to the + connection. The empty bytestring signals the end of the data stream. @@ -285,18 +300,6 @@ def data_to_send(self) -> List[bytes]: # Private APIs for receiving data. - def fail_connection(self, code: int = 1006, reason: str = "") -> None: - # Send a close frame when the state is OPEN (a close frame was already - # sent if it's CLOSING), except when failing the connection because of - # an error reading from or writing to the network. - if code != 1006 and self.state is OPEN: - close = Close(code, reason) - self.send_frame(Frame(OP_CLOSE, close.serialize())) - self.close_sent = close - self.set_state(CLOSING) - if not self.eof_sent: - self.send_eof() - def step_parser(self) -> None: # Run parser until more data is needed or EOF try: @@ -310,25 +313,25 @@ def step_parser(self) -> None: "parser cannot receive data or EOF after an error" ) from self.parser_exc except ProtocolError as exc: - self.fail_connection(1002, str(exc)) + self.fail(1002, str(exc)) self.parser_exc = exc raise except EOFError as exc: - self.fail_connection(1006, str(exc)) + self.fail(1006, str(exc)) self.parser_exc = exc raise except UnicodeDecodeError as exc: - self.fail_connection(1007, f"{exc.reason} at position {exc.start}") + self.fail(1007, f"{exc.reason} at position {exc.start}") self.parser_exc = exc raise except PayloadTooBig as exc: - self.fail_connection(1009, str(exc)) + self.fail(1009, str(exc)) self.parser_exc = exc raise except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. - self.fail_connection(1011) + self.fail(1011) self.parser_exc = exc raise @@ -418,6 +421,7 @@ def parse(self) -> Generator[None, None, None]: if self.cur_size is not None: raise ProtocolError("incomplete fragmented message") + # 5.5.1 Close: "If an endpoint receives a Close frame and did # not previously send a Close frame, the endpoint MUST send a # Close frame in response. (When sending a Close frame in @@ -433,8 +437,13 @@ def parse(self) -> Generator[None, None, None]: self.close_sent = self.close_rcvd self.close_rcvd_then_sent = True self.set_state(CLOSING) - if self.side is SERVER: - self.send_eof() + + # 7.1.2. Start the WebSocket Closing Handshake: "Once an + # endpoint has both sent and received a Close control frame, + # that endpoint SHOULD _Close the WebSocket Connection_" + + if self.side is SERVER: + self.send_eof() else: # pragma: no cover # This can't happen because Frame.parse() validates opcodes. diff --git a/tests/test_connection.py b/tests/test_connection.py index 2a887564b..2d6ccb3f9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -80,8 +80,10 @@ def assertConnectionFailing(self, connection, code=None, reason=""): ) # No frame was received. self.assertFrameReceived(connection, None) - # A close frame and the end of stream were sent. - self.assertFrameSent(connection, close_frame, eof=True) + # A close frame and possibly the end of stream were sent. + self.assertFrameSent( + connection, close_frame, eof=connection.side is Side.SERVER + ) class MaskingTests(ConnectionTestCase): @@ -187,7 +189,7 @@ def test_server_sends_continuation_after_sending_close(self): # this is the same test as test_server_sends_unexpected_continuation. server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") @@ -442,7 +444,7 @@ def test_client_sends_text_after_sending_close(self): def test_server_sends_text_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_text(b"") @@ -644,7 +646,7 @@ def test_client_sends_binary_after_sending_close(self): def test_server_sends_binary_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_binary(b"") @@ -729,7 +731,7 @@ def test_client_sends_close(self): def test_server_sends_close(self): server = Connection(Side.SERVER) server.send_close() - self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x00"]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close(self): @@ -769,11 +771,11 @@ def test_server_sends_close_then_receives_close(self): server.send_close() self.assertFrameReceived(server, None) - self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) + self.assertFrameSent(server, Frame(OP_CLOSE, b"")) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) - self.assertFrameSent(server, None) + self.assertFrameSent(server, None, eof=True) server.receive_eof() self.assertFrameReceived(server, None) @@ -813,7 +815,7 @@ def test_client_sends_close_with_code(self): def test_server_sends_close_with_code(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close_with_code(self): @@ -840,7 +842,7 @@ def test_client_sends_close_with_code_and_reason(self): def test_server_sends_close_with_code_and_reason(self): server = Connection(Side.SERVER) server.send_close(1000, "OK") - self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) self.assertIs(server.state, State.CLOSING) def test_client_receives_close_with_code_and_reason(self): @@ -1029,7 +1031,7 @@ def test_client_sends_ping_after_sending_close(self): def test_server_sends_ping_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) # before the connection is closed" but websockets doesn't support # sending a Ping frame after a Close frame. @@ -1164,7 +1166,7 @@ def test_client_sends_pong_after_sending_close(self): def test_server_sends_pong_after_sending_close(self): server = Connection(Side.SERVER) server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8", b""]) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # websockets doesn't support sending a Pong frame after a Close frame. with self.assertRaises(InvalidState): server.send_pong(b"") From 8f6c4d9e22c7105248cebbe0cb50e47008e12257 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jun 2021 18:35:10 +0200 Subject: [PATCH 102/762] Stop processing after failing connection. --- src/websockets/connection.py | 15 ++++++----- src/websockets/streams.py | 8 ++++++ tests/test_connection.py | 49 ++++++++++++++++++++++++------------ tests/test_streams.py | 9 +++++++ 4 files changed, 57 insertions(+), 24 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index e5c9b866c..966547f59 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -269,6 +269,11 @@ def fail(self, code: Optional[int] = None, reason: str = "") -> None: if self.side is SERVER and not self.eof_sent: self.send_eof() + # "An endpoint MUST NOT continue to attempt to process data + # (including a responding Close frame) from the remote endpoint + # after being instructed to _Fail the WebSocket Connection_." + self.reader.abort() + # Public API for getting incoming events after receiving data. def events_received(self) -> List[Event]: @@ -304,14 +309,8 @@ def step_parser(self) -> None: # Run parser until more data is needed or EOF try: next(self.parser) - except StopIteration: - # This happens if receive_data() or receive_eof() is called after - # the parser raised an exception. (It cannot happen after reaching - # EOF because receive_data() or receive_eof() would fail earlier.) - assert self.parser_exc is not None - raise RuntimeError( - "parser cannot receive data or EOF after an error" - ) from self.parser_exc + except StopIteration: # pragma: no cover + raise AssertionError("parser shouldn't exit") except ProtocolError as exc: self.fail(1002, str(exc)) self.parser_exc = exc diff --git a/src/websockets/streams.py b/src/websockets/streams.py index e02a6ab39..92a356265 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -115,3 +115,11 @@ def feed_eof(self) -> None: if self.eof: raise EOFError("stream ended") self.eof = True + + def abort(self) -> None: + """ + End the stream, discarding all buffered data. + + """ + self.eof = True + del self.buffer[:] diff --git a/tests/test_connection.py b/tests/test_connection.py index 2d6ccb3f9..45a0594d2 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1192,6 +1192,31 @@ def test_server_receives_pong_after_receiving_close(self): ) +class FailTests(ConnectionTestCase): + """ + Test failing the connection. + + See 7.1.7. Fail the WebSocket Connection in RFC 6544. + + """ + + def test_client_stops_processing_frames_after_fail(self): + client = Connection(Side.CLIENT) + client.fail(1002) + self.assertConnectionFailing(client, 1002) + with self.assertRaises(EOFError) as raised: + client.receive_data(b"\x88\x02\x03\xea") + self.assertEqual(str(raised.exception), "stream ended") + + def test_server_stops_processing_frames_after_fail(self): + server = Connection(Side.SERVER) + server.fail(1002) + self.assertConnectionFailing(server, 1002) + with self.assertRaises(EOFError) as raised: + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") + self.assertEqual(str(raised.exception), "stream ended") + + class FragmentationTests(ConnectionTestCase): """ Test message fragmentation. @@ -1401,44 +1426,36 @@ def test_client_receives_data_after_exception(self): with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\xff\xff") self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(EOFError) as raised: client.receive_data(b"\x00\x00") - self.assertEqual( - str(raised.exception), "parser cannot receive data or EOF after an error" - ) + self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_data_after_exception(self): server = Connection(Side.SERVER) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\xff\xff") self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(EOFError) as raised: server.receive_data(b"\x00\x00") - self.assertEqual( - str(raised.exception), "parser cannot receive data or EOF after an error" - ) + self.assertEqual(str(raised.exception), "stream ended") def test_client_receives_eof_after_exception(self): client = Connection(Side.CLIENT) with self.assertRaises(ProtocolError) as raised: client.receive_data(b"\xff\xff") self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(EOFError) as raised: client.receive_eof() - self.assertEqual( - str(raised.exception), "parser cannot receive data or EOF after an error" - ) + self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_eof_after_exception(self): server = Connection(Side.SERVER) with self.assertRaises(ProtocolError) as raised: server.receive_data(b"\xff\xff") self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(EOFError) as raised: server.receive_eof() - self.assertEqual( - str(raised.exception), "parser cannot receive data or EOF after an error" - ) + self.assertEqual(str(raised.exception), "stream ended") def test_client_receives_data_after_eof(self): client = Connection(Side.CLIENT) diff --git a/tests/test_streams.py b/tests/test_streams.py index 566deb2db..e942c8d12 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -144,3 +144,12 @@ def test_feed_eof_after_feed_eof(self): with self.assertRaises(EOFError) as raised: self.reader.feed_eof() self.assertEqual(str(raised.exception), "stream ended") + + def test_abort(self): + gen = self.reader.read_to_eof() + + self.reader.feed_data(b"spam") + self.reader.abort() + + data = self.assertGeneratorReturns(gen) + self.assertEqual(data, b"") From bff6397ffb69dd52c3e91a23268915582ce0b5e2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 18 Jun 2021 21:52:03 +0200 Subject: [PATCH 103/762] Factor out bytes-like types. --- src/websockets/frames.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 4c57386b7..839e2f7b7 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -96,6 +96,9 @@ class Opcode(enum.IntEnum): OK_CLOSE_CODES = {1000, 1001} +BytesLike = bytes, bytearray, memoryview + + @dataclasses.dataclass class Frame: """ @@ -334,7 +337,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: """ if isinstance(data, str): return OP_TEXT, data.encode("utf-8") - elif isinstance(data, (bytes, bytearray, memoryview)): + elif isinstance(data, BytesLike): return OP_BINARY, data else: raise TypeError("data must be bytes-like or str") @@ -356,7 +359,7 @@ def prepare_ctrl(data: Data) -> bytes: """ if isinstance(data, str): return data.encode("utf-8") - elif isinstance(data, (bytes, bytearray, memoryview)): + elif isinstance(data, BytesLike): return bytes(data) else: raise TypeError("data must be bytes-like or str") From 82ba1ff963f623b80f2d6f70c684f545390708a2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Jun 2021 11:46:22 +0200 Subject: [PATCH 104/762] Clarify comments. Attributes are also APIs. --- src/websockets/connection.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 966547f59..620fc6f78 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -157,7 +157,7 @@ def set_state(self, state: State) -> None: self.logger.debug("= connection is %s", state.name) self.state = state - # Public APIs for receiving data. + # Public methods for receiving data. def receive_data(self, data: bytes) -> None: """ @@ -186,7 +186,7 @@ def receive_eof(self) -> None: self.reader.feed_eof() self.step_parser() - # Public APIs for sending events. + # Public methods for sending events. def send_continuation(self, data: bytes, fin: bool) -> None: """ @@ -274,7 +274,7 @@ def fail(self, code: Optional[int] = None, reason: str = "") -> None: # after being instructed to _Fail the WebSocket Connection_." self.reader.abort() - # Public API for getting incoming events after receiving data. + # Public method for getting incoming events after receiving data. def events_received(self) -> List[Event]: """ @@ -287,7 +287,7 @@ def events_received(self) -> List[Event]: events, self.events = self.events, [] return events - # Public API for getting outgoing data after receiving data or sending events. + # Public method for getting outgoing data after receiving data or sending events. def data_to_send(self) -> List[bytes]: """ @@ -303,7 +303,7 @@ def data_to_send(self) -> List[bytes]: writes, self.writes = self.writes, [] return writes - # Private APIs for receiving data. + # Private methods for receiving data. def step_parser(self) -> None: # Run parser until more data is needed or EOF @@ -450,7 +450,7 @@ def parse(self) -> Generator[None, None, None]: self.events.append(frame) - # Private APIs for sending events. + # Private methods for sending events. def send_frame(self, frame: Frame) -> None: # Defensive assertion for protocol compliance. From be42b07606a5cd3975dc92ea0e8a8f2821b55af0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Jun 2021 11:46:58 +0200 Subject: [PATCH 105/762] Uniformize signature. --- src/websockets/legacy/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index d94e9d7b9..c87b5f8d5 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -737,7 +737,9 @@ async def __await_impl__(self) -> WebSocketClientProtocol: def unix_connect( - path: Optional[str], uri: str = "ws://localhost/", **kwargs: Any + path: Optional[str] = None, + uri: str = "ws://localhost/", + **kwargs: Any, ) -> Connect: """ Similar to :func:`connect`, but for connecting to a Unix socket. From 98e8b8d31326557c5c1a85fa0bd75d390b34b3f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 07:39:17 +0200 Subject: [PATCH 106/762] Compare enum with identity. --- src/websockets/client.py | 2 +- src/websockets/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 3217090d0..cc807e97c 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -306,7 +306,7 @@ def parse(self) -> Generator[None, None, None]: except InvalidHandshake as exc: response.exception = exc else: - assert self.state == CONNECTING + assert self.state is CONNECTING self.set_state(OPEN) finally: self.events.append(response) diff --git a/src/websockets/server.py b/src/websockets/server.py index a53798f65..0710dc0be 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -435,7 +435,7 @@ def send_response(self, response: Response) -> None: self.writes.append(response.serialize()) def parse(self) -> Generator[None, None, None]: - if self.state == CONNECTING: + if self.state is CONNECTING: request = yield from Request.parse(self.reader.read_line) if self.debug: From 81a4a2a369190ae4a30e7914ead889c5931172d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 07:41:45 +0200 Subject: [PATCH 107/762] Close connecting when opening handshake fails. --- src/websockets/server.py | 10 ++++++---- tests/test_server.py | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index 0710dc0be..545b38c98 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -420,10 +420,6 @@ def send_response(self, response: Response) -> None: Send a WebSocket handshake response to the client. """ - if response.status_code == 101: - assert self.state is CONNECTING - self.set_state(OPEN) - if self.debug: code, phrase = response.status_code, response.reason_phrase self.logger.debug("> HTTP/1.1 %d %s", code, phrase) @@ -434,6 +430,12 @@ def send_response(self, response: Response) -> None: self.writes.append(response.serialize()) + if response.status_code == 101: + assert self.state is CONNECTING + self.set_state(OPEN) + else: + self.send_eof() + def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: request = yield from Request.parse(self.reader.read_line) diff --git a/tests/test_server.py b/tests/test_server.py index d2c41598e..042d64a31 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -122,7 +122,8 @@ def test_send_reject(self): f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" f"\r\n" - f"Sorry folks.\n".encode() + f"Sorry folks.\n".encode(), + b"", ], ) self.assertEqual(server.state, CONNECTING) From 5495af64f016c332f4a6f91811e9896f07e0ea59 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 09:22:16 +0200 Subject: [PATCH 108/762] Prevent client from closing TCP connection Read until EOF and wait for the server to close the connection first. --- src/websockets/connection.py | 273 +++++++++++++++++++---------------- src/websockets/streams.py | 5 +- tests/test_connection.py | 64 ++++---- tests/test_streams.py | 6 +- 4 files changed, 190 insertions(+), 158 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 620fc6f78..0a3a87bd5 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -170,7 +170,7 @@ def receive_data(self, data: bytes) -> None: """ self.reader.feed_data(data) - self.step_parser() + next(self.parser) def receive_eof(self) -> None: """ @@ -184,7 +184,7 @@ def receive_eof(self) -> None: """ self.reader.feed_eof() - self.step_parser() + next(self.parser) # Public methods for sending events. @@ -233,10 +233,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: else: close = Close(code, reason) data = close.serialize() - self.send_frame(Frame(OP_CLOSE, data)) - self.close_sent = close # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close self.set_state(CLOSING) def send_ping(self, data: bytes) -> None: @@ -253,7 +253,7 @@ def send_pong(self, data: bytes) -> None: """ self.send_frame(Frame(OP_PONG, data)) - def fail(self, code: Optional[int] = None, reason: str = "") -> None: + def fail(self, code: int, reason: str = "") -> None: """ Fail the WebSocket connection. @@ -263,8 +263,13 @@ def fail(self, code: Optional[int] = None, reason: str = "") -> None: # Send a close frame when the state is OPEN (a close frame was already # sent if it's CLOSING), except when failing the connection because # of an error reading from or writing to the network. - if self.state is OPEN and code != 1006: - self.send_close(code, reason) + if self.state is OPEN: + if code != 1006: + close = Close(code, reason) + data = close.serialize() + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close + self.set_state(CLOSING) if self.side is SERVER and not self.eof_sent: self.send_eof() @@ -272,7 +277,8 @@ def fail(self, code: Optional[int] = None, reason: str = "") -> None: # "An endpoint MUST NOT continue to attempt to process data # (including a responding Close frame) from the remote endpoint # after being instructed to _Fail the WebSocket Connection_." - self.reader.abort() + self.parser = self.discard() + next(self.parser) # start coroutine # Public method for getting incoming events after receiving data. @@ -305,28 +311,65 @@ def data_to_send(self) -> List[bytes]: # Private methods for receiving data. - def step_parser(self) -> None: - # Run parser until more data is needed or EOF + def parse(self) -> Generator[None, None, None]: try: - next(self.parser) - except StopIteration: # pragma: no cover - raise AssertionError("parser shouldn't exit") + while True: + if (yield from self.reader.at_eof()): + if self.debug: + self.logger.debug("< EOF") + if self.close_rcvd_then_sent is not None: + if self.side is CLIENT: + self.send_eof() + # If parse() completes normally, execution ends here. + yield + # Once the reader reaches EOF, its feed_data/eof() + # methods raise an error, so our receive_data/eof() + # methods don't step the generator. + raise AssertionError( + "parse() shouldn't step after EOF" + ) # pragma: no cover + else: + raise EOFError("unexpected end of stream") + + if self.max_size is None: + max_size = None + elif self.cur_size is None: + max_size = self.max_size + else: + max_size = self.max_size - self.cur_size + + frame = yield from Frame.parse( + self.reader.read_exact, + mask=self.side is SERVER, + max_size=max_size, + extensions=self.extensions, + ) + + if self.debug: + self.logger.debug("< %s", frame) + + self.recv_frame(frame) + except ProtocolError as exc: self.fail(1002, str(exc)) self.parser_exc = exc raise + except EOFError as exc: self.fail(1006, str(exc)) self.parser_exc = exc raise + except UnicodeDecodeError as exc: self.fail(1007, f"{exc.reason} at position {exc.start}") self.parser_exc = exc raise + except PayloadTooBig as exc: self.fail(1009, str(exc)) self.parser_exc = exc raise + except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. @@ -334,121 +377,97 @@ def step_parser(self) -> None: self.parser_exc = exc raise - def parse(self) -> Generator[None, None, None]: - while True: - eof = yield from self.reader.at_eof() - if eof: - if self.debug: - self.logger.debug("< EOF") - if self.close_rcvd is not None: - if not self.eof_sent: - self.send_eof() - yield - # Once the reader reaches EOF, its feed_data/eof() methods - # raise an error, so our receive_data/eof() methods never - # call step_parser(), so the generator shouldn't resume - # executing until it's garbage collected. - raise AssertionError( - "parser shouldn't step after EOF" - ) # pragma: no cover - else: - raise EOFError("unexpected end of stream") - - if self.max_size is None: - max_size = None - elif self.cur_size is None: - max_size = self.max_size + def discard(self) -> Generator[None, None, None]: + while not (yield from self.reader.at_eof()): + self.reader.discard() + if self.side is CLIENT: + self.send_eof() + # If discard() completes normally, execution ends here. + yield + # Once the reader reaches EOF, its feed_data/eof() + # methods raise an error, so our receive_data/eof() + # methods don't step the generator. + raise AssertionError("discard() shouldn't step after EOF") # pragma: no cover + + def recv_frame(self, frame: Frame) -> None: + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: + # 5.5.1 Close: "The application MUST NOT send any more data + # frames after sending a Close frame." + if self.close_rcvd is not None: + raise ProtocolError("data frame after close frame") + + if self.cur_size is not None: + raise ProtocolError("expected a continuation frame") + if frame.fin: + self.cur_size = None else: - max_size = self.max_size - self.cur_size - - frame = yield from Frame.parse( - self.reader.read_exact, - mask=self.side is SERVER, - max_size=max_size, - extensions=self.extensions, - ) - - if self.debug: - self.logger.debug("< %s", frame) - - if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: - # 5.5.1 Close: "The application MUST NOT send any more data - # frames after sending a Close frame." - if self.close_rcvd is not None: - raise ProtocolError("data frame after close frame") - - if self.cur_size is not None: - raise ProtocolError("expected a continuation frame") - if frame.fin: - self.cur_size = None - else: - self.cur_size = len(frame.data) - - elif frame.opcode is OP_CONT: - # 5.5.1 Close: "The application MUST NOT send any more data - # frames after sending a Close frame." - if self.close_rcvd is not None: - raise ProtocolError("data frame after close frame") - - if self.cur_size is None: - raise ProtocolError("unexpected continuation frame") - if frame.fin: - self.cur_size = None - else: - self.cur_size += len(frame.data) - - elif frame.opcode is OP_PING: - # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST - # send a Pong frame in response, unless it already received a - # Close frame." - if self.close_rcvd is None: - pong_frame = Frame(OP_PONG, frame.data) - self.send_frame(pong_frame) - - elif frame.opcode is OP_PONG: - # 5.5.3 Pong: "A response to an unsolicited Pong frame is not - # expected." - pass - - elif frame.opcode is OP_CLOSE: - # 7.1.5. The WebSocket Connection Close Code - # 7.1.6. The WebSocket Connection Close Reason - self.close_rcvd = Close.parse(frame.data) - if self.state is CLOSING: - assert self.close_sent is not None - self.close_rcvd_then_sent = False - - if self.cur_size is not None: - raise ProtocolError("incomplete fragmented message") - - # 5.5.1 Close: "If an endpoint receives a Close frame and did - # not previously send a Close frame, the endpoint MUST send a - # Close frame in response. (When sending a Close frame in - # response, the endpoint typically echos the status code it - # received.)" - - if self.state is OPEN: - # Echo the original data instead of re-serializing it with - # Close.serialize() because that fails when the close frame - # is empty and Close.parse() synthetizes a 1005 close code. - # The rest is identical to send_close(). - self.send_frame(Frame(OP_CLOSE, frame.data)) - self.close_sent = self.close_rcvd - self.close_rcvd_then_sent = True - self.set_state(CLOSING) - - # 7.1.2. Start the WebSocket Closing Handshake: "Once an - # endpoint has both sent and received a Close control frame, - # that endpoint SHOULD _Close the WebSocket Connection_" - - if self.side is SERVER: - self.send_eof() - - else: # pragma: no cover - # This can't happen because Frame.parse() validates opcodes. - raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") - - self.events.append(frame) + self.cur_size = len(frame.data) + + elif frame.opcode is OP_CONT: + # 5.5.1 Close: "The application MUST NOT send any more data + # frames after sending a Close frame." + if self.close_rcvd is not None: + raise ProtocolError("data frame after close frame") + + if self.cur_size is None: + raise ProtocolError("unexpected continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size += len(frame.data) + + elif frame.opcode is OP_PING: + # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST + # send a Pong frame in response, unless it already received a + # Close frame." + if self.close_rcvd is None: + pong_frame = Frame(OP_PONG, frame.data) + self.send_frame(pong_frame) + + elif frame.opcode is OP_PONG: + # 5.5.3 Pong: "A response to an unsolicited Pong frame is not + # expected." + pass + + elif frame.opcode is OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_rcvd = Close.parse(frame.data) + if self.state is CLOSING: + assert self.close_sent is not None + self.close_rcvd_then_sent = False + + if self.cur_size is not None: + raise ProtocolError("incomplete fragmented message") + + # 5.5.1 Close: "If an endpoint receives a Close frame and did + # not previously send a Close frame, the endpoint MUST send a + # Close frame in response. (When sending a Close frame in + # response, the endpoint typically echos the status code it + # received.)" + + if self.state is OPEN: + # Echo the original data instead of re-serializing it with + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthetizes a 1005 close code. + # The rest is identical to send_close(). + self.send_frame(Frame(OP_CLOSE, frame.data)) + self.close_sent = self.close_rcvd + self.close_rcvd_then_sent = True + self.set_state(CLOSING) + + # 7.1.2. Start the WebSocket Closing Handshake: "Once an + # endpoint has both sent and received a Close control frame, + # that endpoint SHOULD _Close the WebSocket Connection_" + + if self.side is SERVER: + self.send_eof() + + else: # pragma: no cover + # This can't happen because Frame.parse() validates opcodes. + raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") + + self.events.append(frame) # Private methods for sending events. diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 92a356265..d1ce377e7 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -116,10 +116,9 @@ def feed_eof(self) -> None: raise EOFError("stream ended") self.eof = True - def abort(self) -> None: + def discard(self) -> None: """ - End the stream, discarding all buffered data. + Discarding all buffered data, but don't end the stream. """ - self.eof = True del self.buffer[:] diff --git a/tests/test_connection.py b/tests/test_connection.py index 45a0594d2..b8843028d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1204,17 +1204,15 @@ def test_client_stops_processing_frames_after_fail(self): client = Connection(Side.CLIENT) client.fail(1002) self.assertConnectionFailing(client, 1002) - with self.assertRaises(EOFError) as raised: - client.receive_data(b"\x88\x02\x03\xea") - self.assertEqual(str(raised.exception), "stream ended") + client.receive_data(b"\x88\x02\x03\xea") + self.assertFrameReceived(client, None) def test_server_stops_processing_frames_after_fail(self): server = Connection(Side.SERVER) server.fail(1002) self.assertConnectionFailing(server, 1002) - with self.assertRaises(EOFError) as raised: - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") - self.assertEqual(str(raised.exception), "stream ended") + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") + self.assertFrameReceived(server, None) class FragmentationTests(ConnectionTestCase): @@ -1423,39 +1421,53 @@ def test_server_receives_eof_inside_frame(self): def test_client_receives_data_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: + with self.assertRaises(ProtocolError): client.receive_data(b"\xff\xff") - self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(EOFError) as raised: - client.receive_data(b"\x00\x00") - self.assertEqual(str(raised.exception), "stream ended") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_data(b"\x00\x00") + self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: + with self.assertRaises(ProtocolError): server.receive_data(b"\xff\xff") - self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(EOFError) as raised: - server.receive_data(b"\x00\x00") - self.assertEqual(str(raised.exception), "stream ended") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_data(b"\x00\x00") + self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: + with self.assertRaises(ProtocolError): client.receive_data(b"\xff\xff") - self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_eof() + self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: + with self.assertRaises(ProtocolError): server.receive_data(b"\xff\xff") - self.assertEqual(str(raised.exception), "invalid opcode") - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_eof() + self.assertFrameSent(server, None) + + def test_client_receives_data_and_eof_after_exception(self): + client = Connection(Side.CLIENT) + with self.assertRaises(ProtocolError): + client.receive_data(b"\xff\xff") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_data(b"\x00\x00") + client.receive_eof() + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_data_and_eof_after_exception(self): + server = Connection(Side.SERVER) + with self.assertRaises(ProtocolError): + server.receive_data(b"\xff\xff") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_data(b"\x00\x00") + server.receive_eof() + self.assertFrameSent(server, None) def test_client_receives_data_after_eof(self): client = Connection(Side.CLIENT) diff --git a/tests/test_streams.py b/tests/test_streams.py index e942c8d12..8abefbcc9 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -145,11 +145,13 @@ def test_feed_eof_after_feed_eof(self): self.reader.feed_eof() self.assertEqual(str(raised.exception), "stream ended") - def test_abort(self): + def test_discard(self): gen = self.reader.read_to_eof() self.reader.feed_data(b"spam") - self.reader.abort() + self.reader.discard() + self.assertGeneratorRunning(gen) + self.reader.feed_eof() data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"") From 008c160365bcda5b38ddeccf5613b1b2a8395599 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 10:36:12 +0200 Subject: [PATCH 109/762] Prevent receive_data/eof from raising exceptions. --- src/websockets/connection.py | 17 ++- tests/test_connection.py | 248 +++++++++++++++++------------------ 2 files changed, 132 insertions(+), 133 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 0a3a87bd5..f8637ab20 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -168,6 +168,8 @@ def receive_data(self, data: bytes) -> None: - You must call :meth:`data_to_send` and send this data. - You should call :meth:`events_received` and process these events. + :raises EOFError: if :meth:`receive_eof` was called before + """ self.reader.feed_data(data) next(self.parser) @@ -179,9 +181,11 @@ def receive_eof(self) -> None: After calling this method: - You must call :meth:`data_to_send` and send this data. - - You shouldn't call :meth:`events_received` as it won't + - You aren't exepcted to call :meth:`events_received` as it won't return any new events. + :raises EOFError: if :meth:`receive_eof` was called before + """ self.reader.feed_eof() next(self.parser) @@ -324,7 +328,7 @@ def parse(self) -> Generator[None, None, None]: yield # Once the reader reaches EOF, its feed_data/eof() # methods raise an error, so our receive_data/eof() - # methods don't step the generator. + # methods don't step parse(). raise AssertionError( "parse() shouldn't step after EOF" ) # pragma: no cover @@ -353,29 +357,28 @@ def parse(self) -> Generator[None, None, None]: except ProtocolError as exc: self.fail(1002, str(exc)) self.parser_exc = exc - raise except EOFError as exc: self.fail(1006, str(exc)) self.parser_exc = exc - raise except UnicodeDecodeError as exc: self.fail(1007, f"{exc.reason} at position {exc.start}") self.parser_exc = exc - raise except PayloadTooBig as exc: self.fail(1009, str(exc)) self.parser_exc = exc - raise except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. self.fail(1011) self.parser_exc = exc - raise + + yield + # If an error occurs, parse() is replaced by discard(). + raise AssertionError("parse() shouldn't step after EOF") # pragma: no cover def discard(self) -> Generator[None, None, None]: while not (yield from self.reader.at_eof()): diff --git a/tests/test_connection.py b/tests/test_connection.py index b8843028d..ec62fc7fc 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -126,16 +126,16 @@ def test_server_receives_masked_frame(self): def test_client_receives_masked_frame(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(self.masked_text_frame_data) - self.assertEqual(str(raised.exception), "incorrect masking") + client.receive_data(self.masked_text_frame_data) + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "incorrect masking") self.assertConnectionFailing(client, 1002, "incorrect masking") def test_server_receives_unmasked_frame(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(self.unmasked_text_frame_date) - self.assertEqual(str(raised.exception), "incorrect masking") + server.receive_data(self.unmasked_text_frame_date) + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "incorrect masking") self.assertConnectionFailing(server, 1002, "incorrect masking") @@ -159,16 +159,16 @@ def test_server_sends_unexpected_continuation(self): def test_client_receives_unexpected_continuation(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x00\x00") - self.assertEqual(str(raised.exception), "unexpected continuation frame") + client.receive_data(b"\x00\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "unexpected continuation frame") self.assertConnectionFailing(client, 1002, "unexpected continuation frame") def test_server_receives_unexpected_continuation(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x00\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "unexpected continuation frame") + server.receive_data(b"\x00\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "unexpected continuation frame") self.assertConnectionFailing(server, 1002, "unexpected continuation frame") def test_client_sends_continuation_after_sending_close(self): @@ -198,17 +198,17 @@ def test_client_receives_continuation_after_receiving_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x00\x00") - self.assertEqual(str(raised.exception), "data frame after close frame") + client.receive_data(b"\x00\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "data frame after close frame") def test_server_receives_continuation_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x00\x80\x00\xff\x00\xff") - self.assertEqual(str(raised.exception), "data frame after close frame") + server.receive_data(b"\x00\x80\x00\xff\x00\xff") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "data frame after close frame") class TextTests(ConnectionTestCase): @@ -248,16 +248,16 @@ def test_server_receives_text(self): def test_client_receives_text_over_size_limit(self): client = Connection(Side.CLIENT, max_size=3) - with self.assertRaises(PayloadTooBig) as raised: - client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") - self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") def test_server_receives_text_over_size_limit(self): server = Connection(Side.SERVER, max_size=3) - with self.assertRaises(PayloadTooBig) as raised: - server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") - self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") def test_client_receives_text_without_size_limit(self): @@ -342,9 +342,9 @@ def test_client_receives_fragmented_text_over_size_limit(self): client, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) - with self.assertRaises(PayloadTooBig) as raised: - client.receive_data(b"\x80\x02\x98\x80") - self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + client.receive_data(b"\x80\x02\x98\x80") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") def test_server_receives_fragmented_text_over_size_limit(self): @@ -354,9 +354,9 @@ def test_server_receives_fragmented_text_over_size_limit(self): server, Frame(OP_TEXT, "😀".encode()[:2], fin=False), ) - with self.assertRaises(PayloadTooBig) as raised: - server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") - self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") def test_client_receives_fragmented_text_without_size_limit(self): @@ -416,9 +416,9 @@ def test_client_receives_unexpected_text(self): client, Frame(OP_TEXT, b"", fin=False), ) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x01\x00") - self.assertEqual(str(raised.exception), "expected a continuation frame") + client.receive_data(b"\x01\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "expected a continuation frame") self.assertConnectionFailing(client, 1002, "expected a continuation frame") def test_server_receives_unexpected_text(self): @@ -428,9 +428,9 @@ def test_server_receives_unexpected_text(self): server, Frame(OP_TEXT, b"", fin=False), ) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x01\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "expected a continuation frame") + server.receive_data(b"\x01\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "expected a continuation frame") self.assertConnectionFailing(server, 1002, "expected a continuation frame") def test_client_sends_text_after_sending_close(self): @@ -452,17 +452,17 @@ def test_client_receives_text_after_receiving_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x81\x00") - self.assertEqual(str(raised.exception), "data frame after close frame") + client.receive_data(b"\x81\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "data frame after close frame") def test_server_receives_text_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x81\x80\x00\xff\x00\xff") - self.assertEqual(str(raised.exception), "data frame after close frame") + server.receive_data(b"\x81\x80\x00\xff\x00\xff") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "data frame after close frame") class BinaryTests(ConnectionTestCase): @@ -502,16 +502,16 @@ def test_server_receives_binary(self): def test_client_receives_binary_over_size_limit(self): client = Connection(Side.CLIENT, max_size=3) - with self.assertRaises(PayloadTooBig) as raised: - client.receive_data(b"\x82\x04\x01\x02\xfe\xff") - self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + client.receive_data(b"\x82\x04\x01\x02\xfe\xff") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") def test_server_receives_binary_over_size_limit(self): server = Connection(Side.SERVER, max_size=3) - with self.assertRaises(PayloadTooBig) as raised: - server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") - self.assertEqual(str(raised.exception), "over size limit (4 > 3 bytes)") + server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") def test_client_sends_fragmented_binary(self): @@ -580,9 +580,9 @@ def test_client_receives_fragmented_binary_over_size_limit(self): client, Frame(OP_BINARY, b"\x01\x02", fin=False), ) - with self.assertRaises(PayloadTooBig) as raised: - client.receive_data(b"\x80\x02\xfe\xff") - self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + client.receive_data(b"\x80\x02\xfe\xff") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") def test_server_receives_fragmented_binary_over_size_limit(self): @@ -592,9 +592,9 @@ def test_server_receives_fragmented_binary_over_size_limit(self): server, Frame(OP_BINARY, b"\x01\x02", fin=False), ) - with self.assertRaises(PayloadTooBig) as raised: - server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") - self.assertEqual(str(raised.exception), "over size limit (2 > 1 bytes)") + server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") def test_client_sends_unexpected_binary(self): @@ -618,9 +618,9 @@ def test_client_receives_unexpected_binary(self): client, Frame(OP_BINARY, b"", fin=False), ) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x02\x00") - self.assertEqual(str(raised.exception), "expected a continuation frame") + client.receive_data(b"\x02\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "expected a continuation frame") self.assertConnectionFailing(client, 1002, "expected a continuation frame") def test_server_receives_unexpected_binary(self): @@ -630,9 +630,9 @@ def test_server_receives_unexpected_binary(self): server, Frame(OP_BINARY, b"", fin=False), ) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x02\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "expected a continuation frame") + server.receive_data(b"\x02\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "expected a continuation frame") self.assertConnectionFailing(server, 1002, "expected a continuation frame") def test_client_sends_binary_after_sending_close(self): @@ -654,17 +654,17 @@ def test_client_receives_binary_after_receiving_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x82\x00") - self.assertEqual(str(raised.exception), "data frame after close frame") + client.receive_data(b"\x82\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "data frame after close frame") def test_server_receives_binary_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x82\x80\x00\xff\x00\xff") - self.assertEqual(str(raised.exception), "data frame after close frame") + server.receive_data(b"\x82\x80\x00\xff\x00\xff") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "data frame after close frame") class CloseTests(ConnectionTestCase): @@ -871,26 +871,27 @@ def test_server_sends_close_with_reason_only(self): def test_client_receives_close_with_truncated_code(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x88\x01\x03") - self.assertEqual(str(raised.exception), "close frame too short") + client.receive_data(b"\x88\x01\x03") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "close frame too short") self.assertConnectionFailing(client, 1002, "close frame too short") self.assertIs(client.state, State.CLOSING) def test_server_receives_close_with_truncated_code(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") - self.assertEqual(str(raised.exception), "close frame too short") + server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "close frame too short") self.assertConnectionFailing(server, 1002, "close frame too short") self.assertIs(server.state, State.CLOSING) def test_client_receives_close_with_non_utf8_reason(self): client = Connection(Side.CLIENT) - with self.assertRaises(UnicodeDecodeError) as raised: - client.receive_data(b"\x88\x04\x03\xe8\xff\xff") + + client.receive_data(b"\x88\x04\x03\xe8\xff\xff") + self.assertIsInstance(client.parser_exc, UnicodeDecodeError) self.assertEqual( - str(raised.exception), + str(client.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") @@ -898,10 +899,11 @@ def test_client_receives_close_with_non_utf8_reason(self): def test_server_receives_close_with_non_utf8_reason(self): server = Connection(Side.SERVER) - with self.assertRaises(UnicodeDecodeError) as raised: - server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") + + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") + self.assertIsInstance(server.parser_exc, UnicodeDecodeError) self.assertEqual( - str(raised.exception), + str(server.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") @@ -1002,16 +1004,16 @@ def test_server_sends_fragmented_ping_frame(self): def test_client_receives_fragmented_ping_frame(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x09\x00") - self.assertEqual(str(raised.exception), "fragmented control frame") + client.receive_data(b"\x09\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing(client, 1002, "fragmented control frame") def test_server_receives_fragmented_ping_frame(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") - self.assertEqual(str(raised.exception), "fragmented control frame") + server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing(server, 1002, "fragmented control frame") def test_client_sends_ping_after_sending_close(self): @@ -1142,16 +1144,16 @@ def test_server_sends_fragmented_pong_frame(self): def test_client_receives_fragmented_pong_frame(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x0a\x00") - self.assertEqual(str(raised.exception), "fragmented control frame") + client.receive_data(b"\x0a\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing(client, 1002, "fragmented control frame") def test_server_receives_fragmented_pong_frame(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") - self.assertEqual(str(raised.exception), "fragmented control frame") + server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing(server, 1002, "fragmented control frame") def test_client_sends_pong_after_sending_close(self): @@ -1349,9 +1351,9 @@ def test_client_receive_close_in_fragmented_message(self): # frames in the middle of a fragmented message." However, since the # endpoint must not send a data frame after a close frame, a close # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - client.receive_data(b"\x88\x02\x03\xe8") - self.assertEqual(str(raised.exception), "incomplete fragmented message") + client.receive_data(b"\x88\x02\x03\xe8") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "incomplete fragmented message") self.assertConnectionFailing(client, 1002, "incomplete fragmented message") def test_server_receive_close_in_fragmented_message(self): @@ -1365,9 +1367,9 @@ def test_server_receive_close_in_fragmented_message(self): # frames in the middle of a fragmented message." However, since the # endpoint must not send a data frame after a close frame, a close # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertEqual(str(raised.exception), "incomplete fragmented message") + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "incomplete fragmented message") self.assertConnectionFailing(server, 1002, "incomplete fragmented message") @@ -1391,70 +1393,65 @@ def test_server_receives_eof(self): def test_client_receives_eof_between_frames(self): client = Connection(Side.CLIENT) - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "unexpected end of stream") + client.receive_eof() + self.assertIsInstance(client.parser_exc, EOFError) + self.assertEqual(str(client.parser_exc), "unexpected end of stream") def test_server_receives_eof_between_frames(self): server = Connection(Side.SERVER) - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "unexpected end of stream") + server.receive_eof() + self.assertIsInstance(server.parser_exc, EOFError) + self.assertEqual(str(server.parser_exc), "unexpected end of stream") def test_client_receives_eof_inside_frame(self): client = Connection(Side.CLIENT) client.receive_data(b"\x81") - with self.assertRaises(EOFError) as raised: - client.receive_eof() + client.receive_eof() + self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual( - str(raised.exception), "stream ends after 1 bytes, expected 2 bytes" + str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes" ) def test_server_receives_eof_inside_frame(self): server = Connection(Side.SERVER) server.receive_data(b"\x81") - with self.assertRaises(EOFError) as raised: - server.receive_eof() + server.receive_eof() + self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual( - str(raised.exception), "stream ends after 1 bytes, expected 2 bytes" + str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes" ) def test_client_receives_data_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError): - client.receive_data(b"\xff\xff") + client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_data(b"\x00\x00") self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError): - server.receive_data(b"\xff\xff") + server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_data(b"\x00\x00") self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError): - client.receive_data(b"\xff\xff") + client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_eof() self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError): - server.receive_data(b"\xff\xff") + server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_eof() self.assertFrameSent(server, None) def test_client_receives_data_and_eof_after_exception(self): client = Connection(Side.CLIENT) - with self.assertRaises(ProtocolError): - client.receive_data(b"\xff\xff") + client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_data(b"\x00\x00") client.receive_eof() @@ -1462,8 +1459,7 @@ def test_client_receives_data_and_eof_after_exception(self): def test_server_receives_data_and_eof_after_exception(self): server = Connection(Side.SERVER) - with self.assertRaises(ProtocolError): - server.receive_data(b"\xff\xff") + server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_data(b"\x00\x00") server.receive_eof() @@ -1516,18 +1512,18 @@ def test_client_hits_internal_error_reading_frame(self): client = Connection(Side.CLIENT) # This isn't supposed to happen, so we're simulating it. with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): - with self.assertRaises(RuntimeError) as raised: - client.receive_data(b"\x81\x00") - self.assertEqual(str(raised.exception), "BOOM") + client.receive_data(b"\x81\x00") + self.assertIsInstance(client.parser_exc, RuntimeError) + self.assertEqual(str(client.parser_exc), "BOOM") self.assertConnectionFailing(client, 1011, "") def test_server_hits_internal_error_reading_frame(self): server = Connection(Side.SERVER) # This isn't supposed to happen, so we're simulating it. with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): - with self.assertRaises(RuntimeError) as raised: - server.receive_data(b"\x81\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "BOOM") + server.receive_data(b"\x81\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, RuntimeError) + self.assertEqual(str(server.parser_exc), "BOOM") self.assertConnectionFailing(server, 1011, "") From 8add3a648ae4886b78db018f5b672abc0f6e5a8c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 14:59:43 +0200 Subject: [PATCH 110/762] Indicate when TCP connection should close --- src/websockets/connection.py | 16 ++++++++ tests/test_connection.py | 71 +++++++++++++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index f8637ab20..581614e49 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -313,6 +313,22 @@ def data_to_send(self) -> List[bytes]: writes, self.writes = self.writes, [] return writes + def close_expected(self) -> bool: + """ + Tell whether the TCP connection is expected to close soon. + + Call this method immediately after calling any of the ``receive_*()`` + methods and, if it returns ``True``, schedule closing the TCP + connection after a short timeout. + + """ + # We already got a TCP Close if and only if the state is CLOSED. + # We expect a TCP close if and only if we sent a close frame: + # * Normal closure: once we send a close frame, we expect a TCP close. + # * Abnormal closure: we always send a close frame except on EOFError, + # but that's fine because we already got the TCP close. + return self.state is not CLOSED and self.close_sent is not None + # Private methods for receiving data. def parse(self) -> Generator[None, None, None]: diff --git a/tests/test_connection.py b/tests/test_connection.py index ec62fc7fc..e420c7853 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1027,7 +1027,8 @@ def test_client_sends_ping_after_sending_close(self): with self.assertRaises(InvalidState) as raised: client.send_ping(b"") self.assertEqual( - str(raised.exception), "cannot write to a WebSocket in the CLOSING state" + str(raised.exception), + "cannot write to a WebSocket in the CLOSING state", ) def test_server_sends_ping_after_sending_close(self): @@ -1040,7 +1041,8 @@ def test_server_sends_ping_after_sending_close(self): with self.assertRaises(InvalidState) as raised: server.send_ping(b"") self.assertEqual( - str(raised.exception), "cannot write to a WebSocket in the CLOSING state" + str(raised.exception), + "cannot write to a WebSocket in the CLOSING state", ) def test_client_receives_ping_after_receiving_close(self): @@ -1375,7 +1377,7 @@ def test_server_receive_close_in_fragmented_message(self): class EOFTests(ConnectionTestCase): """ - Test connection termination. + Test half-closes on connection termination. """ @@ -1409,7 +1411,8 @@ def test_client_receives_eof_inside_frame(self): client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual( - str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes" + str(client.parser_exc), + "stream ends after 1 bytes, expected 2 bytes", ) def test_server_receives_eof_inside_frame(self): @@ -1418,7 +1421,8 @@ def test_server_receives_eof_inside_frame(self): server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual( - str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes" + str(server.parser_exc), + "stream ends after 1 bytes, expected 2 bytes", ) def test_client_receives_data_after_exception(self): @@ -1502,6 +1506,63 @@ def test_server_receives_eof_after_eof(self): self.assertEqual(str(raised.exception), "stream ended") +class TCPCloseTests(ConnectionTestCase): + """ + Test expectation of TCP close on connection termination. + + """ + + def test_client_default(self): + client = Connection(Side.CLIENT) + self.assertFalse(client.close_expected()) + + def test_server_default(self): + server = Connection(Side.SERVER) + self.assertFalse(server.close_expected()) + + def test_client_sends_close(self): + client = Connection(Side.CLIENT) + client.send_close() + self.assertTrue(client.close_expected()) + + def test_server_sends_close(self): + server = Connection(Side.SERVER) + server.send_close() + self.assertTrue(server.close_expected()) + + def test_client_receives_close(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + self.assertTrue(client.close_expected()) + + def test_client_receives_close_then_eof(self): + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x00") + client.receive_eof() + self.assertFalse(client.close_expected()) + + def test_server_receives_close_then_eof(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + server.receive_eof() + self.assertFalse(server.close_expected()) + + def test_server_receives_close(self): + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertTrue(server.close_expected()) + + def test_client_fails_connection(self): + client = Connection(Side.CLIENT) + client.fail(1002) + self.assertTrue(client.close_expected()) + + def test_server_fails_connection(self): + server = Connection(Side.SERVER) + server.fail(1002) + self.assertTrue(server.close_expected()) + + class ErrorTests(ConnectionTestCase): """ Test other error cases. From d5ab29687ae19170cbf6ee80ecd26d7de0862f5a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 15:15:41 +0200 Subject: [PATCH 111/762] Set connection state to closed after EOF --- src/websockets/connection.py | 6 ++++-- tests/test_connection.py | 22 ++++++++++++++-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 581614e49..2e6871c00 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -128,7 +128,7 @@ def close_code(self) -> Optional[int]: Available once the connection is closed. """ - if self.state is not State.CLOSED: + if self.state is not CLOSED: return None elif self.close_rcvd is None: return 1006 @@ -143,7 +143,7 @@ def close_reason(self) -> Optional[str]: Available once the connection is closed. """ - if self.state is not State.CLOSED: + if self.state is not CLOSED: return None elif self.close_rcvd is None: return "" @@ -340,6 +340,7 @@ def parse(self) -> Generator[None, None, None]: if self.close_rcvd_then_sent is not None: if self.side is CLIENT: self.send_eof() + self.set_state(CLOSED) # If parse() completes normally, execution ends here. yield # Once the reader reaches EOF, its feed_data/eof() @@ -402,6 +403,7 @@ def discard(self) -> Generator[None, None, None]: if self.side is CLIENT: self.send_eof() # If discard() completes normally, execution ends here. + self.set_state(CLOSED) yield # Once the reader reaches EOF, its feed_data/eof() # methods raise an error, so our receive_data/eof() diff --git a/tests/test_connection.py b/tests/test_connection.py index e420c7853..fa559bc43 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -682,35 +682,35 @@ class CloseTests(ConnectionTestCase): def test_close_code(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") - client.set_state(State.CLOSED) + client.receive_eof() self.assertEqual(client.close_code, 1000) def test_close_reason(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") - server.set_state(State.CLOSED) + server.receive_eof() self.assertEqual(server.close_reason, "OK") def test_close_code_not_provided(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x00\x00\x00\x00") - server.set_state(State.CLOSED) + server.receive_eof() self.assertEqual(server.close_code, 1005) def test_close_reason_not_provided(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x00") - client.set_state(State.CLOSED) + client.receive_eof() self.assertEqual(client.close_reason, "") def test_close_code_not_available(self): client = Connection(Side.CLIENT) - client.set_state(State.CLOSED) + client.receive_eof() self.assertEqual(client.close_code, 1006) def test_close_reason_not_available(self): server = Connection(Side.SERVER) - server.set_state(State.CLOSED) + server.receive_eof() self.assertEqual(server.close_reason, "") def test_close_code_not_available_yet(self): @@ -1385,25 +1385,29 @@ def test_client_receives_eof(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) - client.receive_eof() # does not raise an exception + client.receive_eof() + self.assertIs(client.state, State.CLOSED) def test_server_receives_eof(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) - server.receive_eof() # does not raise an exception + server.receive_eof() + self.assertIs(server.state, State.CLOSED) def test_client_receives_eof_between_frames(self): client = Connection(Side.CLIENT) client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual(str(client.parser_exc), "unexpected end of stream") + self.assertIs(client.state, State.CLOSED) def test_server_receives_eof_between_frames(self): server = Connection(Side.SERVER) server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual(str(server.parser_exc), "unexpected end of stream") + self.assertIs(server.state, State.CLOSED) def test_client_receives_eof_inside_frame(self): client = Connection(Side.CLIENT) @@ -1414,6 +1418,7 @@ def test_client_receives_eof_inside_frame(self): str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) + self.assertIs(client.state, State.CLOSED) def test_server_receives_eof_inside_frame(self): server = Connection(Side.SERVER) @@ -1424,6 +1429,7 @@ def test_server_receives_eof_inside_frame(self): str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) + self.assertIs(server.state, State.CLOSED) def test_client_receives_data_after_exception(self): client = Connection(Side.CLIENT) From 48527e4c1c07a9d2c68ac16f91dd2f0c267c18b5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 18:43:56 +0200 Subject: [PATCH 112/762] Replace set_state by a property setter. --- src/websockets/client.py | 2 +- src/websockets/connection.py | 35 +++++++++++++++++++++-------------- src/websockets/server.py | 2 +- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index cc807e97c..6a0bb580e 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -307,7 +307,7 @@ def parse(self) -> Generator[None, None, None]: response.exception = exc else: assert self.state is CONNECTING - self.set_state(OPEN) + self.state = OPEN finally: self.events.append(response) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 2e6871c00..8f407f5c7 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -83,8 +83,8 @@ def __init__( # Connection side. CLIENT or SERVER. self.side = side - # Connnection state. CONNECTING and CLOSED states are handled in subclasses. - self.set_state(state) + # Connnection state. Initially OPEN because subclasses handle CONNECTING. + self.state = state # Maximum size of incoming messages in bytes. self.max_size = max_size @@ -120,6 +120,20 @@ def __init__( # Public attributes + @property + def state(self) -> State: + """ + WebSocket connection state. + + """ + return self._state + + @state.setter + def state(self, state: State) -> None: + if self.debug: + self.logger.debug("= connection is %s", state.name) + self._state = state + @property def close_code(self) -> Optional[int]: """ @@ -150,13 +164,6 @@ def close_reason(self) -> Optional[str]: else: return self.close_rcvd.reason - # Private attributes - - def set_state(self, state: State) -> None: - if self.debug: - self.logger.debug("= connection is %s", state.name) - self.state = state - # Public methods for receiving data. def receive_data(self, data: bytes) -> None: @@ -241,7 +248,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: # 7.1.3. The WebSocket Closing Handshake is Started self.send_frame(Frame(OP_CLOSE, data)) self.close_sent = close - self.set_state(CLOSING) + self.state = CLOSING def send_ping(self, data: bytes) -> None: """ @@ -273,7 +280,7 @@ def fail(self, code: int, reason: str = "") -> None: data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) self.close_sent = close - self.set_state(CLOSING) + self.state = CLOSING if self.side is SERVER and not self.eof_sent: self.send_eof() @@ -340,7 +347,7 @@ def parse(self) -> Generator[None, None, None]: if self.close_rcvd_then_sent is not None: if self.side is CLIENT: self.send_eof() - self.set_state(CLOSED) + self.state = CLOSED # If parse() completes normally, execution ends here. yield # Once the reader reaches EOF, its feed_data/eof() @@ -403,7 +410,7 @@ def discard(self) -> Generator[None, None, None]: if self.side is CLIENT: self.send_eof() # If discard() completes normally, execution ends here. - self.set_state(CLOSED) + self.state = CLOSED yield # Once the reader reaches EOF, its feed_data/eof() # methods raise an error, so our receive_data/eof() @@ -475,7 +482,7 @@ def recv_frame(self, frame: Frame) -> None: self.send_frame(Frame(OP_CLOSE, frame.data)) self.close_sent = self.close_rcvd self.close_rcvd_then_sent = True - self.set_state(CLOSING) + self.state = CLOSING # 7.1.2. Start the WebSocket Closing Handshake: "Once an # endpoint has both sent and received a Close control frame, diff --git a/src/websockets/server.py b/src/websockets/server.py index 545b38c98..d6d2143bb 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -432,7 +432,7 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: assert self.state is CONNECTING - self.set_state(OPEN) + self.state = OPEN else: self.send_eof() From b96d82202cf6d7fbe2196c07c23679cf2318055b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 21:26:47 +0200 Subject: [PATCH 113/762] Add API for getting connection closed exception. --- src/websockets/connection.py | 49 +++++++++++++++-- tests/test_connection.py | 103 ++++++++++++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 7 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 8f407f5c7..4d8851787 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -3,11 +3,19 @@ import enum import logging import uuid -from typing import Generator, List, Optional, Union - -from .exceptions import InvalidState, PayloadTooBig, ProtocolError +from typing import Generator, List, Optional, Type, Union + +from .exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) from .extensions import Extension from .frames import ( + OK_CLOSE_CODES, OP_BINARY, OP_CLOSE, OP_CONT, @@ -123,7 +131,7 @@ def __init__( @property def state(self) -> State: """ - WebSocket connection state. + Connection State defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. """ return self._state @@ -137,7 +145,7 @@ def state(self, state: State) -> None: @property def close_code(self) -> Optional[int]: """ - WebSocket close code received in a close frame. + Connection Close Code defined in 7.1.5 of :rfc:`6455`. Available once the connection is closed. @@ -152,7 +160,7 @@ def close_code(self) -> Optional[int]: @property def close_reason(self) -> Optional[str]: """ - WebSocket close reason received in a close frame. + Connection Close Reason defined in 7.1.6 of :rfc:`6455`. Available once the connection is closed. @@ -164,6 +172,35 @@ def close_reason(self) -> Optional[str]: else: return self.close_rcvd.reason + @property + def connection_closed_exc(self) -> ConnectionClosed: + """ + Exception raised when trying to interact with a closed connection. + + Available once the connection is closed. If you need to raise this + exception while the connection is closing, wait until it's closed. + + """ + assert self.state is CLOSED + exc_type: Type[ConnectionClosed] + if ( + self.close_rcvd is not None + and self.close_sent is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent.code in OK_CLOSE_CODES + ): + exc_type = ConnectionClosedOK + else: + exc_type = ConnectionClosedError + exc: ConnectionClosed = exc_type( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + # Chain to the exception raised in the parser, if any. + exc.__cause__ = self.parser_exc + return exc + # Public methods for receiving data. def receive_data(self, data: bytes) -> None: diff --git a/tests/test_connection.py b/tests/test_connection.py index fa559bc43..10eb58eeb 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,7 +1,13 @@ import unittest.mock from websockets.connection import * -from websockets.exceptions import InvalidState, PayloadTooBig, ProtocolError +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) from websockets.frames import ( OP_BINARY, OP_CLOSE, @@ -1569,6 +1575,101 @@ def test_server_fails_connection(self): self.assertTrue(server.close_expected()) +class ConnectionClosedTests(ConnectionTestCase): + """ + Test connection closed exception. + + """ + + def test_client_sends_close_then_receives_close(self): + # Client-initiated close handshake on the client side complete. + client = Connection(Side.CLIENT) + client.send_close(1000, "") + client.receive_data(b"\x88\x02\x03\xe8") + client.receive_eof() + exc = client.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertFalse(exc.rcvd_then_sent) + + def test_server_sends_close_then_receives_close(self): + # Server-initiated close handshake on the server side complete. + server = Connection(Side.SERVER) + server.send_close(1000, "") + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") + server.receive_eof() + exc = server.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertFalse(exc.rcvd_then_sent) + + def test_client_receives_close_then_sends_close(self): + # Server-initiated close handshake on the client side complete. + client = Connection(Side.CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + client.receive_eof() + exc = client.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertTrue(exc.rcvd_then_sent) + + def test_server_receives_close_then_sends_close(self): + # Client-initiated close handshake on the server side complete. + server = Connection(Side.SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") + server.receive_eof() + exc = server.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertTrue(exc.rcvd_then_sent) + + def test_client_sends_close_then_receives_eof(self): + # Client-initiated close handshake on the client side times out. + client = Connection(Side.CLIENT) + client.send_close(1000, "") + client.receive_eof() + exc = client.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertIsNone(exc.rcvd_then_sent) + + def test_server_sends_close_then_receives_eof(self): + # Server-initiated close handshake on the server side times out. + server = Connection(Side.SERVER) + server.send_close(1000, "") + server.receive_eof() + exc = server.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertIsNone(exc.rcvd_then_sent) + + def test_client_receives_eof(self): + # Server-initiated close handshake on the client side times out. + client = Connection(Side.CLIENT) + client.receive_eof() + exc = client.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertIsNone(exc.sent) + self.assertIsNone(exc.rcvd_then_sent) + + def test_server_receives_eof(self): + # Client-initiated close handshake on the server side times out. + server = Connection(Side.SERVER) + server.receive_eof() + exc = server.connection_closed_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertIsNone(exc.sent) + self.assertIsNone(exc.rcvd_then_sent) + + class ErrorTests(ConnectionTestCase): """ Test other error cases. From a9de47d9247335d4a4b2df551b1732049597ec58 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Jun 2021 08:29:04 +0200 Subject: [PATCH 114/762] Rename API with a shorter name. --- src/websockets/connection.py | 2 +- tests/test_connection.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4d8851787..4a4b5ddd0 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -173,7 +173,7 @@ def close_reason(self) -> Optional[str]: return self.close_rcvd.reason @property - def connection_closed_exc(self) -> ConnectionClosed: + def close_exc(self) -> ConnectionClosed: """ Exception raised when trying to interact with a closed connection. diff --git a/tests/test_connection.py b/tests/test_connection.py index 10eb58eeb..b82428eab 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1587,7 +1587,7 @@ def test_client_sends_close_then_receives_close(self): client.send_close(1000, "") client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() - exc = client.connection_closed_exc + exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(1000, "")) self.assertEqual(exc.sent, Close(1000, "")) @@ -1599,7 +1599,7 @@ def test_server_sends_close_then_receives_close(self): server.send_close(1000, "") server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() - exc = server.connection_closed_exc + exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(1000, "")) self.assertEqual(exc.sent, Close(1000, "")) @@ -1610,7 +1610,7 @@ def test_client_receives_close_then_sends_close(self): client = Connection(Side.CLIENT) client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() - exc = client.connection_closed_exc + exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(1000, "")) self.assertEqual(exc.sent, Close(1000, "")) @@ -1621,7 +1621,7 @@ def test_server_receives_close_then_sends_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() - exc = server.connection_closed_exc + exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) self.assertEqual(exc.rcvd, Close(1000, "")) self.assertEqual(exc.sent, Close(1000, "")) @@ -1632,7 +1632,7 @@ def test_client_sends_close_then_receives_eof(self): client = Connection(Side.CLIENT) client.send_close(1000, "") client.receive_eof() - exc = client.connection_closed_exc + exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertEqual(exc.sent, Close(1000, "")) @@ -1643,7 +1643,7 @@ def test_server_sends_close_then_receives_eof(self): server = Connection(Side.SERVER) server.send_close(1000, "") server.receive_eof() - exc = server.connection_closed_exc + exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertEqual(exc.sent, Close(1000, "")) @@ -1653,7 +1653,7 @@ def test_client_receives_eof(self): # Server-initiated close handshake on the client side times out. client = Connection(Side.CLIENT) client.receive_eof() - exc = client.connection_closed_exc + exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertIsNone(exc.sent) @@ -1663,7 +1663,7 @@ def test_server_receives_eof(self): # Client-initiated close handshake on the server side times out. server = Connection(Side.SERVER) server.receive_eof() - exc = server.connection_closed_exc + exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) self.assertIsNone(exc.sent) From 5d8e0c5dbefa1c487c14736943f858e3b0dc9921 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Jun 2021 08:44:11 +0200 Subject: [PATCH 115/762] Stop processing data after receiving close frame. --- src/websockets/connection.py | 46 +++++++++++++----------------------- tests/test_connection.py | 46 ++++++++++++++---------------------- 2 files changed, 34 insertions(+), 58 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 4a4b5ddd0..87a4e01b1 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -381,20 +381,11 @@ def parse(self) -> Generator[None, None, None]: if (yield from self.reader.at_eof()): if self.debug: self.logger.debug("< EOF") - if self.close_rcvd_then_sent is not None: - if self.side is CLIENT: - self.send_eof() - self.state = CLOSED - # If parse() completes normally, execution ends here. - yield - # Once the reader reaches EOF, its feed_data/eof() - # methods raise an error, so our receive_data/eof() - # methods don't step parse(). - raise AssertionError( - "parse() shouldn't step after EOF" - ) # pragma: no cover - else: - raise EOFError("unexpected end of stream") + # If the WebSocket connection is closed cleanly, with a + # closing handhshake, recv_frame() substitutes parse() + # with discard(). This branch is reached only when the + # connection isn't closed cleanly. + raise EOFError("unexpected end of stream") if self.max_size is None: max_size = None @@ -444,6 +435,8 @@ def parse(self) -> Generator[None, None, None]: def discard(self) -> Generator[None, None, None]: while not (yield from self.reader.at_eof()): self.reader.discard() + if self.debug: + self.logger.debug("< EOF") if self.side is CLIENT: self.send_eof() # If discard() completes normally, execution ends here. @@ -456,11 +449,6 @@ def discard(self) -> Generator[None, None, None]: def recv_frame(self, frame: Frame) -> None: if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: - # 5.5.1 Close: "The application MUST NOT send any more data - # frames after sending a Close frame." - if self.close_rcvd is not None: - raise ProtocolError("data frame after close frame") - if self.cur_size is not None: raise ProtocolError("expected a continuation frame") if frame.fin: @@ -469,11 +457,6 @@ def recv_frame(self, frame: Frame) -> None: self.cur_size = len(frame.data) elif frame.opcode is OP_CONT: - # 5.5.1 Close: "The application MUST NOT send any more data - # frames after sending a Close frame." - if self.close_rcvd is not None: - raise ProtocolError("data frame after close frame") - if self.cur_size is None: raise ProtocolError("unexpected continuation frame") if frame.fin: @@ -483,11 +466,9 @@ def recv_frame(self, frame: Frame) -> None: elif frame.opcode is OP_PING: # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST - # send a Pong frame in response, unless it already received a - # Close frame." - if self.close_rcvd is None: - pong_frame = Frame(OP_PONG, frame.data) - self.send_frame(pong_frame) + # send a Pong frame in response" + pong_frame = Frame(OP_PONG, frame.data) + self.send_frame(pong_frame) elif frame.opcode is OP_PONG: # 5.5.3 Pong: "A response to an unsolicited Pong frame is not @@ -528,6 +509,12 @@ def recv_frame(self, frame: Frame) -> None: if self.side is SERVER: self.send_eof() + # 1.4. Closing Handshake: "after receiving a control frame + # indicating the connection should be closed, a peer discards + # any further data received." + self.parser = self.discard() + next(self.parser) # start coroutine + else: # pragma: no cover # This can't happen because Frame.parse() validates opcodes. raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") @@ -537,7 +524,6 @@ def recv_frame(self, frame: Frame) -> None: # Private methods for sending events. def send_frame(self, frame: Frame) -> None: - # Defensive assertion for protocol compliance. if self.state is not OPEN: raise InvalidState( f"cannot write to a WebSocket in the {self.state.name} state" diff --git a/tests/test_connection.py b/tests/test_connection.py index b82428eab..ac0df46ce 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -205,16 +205,16 @@ def test_client_receives_continuation_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x00\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "data frame after close frame") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) def test_server_receives_continuation_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x00\x80\x00\xff\x00\xff") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "data frame after close frame") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) class TextTests(ConnectionTestCase): @@ -459,16 +459,16 @@ def test_client_receives_text_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x81\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "data frame after close frame") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) def test_server_receives_text_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x81\x80\x00\xff\x00\xff") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "data frame after close frame") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) class BinaryTests(ConnectionTestCase): @@ -661,16 +661,16 @@ def test_client_receives_binary_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x82\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "data frame after close frame") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) def test_server_receives_binary_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x82\x80\x00\xff\x00\xff") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "data frame after close frame") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) class CloseTests(ConnectionTestCase): @@ -1056,10 +1056,7 @@ def test_client_receives_ping_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") - self.assertFrameReceived( - client, - Frame(OP_PING, b"\x22\x66\xaa\xee"), - ) + self.assertFrameReceived(client, None) self.assertFrameSent(client, None) def test_server_receives_ping_after_receiving_close(self): @@ -1067,10 +1064,7 @@ def test_server_receives_ping_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived( - server, - Frame(OP_PING, b"\x22\x66\xaa\xee"), - ) + self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -1186,20 +1180,16 @@ def test_client_receives_pong_after_receiving_close(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") - self.assertFrameReceived( - client, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) def test_server_receives_pong_after_receiving_close(self): server = Connection(Side.SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived( - server, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) class FailTests(ConnectionTestCase): From d5cf1efb737a943583b1e5b4ceca5376bdc3995f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Jun 2021 08:22:32 +0200 Subject: [PATCH 116/762] Improve comments on EOF handling. --- src/websockets/connection.py | 68 ++++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 87a4e01b1..857698f4f 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -319,12 +319,17 @@ def fail(self, code: int, reason: str = "") -> None: self.close_sent = close self.state = CLOSING + # When failing the connection, a server closes the TCP connection + # without waiting for the client to complete the handshake, while a + # client waits for the server to close the TCP connection, possibly + # after sending a close frame that the client will ignore. if self.side is SERVER and not self.eof_sent: self.send_eof() - # "An endpoint MUST NOT continue to attempt to process data - # (including a responding Close frame) from the remote endpoint - # after being instructed to _Fail the WebSocket Connection_." + # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue + # to attempt to process data(including a responding Close frame) from + # the remote endpoint after being instructed to _Fail the WebSocket + # Connection_." self.parser = self.discard() next(self.parser) # start coroutine @@ -362,20 +367,30 @@ def close_expected(self) -> bool: Tell whether the TCP connection is expected to close soon. Call this method immediately after calling any of the ``receive_*()`` - methods and, if it returns ``True``, schedule closing the TCP - connection after a short timeout. + or ``fail_*()`` methods and, if it returns ``True``, schedule closing + the TCP connection after a short timeout. """ # We already got a TCP Close if and only if the state is CLOSED. # We expect a TCP close if and only if we sent a close frame: - # * Normal closure: once we send a close frame, we expect a TCP close. - # * Abnormal closure: we always send a close frame except on EOFError, - # but that's fine because we already got the TCP close. + # * Normal closure: once we send a close frame, we expect a TCP close: + # server waits for client to complete the TCP closing handshake; + # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic + # applies, except on EOFError where we don't send a close frame + # because we already received the TCP close, so we don't expect it. return self.state is not CLOSED and self.close_sent is not None # Private methods for receiving data. def parse(self) -> Generator[None, None, None]: + """ + Parse incoming data into frames. + + :meth:`receive_data` and :meth:`receive_eof` run this generator + coroutine until it needs more data or reaches EOF. + + """ try: while True: if (yield from self.reader.at_eof()): @@ -394,6 +409,9 @@ def parse(self) -> Generator[None, None, None]: else: max_size = self.max_size - self.cur_size + # During a normal closure, execution ends here on the next + # iteration of the loop after receiving a close frame. At + # this point, recv_frame() replaced parse() by discard(). frame = yield from Frame.parse( self.reader.read_exact, mask=self.side is SERVER, @@ -428,26 +446,46 @@ def parse(self) -> Generator[None, None, None]: self.fail(1011) self.parser_exc = exc + # During an abnormal closure, execution ends here after catching an + # exception. At this point, fail() replaced parse() by discard(). yield - # If an error occurs, parse() is replaced by discard(). - raise AssertionError("parse() shouldn't step after EOF") # pragma: no cover + raise AssertionError("parse() shouldn't step after error") # pragma: no cover def discard(self) -> Generator[None, None, None]: + """ + Discard incoming data. + + This coroutine replaces :meth:`parse`: + + - after receiving a close frame, during a normal closure (1.4); + - after sending a close frame, during an abnormal closure (7.1.7). + + """ + # The server close the TCP connection in the same circumstances where + # discard() replaces parse(). The client closes the connection later, + # after the server closes the connection or a timeout elapses. + # (The latter case cannot be handled in this Sans-I/O layer.) + assert (self.side is SERVER) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. if self.side is CLIENT: self.send_eof() - # If discard() completes normally, execution ends here. self.state = CLOSED + # If discard() completes normally, execution ends here. yield - # Once the reader reaches EOF, its feed_data/eof() - # methods raise an error, so our receive_data/eof() - # methods don't step the generator. + # Once the reader reaches EOF, its feed_data/eof() methods raise an + # error, so our receive_data/eof() methods don't step the generator. raise AssertionError("discard() shouldn't step after EOF") # pragma: no cover def recv_frame(self, frame: Frame) -> None: + """ + Process an incoming frame. + + """ if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: if self.cur_size is not None: raise ProtocolError("expected a continuation frame") @@ -506,6 +544,8 @@ def recv_frame(self, frame: Frame) -> None: # endpoint has both sent and received a Close control frame, # that endpoint SHOULD _Close the WebSocket Connection_" + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. if self.side is SERVER: self.send_eof() From 4ed76b2cc43b9480428121cc9f94ee36af4ea27f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 09:43:37 +0200 Subject: [PATCH 117/762] Document all types visible in public APIs. --- docs/reference/utilities.rst | 20 +++++++++++++++-- src/websockets/__init__.py | 2 +- src/websockets/datastructures.py | 3 +-- src/websockets/typing.py | 38 ++++++++++++++++++-------------- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index f1d89eddc..e7f489fbd 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -10,7 +10,12 @@ Data structures --------------- .. automodule:: websockets.datastructures - :members: + + .. autoclass:: Headers + + .. autodata:: HeadersLike + + .. autoexception:: MultipleValuesError Exceptions ---------- @@ -22,4 +27,15 @@ Types ----- .. automodule:: websockets.typing - :members: + + .. autodata:: Data + + .. autodata:: LoggerLike + + .. autodata:: Origin + + .. autodata:: Subprotocol + + .. autodata:: ExtensionName + + .. autodata:: ExtensionParameter diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 6378b82cf..5883c3d65 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -14,7 +14,7 @@ "ConnectionClosedOK", "Data", "DuplicateParameter", - "ExtensionHeader", + "ExtensionName", "ExtensionParameter", "InvalidHandshake", "InvalidHeader", diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 117ffd4f2..3e3f9705a 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -159,5 +159,4 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] -HeadersLike__doc__ = """Types accepted wherever :class:`Headers` is expected""" -HeadersLike.__doc__ = HeadersLike__doc__ +HeadersLike.__doc__ = """Types accepted where :class:`Headers` is expected""" diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 13b172f15..1bd118071 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -8,46 +8,52 @@ "Data", "LoggerLike", "Origin", - "ExtensionHeader", - "ExtensionParameter", "Subprotocol", + "ExtensionName", + "ExtensionParameter", ] + +# Public types used in the signature of public APIs + Data = Union[str, bytes] Data.__doc__ = """ Types supported in a WebSocket message: -- :class:`str` for text messages -- :class:`bytes` for binary messages - +- :class:`str` for text messages; +- :class:`bytes` for binary messages. """ + LoggerLike = Union[logging.Logger, logging.LoggerAdapter] -LoggerLike.__doc__ = """"Types accepted where :class:`~logging.Logger` is expected""" +LoggerLike.__doc__ = """Types accepted where :class:`~logging.Logger` is expected.""" + Origin = NewType("Origin", str) -Origin.__doc__ = """Value of a Origin header""" +Origin.__doc__ = """Value of a Origin header.""" + + +Subprotocol = NewType("Subprotocol", str) +Subprotocol.__doc__ = """Subprotocol in a Sec-WebSocket-Protocol header.""" ExtensionName = NewType("ExtensionName", str) -ExtensionName.__doc__ = """Name of a WebSocket extension""" +ExtensionName.__doc__ = """Name of a WebSocket extension.""" ExtensionParameter = Tuple[str, Optional[str]] -ExtensionParameter.__doc__ = """Parameter of a WebSocket extension""" +ExtensionParameter.__doc__ = """Parameter of a WebSocket extension.""" -ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] -ExtensionHeader.__doc__ = """Extension in a Sec-WebSocket-Extensions header""" +# Private types - -Subprotocol = NewType("Subprotocol", str) -Subprotocol.__doc__ = """Subprotocol value in a Sec-WebSocket-Protocol header""" +ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] +ExtensionHeader.__doc__ = """Extension in a Sec-WebSocket-Extensions header.""" ConnectionOption = NewType("ConnectionOption", str) -ConnectionOption.__doc__ = """Connection option in a Connection header""" +ConnectionOption.__doc__ = """Connection option in a Connection header.""" UpgradeProtocol = NewType("UpgradeProtocol", str) -UpgradeProtocol.__doc__ = """Upgrade protocol in an Upgrade header""" +UpgradeProtocol.__doc__ = """Upgrade protocol in an Upgrade header.""" From 6aae7c69e30c74bbdc9bd81c52f7a5f90b4c8037 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 18:21:36 +0200 Subject: [PATCH 118/762] Explain ping timeout errors in the FAQ. Fix #1012. --- docs/howto/faq.rst | 54 +++++++++++++++++++++++++++++++++++++--- docs/topics/timeouts.rst | 3 ++- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index a4ad84680..7cec6d43e 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -251,10 +251,10 @@ If you're seeing this traceback in the logs of a server: .. code-block:: pytb - Error in connection handler + connection handler failed Traceback (most recent call last): ... - asyncio.streams.IncompleteReadError: 0 bytes read on a total of 2 expected bytes + asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes The above exception was the direct cause of the following exception: @@ -290,11 +290,59 @@ There are several reasons why long-lived connections may be lost: a wired network, enter airplane mode, be put to sleep, etc. * HTTP load balancers or proxies that aren't configured for long-lived connections may terminate connections after a short amount of time, usually - 30 seconds. + 30 seconds, despite websockets' keepalive mechanism. If you're facing a reproducible issue, :ref:`enable debug logs ` to see when and how connections are closed. +What does ``ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received`` mean? +....................................................................................................................... + +If you're seeing this traceback in the logs of a server: + +.. code-block:: pytb + + connection handler failed + Traceback (most recent call last): + ... + asyncio.exceptions.CancelledError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + +or if a client crashes with this traceback: + +.. code-block:: pytb + + Traceback (most recent call last): + ... + asyncio.exceptions.CancelledError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + +it means that the WebSocket connection suffered from excessive latency and was +closed after reaching the timeout of websockets' keepalive mechanism. + +You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it +from being logged. + +There are two main reasons why latency may increase: + +* Poor network connectivity. +* More traffic than the recipient can handle. + +See the discussion of :doc:`timeouts <../topics/timeouts>` for details. + +If websockets' default timeout of 20 seconds is too short for your use case, +you can adjust it with the ``ping_timeout`` argument. + How do I set a timeout on ``recv()``? ..................................... diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 828d22a0b..8febfce9f 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -20,7 +20,8 @@ By default, websockets waits 20 seconds, then sends a Ping frame, and expects to receive the corresponding Pong frame within 20 seconds. Else, it considers the connection broken and closes it. -Timings are configurable with ``ping_interval`` and ``ping_timeout``. +Timings are configurable with the ``ping_interval`` and ``ping_timeout`` +arguments of :func:`~websockets.connect` and :func:`~websockets.serve`. While WebSocket runs on top of TCP, websockets doesn't rely on TCP keepalive because it's disabled by default and, if enabled, the default interval is no From 4e1dac362a3639b9cf0e5bcf382601e6e32cfede Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 18:22:26 +0200 Subject: [PATCH 119/762] Show DEBUG log level in example. --- docs/howto/cheatsheet.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst index 86684c44c..edfb00baa 100644 --- a/docs/howto/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -69,7 +69,7 @@ If you don't understand what websockets is doing, enable logging:: import logging logger = logging.getLogger('websockets') - logger.setLevel(logging.INFO) + logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) The logs contain: From 21af14cdd61cb31a4a4c003d73f9eb489a7a22f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 21:37:01 +0200 Subject: [PATCH 120/762] Fix display of license in docs. --- docs/project/license.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/project/license.rst b/docs/project/license.rst index 426110020..0a3b8703d 100644 --- a/docs/project/license.rst +++ b/docs/project/license.rst @@ -1,4 +1,4 @@ License ======= -.. literalinclude:: ../../LICENSE +.. include:: ../../LICENSE From bf02ad01e91a05777d607f926673f8d8b0efc2e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 21:42:10 +0200 Subject: [PATCH 121/762] Memory use => usage. --- README.rst | 2 +- docs/topics/memory.rst | 2 +- docs/topics/security.rst | 12 ++++++------ src/websockets/legacy/protocol.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index 70dec88db..6b9b67672 100644 --- a/README.rst +++ b/README.rst @@ -104,7 +104,7 @@ The development of ``websockets`` is shaped by four principles: under 100% branch coverage. Also it passes the industry-standard `Autobahn Testsuite`_. -4. **Performance**: memory use is configurable. An extension written in C +4. **Performance**: memory usage is configurable. An extension written in C accelerates expensive operations. It's pre-compiled for Linux, macOS and Windows and packaged in the wheel format for each system and Python version. diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index e5b9ed9a2..c880d5579 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -26,7 +26,7 @@ Buffers Under normal circumstances, buffers are almost always empty. Under high load, if a server receives more messages than it can process, -bufferbloat can result in excessive memory use. +bufferbloat can result in excessive memory usage. By default websockets has generous limits. It is strongly recommended to adapt them to your application. When you call :func:`~legacy.server.serve`: diff --git a/docs/topics/security.rst b/docs/topics/security.rst index 6c541db06..d3dec21bd 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -9,8 +9,8 @@ For production use, a server should require encrypted connections. See this example of :ref:`encrypting connections with TLS `. -Memory use ----------- +Memory usage +------------ .. warning:: @@ -21,11 +21,11 @@ Memory use With the default settings, opening a connection uses 70 KiB of memory. -Sending some highly compressed messages could use up to 128 MiB of memory -with an amplification factor of 1000 between network traffic and memory use. +Sending some highly compressed messages could use up to 128 MiB of memory with +an amplification factor of 1000 between network traffic and memory usage. -Configuring a server to :doc:`memory` will improve security in addition to -improving performance. +Configuring a server to :doc:`optimize memory usage ` will improve +security in addition to improving performance. Other limits ------------ diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index df74bdb38..b0d2ed832 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1432,7 +1432,7 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N If you broadcast messages faster than a connection can handle them, messages will pile up in its write buffer until the connection times out. Keep low values for ``ping_interval`` and ``ping_timeout`` to prevent - excessive memory use by slow connections when you use :func:`broadcast`. + excessive memory usage by slow connections when you use :func:`broadcast`. Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, :func:`broadcast` doesn't support sending fragmented messages. Indeed, From 59eb1b53a9b3b868a05e0eecc8255ee3b11b8bcc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 21:45:52 +0200 Subject: [PATCH 122/762] Fix typo. --- docs/topics/broadcast.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index c43a8fa36..f9cd9e281 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -12,7 +12,7 @@ Broadcasting messages If you want to learn about its design in depth, continue reading this document. -WebSocket servers often send the same message to all connected clients to a +WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. Let's explore options for broadcasting a message, explain the design From 0867d9f0f9fa0dfc267e6e0b91cb1aaa5858ccaf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 25 Jul 2021 21:51:41 +0200 Subject: [PATCH 123/762] Clarify error message. --- src/websockets/frames.py | 4 ++-- src/websockets/legacy/protocol.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 839e2f7b7..a0ce1d350 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -340,7 +340,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: elif isinstance(data, BytesLike): return OP_BINARY, data else: - raise TypeError("data must be bytes-like or str") + raise TypeError("data must be str or bytes-like") def prepare_ctrl(data: Data) -> bytes: @@ -362,7 +362,7 @@ def prepare_ctrl(data: Data) -> bytes: elif isinstance(data, BytesLike): return bytes(data) else: - raise TypeError("data must be bytes-like or str") + raise TypeError("data must be str or bytes-like") @dataclasses.dataclass diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b0d2ed832..339733db2 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -507,7 +507,7 @@ async def send( :raises ~websockets.exceptions.ConnectionClosed: when the connection is closed - :raises TypeError: for unsupported inputs + :raises TypeError: if ``message`` doesn't have a supported type """ await self.ensure_open() @@ -613,7 +613,7 @@ async def send( self._fragmented_message_waiter = None else: - raise TypeError("data must be bytes, str, or iterable") + raise TypeError("data must be str, bytes-like, or iterable") async def close(self, code: int = 1000, reason: str = "") -> None: """ @@ -1441,11 +1441,11 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N as fast as possible. :raises RuntimeError: if a connection is busy sending a fragmented message - :raises TypeError: for unsupported inputs + :raises TypeError: if ``message`` doesn't have a supported type """ if not isinstance(message, (str, bytes, bytearray, memoryview)): - raise TypeError("data must be bytes, str, or iterable") + raise TypeError("data must be str or bytes-like") opcode, data = prepare_data(message) From 786134542d9f0d6ad3959198c5e543ce55fff4aa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 31 Jul 2021 09:35:19 +0200 Subject: [PATCH 124/762] Document that time.sleep doesn't work in asyncio. This belongs in asyncio's docs, not in websockets', but I can't find it (easily) in Python docs and it keeps tripping people up. Fix #1025, #1024, #923, #901, #703, etc. --- docs/howto/faq.rst | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 7cec6d43e..394920b58 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -214,7 +214,7 @@ the connection terminates. Why does my program never receives any messages? ................................................ -Your program runs a coroutine that never yield control to the event loop. The +Your program runs a coroutine that never yields control to the event loop. The coroutine that receives messages never gets a chance to run. Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough @@ -240,6 +240,15 @@ See `issue 867`_. .. _issue 867: https://github.com/aaugustin/websockets/issues/867 +Why does my very simple program misbehave mysteriously? +....................................................... + +You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which +blocks the event loop and prevents asyncio from operating normally. + +This may lead to messages getting send but not received, to connection +timeouts, and to unexpected results of shotgun debugging e.g. adding an +unnecessary call to ``send()`` makes the program functional. Both sides ---------- From b343fc60824737d4faa421aafe7983a8d3d0c9df Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 6 Aug 2021 22:42:41 +0200 Subject: [PATCH 125/762] Refactor tests for redirects. --- tests/legacy/test_client_server.py | 116 ++++++++--------------------- 1 file changed, 29 insertions(+), 87 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 755fcefdd..6ce100fa2 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -75,6 +75,20 @@ async def default_handler(ws, path): await ws.send((await ws.recv())) +async def redirect_request(path, headers, test, status): + if path == "/redirect": + location = get_server_uri(test.server, test.secure, "/") + elif path == "/infinite": + location = get_server_uri(test.server, test.secure, "/infinite") + elif path == "/force_insecure": + location = get_server_uri(test.server, False, "/") + elif path == "/missing_location": + return status, {}, b"" + else: + return None + return status, {"Location": location}, b"" + + @contextlib.contextmanager def temp_test_server(test, **kwargs): test.start_server(**kwargs) @@ -84,15 +98,9 @@ def temp_test_server(test, **kwargs): test.stop_server() -@contextlib.contextmanager -def temp_test_redirecting_server( - test, status, include_location=True, force_insecure=False, **kwargs -): - test.start_redirecting_server(status, include_location, force_insecure, **kwargs) - try: - yield - finally: - test.stop_redirecting_server() +def temp_test_redirecting_server(test, status=http.HTTPStatus.FOUND, **kwargs): + process_request = functools.partial(redirect_request, test=test, status=status) + return temp_test_server(test, process_request=process_request, **kwargs) @contextlib.contextmanager @@ -201,11 +209,6 @@ class ClientServerTestsMixin: def setUp(self): super().setUp() self.server = None - self.redirecting_server = None - - @property - def server_context(self): - return None def start_server(self, deprecation_warnings=None, **kwargs): handler = kwargs.pop("handler", default_handler) @@ -226,42 +229,6 @@ def start_server(self, deprecation_warnings=None, **kwargs): expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) - def start_redirecting_server( - self, - status, - include_location=True, - force_insecure=False, - deprecation_warnings=None, - **kwargs, - ): - async def process_request(path, headers): - server_uri = get_server_uri(self.server, self.secure, path) - if force_insecure: - server_uri = server_uri.replace("wss:", "ws:") - headers = {"Location": server_uri} if include_location else [] - return status, headers, b"" - - with warnings.catch_warnings(record=True) as recorded_warnings: - start_server = serve( - default_handler, - "localhost", - 0, - compression=None, - ping_interval=None, - process_request=process_request, - ssl=self.server_context, - **kwargs, - ) - self.redirecting_server = self.loop.run_until_complete(start_server) - - expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if ( - sys.version_info[:2] >= (3, 10) - and "remove loop argument" not in expected_warnings - ): # pragma: no cover - expected_warnings += ["There is no current event loop"] - self.assertDeprecationWarnings(recorded_warnings, expected_warnings) - def start_client( self, resource_name="/", user_info=None, deprecation_warnings=None, **kwargs ): @@ -274,8 +241,7 @@ def start_client( try: server_uri = kwargs.pop("uri") except KeyError: - server = self.redirecting_server if self.redirecting_server else self.server - server_uri = get_server_uri(server, secure, resource_name, user_info) + server_uri = get_server_uri(self.server, secure, resource_name, user_info) with warnings.catch_warnings(record=True) as recorded_warnings: start_client = connect(server_uri, **kwargs) @@ -306,17 +272,6 @@ def stop_server(self): except asyncio.TimeoutError: # pragma: no cover self.fail("Server failed to stop") - def stop_redirecting_server(self): - self.redirecting_server.close() - try: - self.loop.run_until_complete( - asyncio.wait_for(self.redirecting_server.wait_closed(), timeout=1) - ) - except asyncio.TimeoutError: # pragma: no cover - self.fail("Redirecting server failed to stop") - finally: - self.redirecting_server = None - @contextlib.contextmanager def temp_server(self, **kwargs): with temp_test_server(self, **kwargs): @@ -388,7 +343,6 @@ def test_basic(self): reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") - @with_server() def test_redirect(self): redirect_statuses = [ http.HTTPStatus.MOVED_PERMANENTLY, @@ -399,40 +353,31 @@ def test_redirect(self): ] for status in redirect_statuses: with temp_test_redirecting_server(self, status): - with temp_test_client(self): + with self.temp_client("/redirect"): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") def test_infinite_redirect(self): - with temp_test_redirecting_server( - self, - http.HTTPStatus.FOUND, - ): - self.server = self.redirecting_server + with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): - with temp_test_client(self): + with self.temp_client("/infinite"): self.fail("Did not raise") # pragma: no cover - @with_server() def test_redirect_missing_location(self): - with temp_test_redirecting_server( - self, - http.HTTPStatus.FOUND, - include_location=False, - loop=self.loop, - deprecation_warnings=["remove loop argument"], - ): + with temp_test_redirecting_server(self): with self.assertRaises(InvalidHeader): - with temp_test_client(self): + with self.temp_client("/missing_location"): self.fail("Did not raise") # pragma: no cover def test_loop_backwards_compatibility(self): with self.temp_server( - loop=self.loop, deprecation_warnings=["remove loop argument"] + loop=self.loop, + deprecation_warnings=["remove loop argument"], ): with self.temp_client( - loop=self.loop, deprecation_warnings=["remove loop argument"] + loop=self.loop, + deprecation_warnings=["remove loop argument"], ): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) @@ -1274,13 +1219,10 @@ def test_ws_uri_is_rejected(self): uri=get_server_uri(self.server, secure=False), ssl=self.client_context ) - @with_server() def test_redirect_insecure(self): - with temp_test_redirecting_server( - self, http.HTTPStatus.FOUND, force_insecure=True - ): + with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): - with temp_test_client(self): + with self.temp_client("/force_insecure"): self.fail("Did not raise") # pragma: no cover From 1f462000ac9b4f8ad80deb565083ffe430e09acf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 11 Aug 2021 09:53:50 +0200 Subject: [PATCH 126/762] Fix handling of relative redirects. Fix #1023. --- docs/project/changelog.rst | 2 ++ src/websockets/legacy/client.py | 7 ++++++- tests/legacy/test_client_server.py | 13 +++++++++++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index cf38c6159..277ee5022 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -74,6 +74,8 @@ They may change at any time. * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. +* Fixed handling of relative redirects in :func:`~legacy.client.connect`. + 9.1 ... diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index c87b5f8d5..5d6d8130a 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -9,6 +9,7 @@ import collections.abc import functools import logging +import urllib.parse import warnings from types import TracebackType from typing import ( @@ -611,12 +612,15 @@ def __init__( # This is a coroutine function. self._create_connection = create_connection + self._uri = uri self._wsuri = wsuri def handle_redirect(self, uri: str) -> None: # Update the state of this instance to connect to a new URI. + old_uri = self._uri old_wsuri = self._wsuri - new_wsuri = parse_uri(uri) + new_uri = urllib.parse.urljoin(old_uri, uri) + new_wsuri = parse_uri(new_uri) # Forbid TLS downgrade. if old_wsuri.secure and not new_wsuri.secure: @@ -645,6 +649,7 @@ def handle_redirect(self, uri: str) -> None: ) # Set the new WebSocket URI. This suffices for same-origin redirects. + self._uri = new_uri self._wsuri = new_wsuri # async for ... in connect(...): diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 6ce100fa2..02d7a7aa3 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -76,8 +76,10 @@ async def default_handler(ws, path): async def redirect_request(path, headers, test, status): - if path == "/redirect": + if path == "/absolute_redirect": location = get_server_uri(test.server, test.secure, "/") + elif path == "/relative_redirect": + location = "/" elif path == "/infinite": location = get_server_uri(test.server, test.secure, "/infinite") elif path == "/force_insecure": @@ -353,11 +355,18 @@ def test_redirect(self): ] for status in redirect_statuses: with temp_test_redirecting_server(self, status): - with self.temp_client("/redirect"): + with self.temp_client("/absolute_redirect"): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") + def test_redirect_relative_location(self): + with temp_test_redirecting_server(self): + with self.temp_client("/relative_redirect"): + self.loop.run_until_complete(self.client.send("Hello!")) + reply = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(reply, "Hello!") + def test_infinite_redirect(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): From 90a20f550ea57f3ae6474d560d623158de97e490 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 07:09:04 +0200 Subject: [PATCH 127/762] Simplify exceptions raised by send_close. --- src/websockets/connection.py | 2 +- tests/test_connection.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 857698f4f..52fd9bb81 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -275,7 +275,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: raise ProtocolError("expected a continuation frame") if code is None: if reason != "": - raise ValueError("cannot send a reason without a code") + raise ProtocolError("cannot send a reason without a code") close = Close(1005, "") data = b"" else: diff --git a/tests/test_connection.py b/tests/test_connection.py index ac0df46ce..3d4d98436 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -865,13 +865,13 @@ def test_server_receives_close_with_code_and_reason(self): def test_client_sends_close_with_reason_only(self): client = Connection(Side.CLIENT) - with self.assertRaises(ValueError) as raised: + with self.assertRaises(ProtocolError) as raised: client.send_close(reason="going away") self.assertEqual(str(raised.exception), "cannot send a reason without a code") def test_server_sends_close_with_reason_only(self): server = Connection(Side.SERVER) - with self.assertRaises(ValueError) as raised: + with self.assertRaises(ProtocolError) as raised: server.send_close(reason="OK") self.assertEqual(str(raised.exception), "cannot send a reason without a code") From dc42ecba8d809e14f1856e2e65816c27eaeb8fb6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 08:43:04 +0200 Subject: [PATCH 128/762] Make Headers comparison case-insensitive. --- src/websockets/datastructures.py | 2 +- tests/test_datastructures.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 3e3f9705a..42282e209 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -129,7 +129,7 @@ def __delitem__(self, key: str) -> None: def __eq__(self, other: Any) -> bool: if not isinstance(other, Headers): return NotImplemented - return self._list == other._list + return self._dict == other._dict def clear(self) -> None: """ diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 628cbcb02..a361a1b4d 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -97,6 +97,10 @@ def test_eq(self): other_headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) self.assertEqual(self.headers, other_headers) + def test_eq_case_insensitive(self): + other_headers = Headers(connection="Upgrade", server="websockets") + self.assertEqual(self.headers, other_headers) + def test_eq_not_equal(self): other_headers = Headers([("Connection", "close"), ("Server", "websockets")]) self.assertNotEqual(self.headers, other_headers) From 5de7b41a2b2003aaf1db4042ac3d5e2ab4c24cdb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 08:43:46 +0200 Subject: [PATCH 129/762] Fix initializing Headers from Headers. It didn't work when a header had multiple values. --- src/websockets/datastructures.py | 16 +++- tests/test_datastructures.py | 149 ++++++++++++++++++++++++++----- 2 files changed, 138 insertions(+), 27 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 42282e209..65c5d4115 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -74,10 +74,10 @@ class Headers(MutableMapping[str, str]): __slots__ = ["_dict", "_list"] - def __init__(self, *args: Any, **kwargs: str) -> None: + # Like dict, Headers accepts an optional "mapping or iterable" argument. + def __init__(self, *args: HeadersLike, **kwargs: str) -> None: self._dict: Dict[str, List[str]] = {} self._list: List[Tuple[str, str]] = [] - # MutableMapping.update calls __setitem__ for each (name, value) pair. self.update(*args, **kwargs) def __str__(self) -> str: @@ -86,7 +86,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"{self.__class__.__name__}({self._list!r})" - def copy(self) -> "Headers": + def copy(self) -> Headers: copy = self.__class__() copy._dict = self._dict.copy() copy._list = self._list.copy() @@ -139,6 +139,16 @@ def clear(self) -> None: self._dict = {} self._list = [] + def update(self, *args: HeadersLike, **kwargs: str) -> None: + """ + Update from a Headers instance and/or keyword arguments. + + """ + args = tuple( + arg.raw_items() if isinstance(arg, Headers) else arg for arg in args + ) + super().update(*args, **kwargs) + # Methods for handling multiple values def get_all(self, key: str) -> List[str]: diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index a361a1b4d..32b79817a 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -3,10 +3,68 @@ from websockets.datastructures import * +class MultipleValuesErrorTests(unittest.TestCase): + def test_multiple_values_error_str(self): + self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'") + self.assertEqual(str(MultipleValuesError()), "") + + class HeadersTests(unittest.TestCase): def setUp(self): self.headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) + def test_init(self): + self.assertEqual( + Headers(), + Headers(), + ) + + def test_init_from_kwargs(self): + self.assertEqual( + Headers(connection="Upgrade", server="websockets"), + self.headers, + ) + + def test_init_from_headers(self): + self.assertEqual( + Headers(self.headers), + self.headers, + ) + + def test_init_from_headers_and_kwargs(self): + self.assertEqual( + Headers(Headers(connection="Upgrade"), server="websockets"), + self.headers, + ) + + def test_init_from_mapping(self): + self.assertEqual( + Headers({"Connection": "Upgrade", "Server": "websockets"}), + self.headers, + ) + + def test_init_from_mapping_and_kwargs(self): + self.assertEqual( + Headers({"Connection": "Upgrade"}, server="websockets"), + self.headers, + ) + + def test_init_from_iterable(self): + self.assertEqual( + Headers([("Connection", "Upgrade"), ("Server", "websockets")]), + self.headers, + ) + + def test_init_from_iterable_and_kwargs(self): + self.assertEqual( + Headers([("Connection", "Upgrade")], server="websockets"), + self.headers, + ) + + def test_init_multiple_positional_arguments(self): + with self.assertRaises(TypeError): + Headers(Headers(connection="Upgrade"), Headers(server="websockets")) + def test_str(self): self.assertEqual( str(self.headers), "Connection: Upgrade\r\nServer: websockets\r\n\r\n" @@ -27,10 +85,6 @@ def test_serialize(self): b"Connection: Upgrade\r\nServer: websockets\r\n\r\n", ) - def test_multiple_values_error_str(self): - self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'") - self.assertEqual(str(MultipleValuesError()), "") - def test_contains(self): self.assertIn("Server", self.headers) @@ -59,11 +113,6 @@ def test_getitem_key_error(self): with self.assertRaises(KeyError): self.headers["Upgrade"] - def test_getitem_multiple_values_error(self): - self.headers["Server"] = "2" - with self.assertRaises(MultipleValuesError): - self.headers["Server"] - def test_setitem(self): self.headers["Upgrade"] = "websocket" self.assertEqual(self.headers["Upgrade"], "websocket") @@ -72,11 +121,6 @@ def test_setitem_case_insensitive(self): self.headers["upgrade"] = "websocket" self.assertEqual(self.headers["Upgrade"], "websocket") - def test_setitem_multiple_values(self): - self.headers["Connection"] = "close" - with self.assertRaises(MultipleValuesError): - self.headers["Connection"] - def test_delitem(self): del self.headers["Connection"] with self.assertRaises(KeyError): @@ -87,12 +131,6 @@ def test_delitem_case_insensitive(self): with self.assertRaises(KeyError): self.headers["Connection"] - def test_delitem_multiple_values(self): - self.headers["Connection"] = "close" - del self.headers["Connection"] - with self.assertRaises(KeyError): - self.headers["Connection"] - def test_eq(self): other_headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")]) self.assertEqual(self.headers, other_headers) @@ -124,12 +162,75 @@ def test_get_all_case_insensitive(self): def test_get_all_no_values(self): self.assertEqual(self.headers.get_all("Upgrade"), []) - def test_get_all_multiple_values(self): - self.headers["Connection"] = "close" - self.assertEqual(self.headers.get_all("Connection"), ["Upgrade", "close"]) - def test_raw_items(self): self.assertEqual( list(self.headers.raw_items()), [("Connection", "Upgrade"), ("Server", "websockets")], ) + + +class MultiValueHeadersTests(unittest.TestCase): + def setUp(self): + self.headers = Headers([("Server", "Python"), ("Server", "websockets")]) + + def test_init_from_headers(self): + self.assertEqual( + Headers(self.headers), + self.headers, + ) + + def test_init_from_headers_and_kwargs(self): + self.assertEqual( + Headers(Headers(server="Python"), server="websockets"), + self.headers, + ) + + def test_str(self): + self.assertEqual( + str(self.headers), "Server: Python\r\nServer: websockets\r\n\r\n" + ) + + def test_repr(self): + self.assertEqual( + repr(self.headers), + "Headers([('Server', 'Python'), ('Server', 'websockets')])", + ) + + def test_copy(self): + self.assertEqual(repr(self.headers.copy()), repr(self.headers)) + + def test_serialize(self): + self.assertEqual( + self.headers.serialize(), + b"Server: Python\r\nServer: websockets\r\n\r\n", + ) + + def test_iter(self): + self.assertEqual(set(iter(self.headers)), {"server"}) + + def test_len(self): + self.assertEqual(len(self.headers), 1) + + def test_getitem_multiple_values_error(self): + with self.assertRaises(MultipleValuesError): + self.headers["Server"] + + def test_setitem(self): + self.headers["Server"] = "redux" + self.assertEqual( + self.headers.get_all("Server"), ["Python", "websockets", "redux"] + ) + + def test_delitem(self): + del self.headers["Server"] + with self.assertRaises(KeyError): + self.headers["Server"] + + def test_get_all(self): + self.assertEqual(self.headers.get_all("Server"), ["Python", "websockets"]) + + def test_raw_items(self): + self.assertEqual( + list(self.headers.raw_items()), + [("Server", "Python"), ("Server", "websockets")], + ) From 26c17794dffef336ec5f43405d09608a42a78bca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 09:12:23 +0200 Subject: [PATCH 130/762] Simplify extra_headers implementation. --- src/websockets/client.py | 9 +----- src/websockets/legacy/client.py | 10 ++---- src/websockets/legacy/server.py | 8 +---- src/websockets/server.py | 8 +---- tests/legacy/test_client_server.py | 52 +++--------------------------- 5 files changed, 10 insertions(+), 77 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 6a0bb580e..105f42775 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,6 +1,5 @@ from __future__ import annotations -import collections from typing import Generator, List, Optional, Sequence from .connection import CLIENT, CONNECTING, OPEN, Connection, State @@ -105,13 +104,7 @@ def connect(self) -> Request: # noqa: F811 headers["Sec-WebSocket-Protocol"] = protocol_header if self.extra_headers is not None: - extra_headers = self.extra_headers - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - headers[name] = value + headers.update(self.extra_headers) headers.setdefault("User-Agent", USER_AGENT) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 5d6d8130a..bad2cbeea 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -6,7 +6,6 @@ from __future__ import annotations import asyncio -import collections.abc import functools import logging import urllib.parse @@ -387,13 +386,8 @@ async def handshake( protocol_header = build_subprotocol(available_subprotocols) request_headers["Sec-WebSocket-Protocol"] = protocol_header - if extra_headers is not None: - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - request_headers[name] = value + if self.extra_headers is not None: + request_headers.update(self.extra_headers) request_headers.setdefault("User-Agent", USER_AGENT) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3fde99568..297fc5664 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -6,7 +6,6 @@ from __future__ import annotations import asyncio -import collections.abc import email.utils import functools import http @@ -697,12 +696,7 @@ async def handshake( if callable(extra_headers): extra_headers = extra_headers(path, self.request_headers) if extra_headers is not None: - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - response_headers[name] = value + response_headers.update(extra_headers) response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) response_headers.setdefault("Server", USER_AGENT) diff --git a/src/websockets/server.py b/src/websockets/server.py index d6d2143bb..ae0cfdd58 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -2,7 +2,6 @@ import base64 import binascii -import collections import email.utils import http from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast @@ -142,12 +141,7 @@ def accept(self, request: Request) -> Response: else: extra_headers = self.extra_headers if extra_headers is not None: - if isinstance(extra_headers, Headers): - extra_headers = extra_headers.raw_items() - elif isinstance(extra_headers, collections.abc.Mapping): - extra_headers = extra_headers.items() - for name, value in extra_headers: - headers[name] = value + headers.update(extra_headers) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) headers.setdefault("Server", USER_AGENT) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 02d7a7aa3..3fcd0b044 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -608,52 +608,24 @@ def test_protocol_headers(self): self.assertEqual(server_req, repr(client_req)) self.assertEqual(server_resp, repr(client_resp)) - @with_server() - @with_client("/headers", extra_headers=Headers({"X-Spam": "Eggs"})) - def test_protocol_custom_request_headers(self): - req_headers = self.loop.run_until_complete(self.client.recv()) - self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", req_headers) - @with_server() @with_client("/headers", extra_headers={"X-Spam": "Eggs"}) - def test_protocol_custom_request_headers_dict(self): - req_headers = self.loop.run_until_complete(self.client.recv()) - self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", req_headers) - - @with_server() - @with_client("/headers", extra_headers=[("X-Spam", "Eggs")]) - def test_protocol_custom_request_headers_list(self): + def test_protocol_custom_request_headers(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client("/headers", extra_headers=[("User-Agent", "Eggs")]) + @with_client("/headers", extra_headers={"User-Agent": "Eggs"}) def test_protocol_custom_request_user_agent(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) self.assertIn("('User-Agent', 'Eggs')", req_headers) - @with_server(extra_headers=lambda p, r: Headers({"X-Spam": "Eggs"})) - @with_client("/headers") - def test_protocol_custom_response_headers_callable(self): - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=lambda p, r: {"X-Spam": "Eggs"}) @with_client("/headers") - def test_protocol_custom_response_headers_callable_dict(self): - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - - @with_server(extra_headers=lambda p, r: [("X-Spam", "Eggs")]) - @with_client("/headers") - def test_protocol_custom_response_headers_callable_list(self): + def test_protocol_custom_response_headers_callable(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) @@ -664,28 +636,14 @@ def test_protocol_custom_response_headers_callable_none(self): self.loop.run_until_complete(self.client.recv()) # doesn't crash self.loop.run_until_complete(self.client.recv()) # nothing to check - @with_server(extra_headers=Headers({"X-Spam": "Eggs"})) - @with_client("/headers") - def test_protocol_custom_response_headers(self): - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers={"X-Spam": "Eggs"}) @with_client("/headers") - def test_protocol_custom_response_headers_dict(self): - self.loop.run_until_complete(self.client.recv()) - resp_headers = self.loop.run_until_complete(self.client.recv()) - self.assertIn("('X-Spam', 'Eggs')", resp_headers) - - @with_server(extra_headers=[("X-Spam", "Eggs")]) - @with_client("/headers") - def test_protocol_custom_response_headers_list(self): + def test_protocol_custom_response_headers(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers=[("Server", "Eggs")]) + @with_server(extra_headers={"Server": "Eggs"}) @with_client("/headers") def test_protocol_custom_response_user_agent(self): self.loop.run_until_complete(self.client.recv()) From 6edb363af07365867b4c372405b6ab3177a57830 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 12 Aug 2021 09:22:38 +0200 Subject: [PATCH 131/762] Remove extra_headers from Sans-I/O layer. Request and response objects are explicitly passed through the integration layer, making this API unnecessary. --- src/websockets/client.py | 9 ++------- src/websockets/server.py | 26 +++++++------------------- tests/test_client.py | 22 ---------------------- tests/test_server.py | 37 ++++++------------------------------- 4 files changed, 15 insertions(+), 79 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 105f42775..e21fd36c0 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -3,7 +3,7 @@ from typing import Generator, List, Optional, Sequence from .connection import CLIENT, CONNECTING, OPEN, Connection, State -from .datastructures import Headers, HeadersLike, MultipleValuesError +from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, InvalidHeader, @@ -50,7 +50,6 @@ def __init__( origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, state: State = CONNECTING, max_size: Optional[int] = 2 ** 20, logger: Optional[LoggerLike] = None, @@ -65,7 +64,6 @@ def __init__( self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols - self.extra_headers = extra_headers self.key = generate_key() def connect(self) -> Request: # noqa: F811 @@ -103,10 +101,7 @@ def connect(self) -> Request: # noqa: F811 protocol_header = build_subprotocol(self.available_subprotocols) headers["Sec-WebSocket-Protocol"] = protocol_header - if self.extra_headers is not None: - headers.update(self.extra_headers) - - headers.setdefault("User-Agent", USER_AGENT) + headers["User-Agent"] = USER_AGENT return Request(self.wsuri.resource_name, headers) diff --git a/src/websockets/server.py b/src/websockets/server.py index ae0cfdd58..67183c685 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,10 +4,10 @@ import binascii import email.utils import http -from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast +from typing import Generator, List, Optional, Sequence, Tuple, cast from .connection import CONNECTING, OPEN, SERVER, Connection, State -from .datastructures import Headers, HeadersLike, MultipleValuesError +from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, InvalidHeader, @@ -44,9 +44,6 @@ __all__ = ["ServerConnection"] -HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] - - class ServerConnection(Connection): side = SERVER @@ -56,7 +53,6 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, state: State = CONNECTING, max_size: Optional[int] = 2 ** 20, logger: Optional[LoggerLike] = None, @@ -70,7 +66,6 @@ def __init__( self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols - self.extra_headers = extra_headers def accept(self, request: Request) -> Response: """ @@ -125,6 +120,8 @@ def accept(self, request: Request) -> Response: headers = Headers() + headers["Date"] = email.utils.formatdate(usegmt=True) + headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Accept"] = accept_key(key) @@ -135,16 +132,7 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - extra_headers: Optional[HeadersLike] - if callable(self.extra_headers): - extra_headers = self.extra_headers(request.path, request.headers) - else: - extra_headers = self.extra_headers - if extra_headers is not None: - headers.update(extra_headers) - - headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Server", USER_AGENT) + headers["Server"] = USER_AGENT self.logger.info("connection open") return Response(101, "Switching Protocols", headers) @@ -402,10 +390,10 @@ def reject( if headers is None: headers = Headers() headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Server", USER_AGENT) + headers.setdefault("Connection", "close") headers.setdefault("Content-Length", str(len(body))) headers.setdefault("Content-Type", "text/plain; charset=utf-8") - headers.setdefault("Connection", "close") + headers.setdefault("Server", USER_AGENT) self.logger.info("connection failed (%d %s)", status.value, status.phrase) return Response(status.value, status.phrase, headers, body) diff --git a/tests/test_client.py b/tests/test_client.py index 2ef1f6a95..015b93b3f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -108,28 +108,6 @@ def test_subprotocols(self): self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") - def test_extra_headers(self): - for extra_headers in [ - Headers({"X-Spam": "Eggs"}), - {"X-Spam": "Eggs"}, - [("X-Spam", "Eggs")], - ]: - with self.subTest(extra_headers=extra_headers): - client = ClientConnection( - "wss://example.com/", extra_headers=extra_headers - ) - request = client.connect() - - self.assertEqual(request.headers["X-Spam"], "Eggs") - - def test_extra_headers_overrides_user_agent(self): - client = ClientConnection( - "wss://example.com/", extra_headers={"User-Agent": "Other"} - ) - request = client.connect() - - self.assertEqual(request.headers["User-Agent"], "Other") - class AcceptRejectTests(unittest.TestCase): def test_receive_accept(self): diff --git a/tests/test_server.py b/tests/test_server.py index 042d64a31..54699c3ef 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -96,10 +96,10 @@ def test_send_accept(self): server.data_to_send(), [ f"HTTP/1.1 101 Switching Protocols\r\n" + f"Date: {DATE}\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" - f"Date: {DATE}\r\n" f"Server: {USER_AGENT}\r\n" f"\r\n".encode() ], @@ -117,10 +117,10 @@ def test_send_reject(self): [ f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" + f"Connection: close\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" - f"Connection: close\r\n" + f"Server: {USER_AGENT}\r\n" f"\r\n" f"Sorry folks.\n".encode(), b"", @@ -139,10 +139,10 @@ def test_accept_response(self): response.headers, Headers( { + "Date": DATE, "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, - "Date": DATE, "Server": USER_AGENT, } ), @@ -161,10 +161,10 @@ def test_reject_response(self): Headers( { "Date": DATE, - "Server": USER_AGENT, + "Connection": "close", "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", - "Connection": "close", + "Server": USER_AGENT, } ), ) @@ -604,31 +604,6 @@ def test_unsupported_subprotocol(self): self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) - def test_extra_headers(self): - for extra_headers in [ - Headers({"X-Spam": "Eggs"}), - {"X-Spam": "Eggs"}, - [("X-Spam", "Eggs")], - lambda path, headers: Headers({"X-Spam": "Eggs"}), - lambda path, headers: {"X-Spam": "Eggs"}, - lambda path, headers: [("X-Spam", "Eggs")], - ]: - with self.subTest(extra_headers=extra_headers): - server = ServerConnection(extra_headers=extra_headers) - request = self.make_request() - response = server.accept(request) - - self.assertEqual(response.status_code, 101) - self.assertEqual(response.headers["X-Spam"], "Eggs") - - def test_extra_headers_overrides_server(self): - server = ServerConnection(extra_headers={"Server": "Other"}) - request = self.make_request() - response = server.accept(request) - - self.assertEqual(response.status_code, 101) - self.assertEqual(response.headers["Server"], "Other") - class MiscTests(unittest.TestCase): def test_bypass_handshake(self): From af0f5646ad14e62ee614fa898cc4651ae7d02b93 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 08:36:27 +0200 Subject: [PATCH 132/762] Fix rendering of extensions how-to. --- docs/howto/extensions.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index fdaf09f63..9c49de172 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -1,6 +1,8 @@ Writing an extension ==================== +.. currentmodule:: websockets.extensions + During the opening handshake, WebSocket clients and servers negotiate which extensions will be used with which parameters. Then each frame is processed by extensions before being sent or after being received. @@ -24,8 +26,8 @@ As a consequence, writing an extension requires implementing several classes: websockets provides abstract base classes for extension factories and extensions. See the API documentation for details on their methods: -* :class:`ClientExtensionFactory` and class:`ServerExtensionFactory` for - :extension factories, +* :class:`ClientExtensionFactory` and :class:`ServerExtensionFactory` for + extension factories, * :class:`Extension` for extensions. From 017a832fb0070318a874626b3c69e36d2af669b3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 08:45:55 +0200 Subject: [PATCH 133/762] Regenerate Sphinx configuration. Recent Sphinx versions have a simpler template. --- docs/Makefile | 166 +++----------------------------- docs/conf.py | 258 +++++++------------------------------------------- docs/make.bat | 35 +++++++ 3 files changed, 83 insertions(+), 376 deletions(-) create mode 100644 docs/make.bat diff --git a/docs/Makefile b/docs/Makefile index bb25aa49d..d4bb2cbb9 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,160 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . BUILDDIR = _build -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . -# the i18n builder cannot share the environment and doctrees with the others -I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext - +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " texinfo to make Texinfo files" - @echo " info to make Texinfo files and run them through makeinfo" - @echo " gettext to make PO message catalogs" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - @echo " spelling to check for typos in documentation" - -clean: - -rm -rf $(BUILDDIR)/* - -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/websockets.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/websockets.qhc" - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/websockets" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/websockets" - @echo "# devhelp" - -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -texinfo: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo - @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." - @echo "Run \`make' in that directory to run these through makeinfo" \ - "(use \`make info' here to do that automatically)." - -info: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo "Running Texinfo files through makeinfo..." - make -C $(BUILDDIR)/texinfo info - @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." - -gettext: - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale - @echo - @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." - -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." +.PHONY: help Makefile -spelling: - $(SPHINXBUILD) -b spelling $(ALLSPHINXOPTS) $(BUILDDIR)/spelling - @echo - @echo "Check finished. Wrong words can be found in " \ - "$(BUILDDIR)/spelling/output.txt." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/conf.py b/docs/conf.py index 2246c0287..145abafa4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,112 +1,64 @@ -# -*- coding: utf-8 -*- +# Configuration file for the Sphinx documentation builder. # -# websockets documentation build configuration file, created by -# sphinx-quickstart on Sun Mar 31 20:48:44 2013. -# -# This file is execfile()d with the current directory set to its containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html -import sys, os, datetime +import datetime +import os +import sys + +# -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. + sys.path.insert(0, os.path.join(os.path.abspath('..'), 'src')) -# -- General configuration ----------------------------------------------------- -# If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# -- Project information ----------------------------------------------------- + +project = 'websockets' +copyright = f'2013-{datetime.date.today().year}, Aymeric Augustin and contributors' +author = 'Aymeric Augustin' + +# The full version, including alpha/beta/rc tags +release = '9.1' + -# Add any Sphinx extension module names here, as strings. They can be extensions -# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.viewcode', 'sphinx_autodoc_typehints', + 'sphinxcontrib.spelling', 'sphinxcontrib_trio', - ] +] -# Spelling check needs an additional module that is not installed by default. -# Add it only if spelling check is requested so docs can be generated without it. -if 'spelling' in sys.argv: - extensions.append('sphinxcontrib.spelling') +intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] -# The suffix of source filenames. -source_suffix = '.rst' - -# The encoding of source files. -#source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = 'websockets' -copyright = f'2013-{datetime.date.today().year}, Aymeric Augustin and contributors' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = '9.1' -# The full version, including alpha/beta/rc tags. -release = '9.1' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -#language = None - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' - # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] - -# The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - - -# -- Options for HTML output --------------------------------------------------- +# -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. +# html_theme = 'alabaster' -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. html_theme_options = { 'logo': 'websockets.svg', 'description': 'A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.', @@ -117,39 +69,6 @@ 'tidelift_url': 'https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs', } -# Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -#html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -#html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. html_sidebars = { '**': [ 'about.html', @@ -160,114 +79,7 @@ ] } -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Output file base name for HTML help builder. -htmlhelp_basename = 'websocketsdoc' - - -# -- Options for LaTeX output -------------------------------------------------- - -latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, author, documentclass [howto/manual]). -latex_documents = [ - ('index', 'websockets.tex', 'websockets Documentation', - 'Aymeric Augustin', 'manual'), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True - - -# -- Options for manual page output -------------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'websockets', 'websockets Documentation', - ['Aymeric Augustin'], 1) -] - -# If true, show URL addresses after external links. -#man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------------ - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ('index', 'websockets', 'websockets Documentation', - 'Aymeric Augustin', 'websockets', 'One line description of project.', - 'Miscellaneous'), -] - -# Documents to append as an appendix to all manuals. -#texinfo_appendices = [] - -# If false, no module index is generated. -#texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' - - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/3/': None} +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 000000000..2119f5109 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd From cc35695df01241752d01266ad3d86510505af070 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 09:10:48 +0200 Subject: [PATCH 134/762] Reformat conf.py with black. --- docs/conf.py | 60 +++++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 145abafa4..3523477c2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,18 +13,17 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. - -sys.path.insert(0, os.path.join(os.path.abspath('..'), 'src')) +sys.path.insert(0, os.path.join(os.path.abspath(".."), "src")) # -- Project information ----------------------------------------------------- -project = 'websockets' -copyright = f'2013-{datetime.date.today().year}, Aymeric Augustin and contributors' -author = 'Aymeric Augustin' +project = "websockets" +copyright = f"2013-{datetime.date.today().year}, Aymeric Augustin and contributors" +author = "Aymeric Augustin" # The full version, including alpha/beta/rc tags -release = '9.1' +release = "9.1" # -- General configuration --------------------------------------------------- @@ -33,53 +32,52 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.viewcode', - 'sphinx_autodoc_typehints', - 'sphinxcontrib.spelling', - 'sphinxcontrib_trio', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", + "sphinxcontrib.spelling", + "sphinxcontrib_trio", ] -intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} +intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -# -html_theme = 'alabaster' +html_theme = "alabaster" html_theme_options = { - 'logo': 'websockets.svg', - 'description': 'A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.', - 'github_button': True, - 'github_type': 'star', - 'github_user': 'aaugustin', - 'github_repo': 'websockets', - 'tidelift_url': 'https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs', + "logo": "websockets.svg", + "description": "A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.", + "github_button": True, + "github_type": "star", + "github_user": "aaugustin", + "github_repo": "websockets", + "tidelift_url": "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs", } html_sidebars = { - '**': [ - 'about.html', - 'searchbox.html', - 'navigation.html', - 'relations.html', - 'donate.html', + "**": [ + "about.html", + "searchbox.html", + "navigation.html", + "relations.html", + "donate.html", ] } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] From 323c93297a0876f7fc47993b908200888a50afbb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 09:07:58 +0200 Subject: [PATCH 135/762] Update spelling wordlist. List generated with this command, then checked and adjusted manually: $ make -C docs spelling | grep 'Spell check' | awk '{ print $4 }' \ sed 's/:$//' | tr '[:upper:]' '[:lower:]' | sort | uniq --- docs/conf.py | 2 ++ docs/spelling_wordlist.txt | 22 +++------------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 3523477c2..e4273570d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -42,6 +42,8 @@ intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} +spelling_show_suggestions = True + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b3389a920..3d05752d5 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1,16 +1,12 @@ -api -attr augustin auth autoscaler -awaitable aymeric backend backoff backpressure balancer balancers -bitcoin bottlenecked bufferbloat bugfix @@ -22,14 +18,10 @@ coroutines cryptocurrencies cryptocurrency ctrl -daemonize -datastructures django dyno -formatter fractalideas gunicorn -haproxy hypercorn iframe IPv @@ -37,48 +29,40 @@ istio iterable keepalive KiB -Kubernetes +kubernetes lifecycle linkerd liveness lookups MiB -mypy nginx -onmessage -parsers permessage pid -pong -pongs proxying pythonic reconnection redis +redistributions retransmit runtime scalable -serializers stateful subclasses subclassing +subpackages subprotocol subprotocols supervisord tidelift tls tox -unparse unregister uple -username uvicorn -uvloop virtualenv WebSocket websocket websockets ws wsgi -wss www From 5a6bd258ca29665ad8acd7709d989c57c2a8b2bc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 11:31:15 +0200 Subject: [PATCH 136/762] Enable and configure Furo theme. This removes two features from the sidebar: the GitHub "stars" count and the Tidelift banner. --- docs/Makefile | 3 ++ docs/_static/favicon.ico | 1 + docs/conf.py | 37 +++++++++++++------------ docs/requirements.txt | 5 ++++ experiments/authentication/favicon.ico | Bin 5430 -> 22 bytes logo/favicon.ico | Bin 0 -> 5430 bytes 6 files changed, 29 insertions(+), 17 deletions(-) create mode 120000 docs/_static/favicon.ico mode change 100644 => 120000 experiments/authentication/favicon.ico create mode 100644 logo/favicon.ico diff --git a/docs/Makefile b/docs/Makefile index d4bb2cbb9..7a04f7827 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -18,3 +18,6 @@ help: # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +livehtml: + sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/favicon.ico b/docs/_static/favicon.ico new file mode 120000 index 000000000..dd7df921e --- /dev/null +++ b/docs/_static/favicon.ico @@ -0,0 +1 @@ +../../logo/favicon.ico \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e4273570d..04d4c0557 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,8 +36,11 @@ "sphinx.ext.intersphinx", "sphinx.ext.viewcode", "sphinx_autodoc_typehints", + "sphinx_copybutton", + "sphinx_inline_tabs", "sphinxcontrib.spelling", "sphinxcontrib_trio", + "sphinxext.opengraph", ] intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} @@ -57,29 +60,29 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = "alabaster" +html_theme = "furo" html_theme_options = { - "logo": "websockets.svg", - "description": "A library for building WebSocket servers and clients in Python with a focus on correctness and simplicity.", - "github_button": True, - "github_type": "star", - "github_user": "aaugustin", - "github_repo": "websockets", - "tidelift_url": "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=docs", + "light_css_variables": { + "color-brand-primary": "#306998", # blue from logo + "color-brand-content": "#0b487a", # blue more saturated and less dark + }, + "dark_css_variables": { + "color-brand-primary": "#ffd43bcc", # yellow from logo, more muted than content + "color-brand-content": "#ffd43bd9", # yellow from logo, transparent like text + }, + "sidebar_hide_name": True, } -html_sidebars = { - "**": [ - "about.html", - "searchbox.html", - "navigation.html", - "relations.html", - "donate.html", - ] -} +html_logo = "_static/websockets.svg" + +html_favicon = "_static/favicon.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] + +html_copy_source = False + +html_show_sphinx = False diff --git a/docs/requirements.txt b/docs/requirements.txt index 0eaf94fbe..b9c371228 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,9 @@ +furo sphinx +sphinx-autobuild sphinx-autodoc-typehints +sphinx-copybutton +sphinx-inline-tabs sphinxcontrib-spelling sphinxcontrib-trio +sphinxext-opengraph diff --git a/experiments/authentication/favicon.ico b/experiments/authentication/favicon.ico deleted file mode 100644 index 36e855029d705e72d44428bda6e8cb6d3dd317ed..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5430 zcmeH~eNdFw6~-?im3Ahb_(wY9G?@uAnl{s!P5}k6EMLo)Am6YM4N*)C21rB%KSG3E zv;hMxiW&u@(MF93wOCQ3rD-sRXbhdyP7o3K#!S*q0u${7l(0J*XQZVahM#RhT&Mw%^BqMk{U@J1~1H2BJEU64r)0!B+s<_@|*2 zqhiBw1#eC^Kd~FBvDYyg*$#wV#Xwjm9t&v2p)n2r@@vI8+WgsgcG9|#V(P|dLI)7j zfdSf?v#0~lgtfyurUjuNe*Ta2+T*8ca~J%;(ZdETkY^gva%XNcqcj2kdN-G|0yeVF}ZFP=4Z!J2R#i0#02+Ir+*zVG@y z*5YS?yJuBTDvF9 z+e4qHUo$?zisgL(ZFDT{$HRZz5|LeB6ok5h0Mx8E{O!%Wpod<2-phQZDi;R}!Uptt z_Sffbz;E>+X1=#0_Q!9(68Y81O(FP$*?76Z67toHt(cKzH39p~78|EMbIvF7dZI9CxZHU=B0 z<9Jap>Rt-q+)$@;bvu`KC6(K3*mRuQ9MaslWomPE8-LktGQPhv{>AfKV-~kmXM9f| z|2D6HQbVN2tW*=RJWA^l;TTP>949?wSa4Hnl)qW|Kbb4f8Fvyy2h!?Rc2a%__sk=;L*m&%aOk`2^Q3=D%J$`lrna32wt&`ua?J=IYOX`)ey| zsDH*}F--J3_a!wE;q=XC^%sGy0H6D|y~p17*k>&_PGWEtANNEx9#Nf`$LIWDXd7VU z-lc0`F21g#6;|0_CyK+_IPED%Vmjd+bq#|NSGn)ee~4%xL9&NlK@0!?spEQZeA9Sw zOg84YE{XKby>e*LH9+>d2-$me4ei*@T2#ST@l)pX-aNEoMD}B9TkIH*)9-fc#{JDF zsS`tE`y^thLD#X6zKIf~ACX*0P0;qeI{2O1>Ze+hd$YWhu%sV8DP0&!?gSFKPsfuu z=^dyc`Wh;U66imQy~wci5ZZc7tYMe4y3>E-zOC46crNM2&=cI_Q|MnbFb{~Qf9lvq zl)zjpqW|G=)`eCR4jO{~dh5&OHAgQLZpnV9ydbWbam~@=o9J4d>8Y0WqGf3>kS6_3 zH=ygFmaRky%tJKeAJJ-p{_kqA)#Yh(*$efw^_AAoNKZthCz1G^u_xP8J>82TYJl`x zPtw0=VBb|il)!vMb4^CrJ8A0?#hMfy%!Slv9#Q+&`23lKphoC2`RmZF?C`@YT|BPQis$fd)qZc3H#T~U(G%W)mfglo{sU;_HV?#Giog+ zWdZKb7(gC1)GY7CGNOdEKE$SWVT5(5)r3};^sUup2XfhuD&bJR@3U{@`dIhp%!9oK z`hCKgZ~0C9j|Y3#aMr|#eN9k{rX#rbA2Ga;A`w_nTG-C ztBDe-vb<91XNIG%>P!!=&(~otdZi|$F7=v_x=f{SqJ-l`K`Vb(5MVvJ!Jze!ng;g7 z>?QBK{=AowD1llEwR+6*IO(WiJm0l|D{Ep{uL)g4S}$^l5>9OnTX~|`Xtnjy{y6%g zO~ax51$Wq&ClMvEKQ9vBhc`yq?ukr~>U(@1IJr6Gc0;i-fhgfV}k+ptmI=o*Qx0@RDGY}eIiR4)*!i%Vr#UY#Q>p{*{nuh%O zw-{rDA=?=E^vn-xf;^a)yc&4S#>DB;t~+h17%G6M9Y7Y%ttpeEt)~ z*kB1&)1jq0b@m6ll1AkWCmP=Q^&;&&DfRHy%i&r*$hok-P}6XevH3fT@AR)C)pTj8 zN-n>!BLP|-wu%PrJgTk5T@A${HyM{c{o!N z{uyKPol0D`gW)%O|F9}G?&6Mwv&}o=zc^nRJKR(eeeKWX(PfP1M^laIn|c0j*0+@b zZDY{i*oZ&{!b(QfqZlSb-*-dLEEhC+xWJOGU};h)wgh3bYC?&N1*5L&qtO{IrNa-n mx{c17;Wp~=fE$Kp5swEk(0J*XQZVahM#RhT&Mw%^BqMk{U@J1~1H2BJEU64r)0!B+s<_@|*2 zqhiBw1#eC^Kd~FBvDYyg*$#wV#Xwjm9t&v2p)n2r@@vI8+WgsgcG9|#V(P|dLI)7j zfdSf?v#0~lgtfyurUjuNe*Ta2+T*8ca~J%;(ZdETkY^gva%XNcqcj2kdN-G|0yeVF}ZFP=4Z!J2R#i0#02+Ir+*zVG@y z*5YS?yJuBTDvF9 z+e4qHUo$?zisgL(ZFDT{$HRZz5|LeB6ok5h0Mx8E{O!%Wpod<2-phQZDi;R}!Uptt z_Sffbz;E>+X1=#0_Q!9(68Y81O(FP$*?76Z67toHt(cKzH39p~78|EMbIvF7dZI9CxZHU=B0 z<9Jap>Rt-q+)$@;bvu`KC6(K3*mRuQ9MaslWomPE8-LktGQPhv{>AfKV-~kmXM9f| z|2D6HQbVN2tW*=RJWA^l;TTP>949?wSa4Hnl)qW|Kbb4f8Fvyy2h!?Rc2a%__sk=;L*m&%aOk`2^Q3=D%J$`lrna32wt&`ua?J=IYOX`)ey| zsDH*}F--J3_a!wE;q=XC^%sGy0H6D|y~p17*k>&_PGWEtANNEx9#Nf`$LIWDXd7VU z-lc0`F21g#6;|0_CyK+_IPED%Vmjd+bq#|NSGn)ee~4%xL9&NlK@0!?spEQZeA9Sw zOg84YE{XKby>e*LH9+>d2-$me4ei*@T2#ST@l)pX-aNEoMD}B9TkIH*)9-fc#{JDF zsS`tE`y^thLD#X6zKIf~ACX*0P0;qeI{2O1>Ze+hd$YWhu%sV8DP0&!?gSFKPsfuu z=^dyc`Wh;U66imQy~wci5ZZc7tYMe4y3>E-zOC46crNM2&=cI_Q|MnbFb{~Qf9lvq zl)zjpqW|G=)`eCR4jO{~dh5&OHAgQLZpnV9ydbWbam~@=o9J4d>8Y0WqGf3>kS6_3 zH=ygFmaRky%tJKeAJJ-p{_kqA)#Yh(*$efw^_AAoNKZthCz1G^u_xP8J>82TYJl`x zPtw0=VBb|il)!vMb4^CrJ8A0?#hMfy%!Slv9#Q+&`23lKphoC2`RmZF?C`@YT|BPQis$fd)qZc3H#T~U(G%W)mfglo{sU;_HV?#Giog+ zWdZKb7(gC1)GY7CGNOdEKE$SWVT5(5)r3};^sUup2XfhuD&bJR@3U{@`dIhp%!9oK z`hCKgZ~0C9j|Y3#aMr|#eN9k{rX#rbA2Ga;A`w_nTG-C ztBDe-vb<91XNIG%>P!!=&(~otdZi|$F7=v_x=f{SqJ-l`K`Vb(5MVvJ!Jze!ng;g7 z>?QBK{=AowD1llEwR+6*IO(WiJm0l|D{Ep{uL)g4S}$^l5>9OnTX~|`Xtnjy{y6%g zO~ax51$Wq&ClMvEKQ9vBhc`yq?ukr~>U(@1IJr6Gc0;i-fhgfV}k+ptmI=o*Qx0@RDGY}eIiR4)*!i%Vr#UY#Q>p{*{nuh%O zw-{rDA=?=E^vn-xf;^a)yc&4S#>DB;t~+h17%G6M9Y7Y%ttpeEt)~ z*kB1&)1jq0b@m6ll1AkWCmP=Q^&;&&DfRHy%i&r*$hok-P}6XevH3fT@AR)C)pTj8 zN-n>!BLP|-wu%PrJgTk5T@A${HyM{c{o!N z{uyKPol0D`gW)%O|F9}G?&6Mwv&}o=zc^nRJKR(eeeKWX(PfP1M^laIn|c0j*0+@b zZDY{i*oZ&{!b(QfqZlSb-*-dLEEhC+xWJOGU};h)wgh3bYC?&N1*5L&qtO{IrNa-n mx{c17;Wp~=fE$Kp5swEk Date: Sun, 15 Aug 2021 07:34:33 +0200 Subject: [PATCH 137/762] Update links to RFC after IETF website changes. --- docs/howto/logging.rst | 4 ++-- docs/project/changelog.rst | 2 +- docs/reference/extensions.rst | 2 +- docs/reference/limitations.rst | 4 +++- docs/topics/compression.rst | 2 +- docs/topics/design.rst | 10 ++++---- src/websockets/extensions/base.py | 2 +- .../extensions/permessage_deflate.py | 4 ++-- src/websockets/headers.py | 23 ++++++++++--------- src/websockets/http.py | 2 +- src/websockets/http11.py | 14 +++++------ src/websockets/legacy/client.py | 4 ++-- src/websockets/legacy/framing.py | 2 +- src/websockets/legacy/handshake.py | 2 +- src/websockets/legacy/http.py | 10 ++++---- src/websockets/legacy/protocol.py | 10 ++++---- src/websockets/legacy/server.py | 6 ++--- src/websockets/server.py | 2 +- src/websockets/uri.py | 4 ++-- 19 files changed, 56 insertions(+), 53 deletions(-) diff --git a/docs/howto/logging.rst b/docs/howto/logging.rst index 824812959..4e91557fa 100644 --- a/docs/howto/logging.rst +++ b/docs/howto/logging.rst @@ -35,8 +35,8 @@ Instead, when running as a server, websockets logs one event when a `connection is established`_ and another event when a `connection is closed`_. -.. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455#section-4 -.. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 +.. _connection is established: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 +.. _connection is closed: https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.4 websockets doesn't log an event for every message because that would be excessive for many applications exchanging small messages at a fast rate. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 277ee5022..07942783f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -747,7 +747,7 @@ Also: * Added support for providing and checking Origin_. -.. _Origin: https://tools.ietf.org/html/rfc6455#section-10.2 +.. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 2.0 ... diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index 6ca20dbc8..bae583a21 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -8,7 +8,7 @@ The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. -.. _extensions: https://tools.ietf.org/html/rfc6455#section-9 +.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name Per-Message Deflate diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst index ecdde23bf..81f1445b5 100644 --- a/docs/reference/limitations.rst +++ b/docs/reference/limitations.rst @@ -6,9 +6,11 @@ Client The client doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -mandated by :rfc:`6455`. However, :func:`~websockets.connect()` isn't the +`mandated by RFC 6455`_. However, :func:`~websockets.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. +.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 + The client doesn't support connecting through a HTTP proxy (`issue 364`_) or a SOCKS proxy (`issue 475`_). diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index e23319636..f78e32748 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -162,4 +162,4 @@ settings affect memory usage and how to optimize them. This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory Level = 4 for optimizing memory usage. -.. _experiment by Peter Thorson: https://www.ietf.org/mail-archive/web/hybi/current/msg10222.html +.. _experiment by Peter Thorson: https://mailarchive.ietf.org/arch/msg/hybi/F9t4uPufVEy8KBLuL36cZjCmM_Y/ diff --git a/docs/topics/design.rst b/docs/topics/design.rst index fa2093433..2c9d505aa 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -173,16 +173,16 @@ differences between a server and a client: - `closing the TCP connection`_: the server closes the connection immediately; the client waits for the server to do it. -.. _client-to-server masking: https://tools.ietf.org/html/rfc6455#section-5.3 -.. _closing the TCP connection: https://tools.ietf.org/html/rfc6455#section-5.5.1 +.. _client-to-server masking: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.3 +.. _closing the TCP connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.1 These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented in the same class, :class:`~legacy.protocol.WebSocketCommonProtocol`. -.. _data framing: https://tools.ietf.org/html/rfc6455#section-5 -.. _sending and receiving data: https://tools.ietf.org/html/rfc6455#section-6 -.. _closing the connection: https://tools.ietf.org/html/rfc6455#section-7 +.. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 +.. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 +.. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 The :attr:`~legacy.protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 2de9176bd..7217aa513 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -4,7 +4,7 @@ See `section 9 of RFC 6455`_. -.. _section 9 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-9 +.. _section 9 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 0c9088a9e..59f019d53 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -268,7 +268,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): ``True`` to include them in the negotiation offer without a value or to an integer value to include them with this value. - .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1 + .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 :param server_no_context_takeover: defaults to ``False`` :param client_no_context_takeover: defaults to ``False`` @@ -466,7 +466,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): ``True`` to include them in the negotiation offer without a value or to an integer value to include them with this value. - .. _section 7.1 of RFC 7692: https://tools.ietf.org/html/rfc7692#section-7.1 + .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 :param server_no_context_takeover: defaults to ``False`` :param client_no_context_takeover: defaults to ``False`` diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 181e976e3..ee6dd1672 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -40,8 +40,8 @@ # To avoid a dependency on a parsing library, we implement manually the ABNF -# described in https://tools.ietf.org/html/rfc6455#section-9.1 with the -# definitions from https://tools.ietf.org/html/rfc7230#appendix-B. +# described in https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 and +# https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. def peek_ahead(header: str, pos: int) -> Optional[str]: @@ -161,9 +161,9 @@ def parse_list( :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. """ - # Per https://tools.ietf.org/html/rfc7230#section-7, "a recipient MUST - # parse and ignore a reasonable number of empty list elements"; hence - # while loops that remove extra delimiters. + # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient + # MUST parse and ignore a reasonable number of empty list elements"; + # hence while loops that remove extra delimiters. # Remove extra delimiters before the first item. while peek_ahead(header, pos) == ",": @@ -289,8 +289,9 @@ def parse_extension_item_param( if peek_ahead(header, pos) == '"': pos_before = pos # for proper error reporting below value, pos = parse_quoted_string(header, pos, header_name) - # https://tools.ietf.org/html/rfc6455#section-9.1 says: the value - # after quoted-string unescaping MUST conform to the 'token' ABNF. + # https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 says: + # the value after quoted-string unescaping MUST conform to + # the 'token' ABNF. if _token_re.fullmatch(value) is None: raise exceptions.InvalidHeaderFormat( header_name, "invalid quoted header content", header, pos_before @@ -452,7 +453,7 @@ def build_www_authenticate_basic(realm: str) -> str: :param realm: authentication realm """ - # https://tools.ietf.org/html/rfc7617#section-2 + # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 realm = build_quoted_string(realm) charset = build_quoted_string("UTF-8") return f"Basic realm={realm}, charset={charset}" @@ -498,8 +499,8 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: :raises InvalidHeaderValue: on unsupported inputs """ - # https://tools.ietf.org/html/rfc7235#section-2.1 - # https://tools.ietf.org/html/rfc7617#section-2 + # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 + # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": raise exceptions.InvalidHeaderValue( @@ -539,7 +540,7 @@ def build_authorization_basic(username: str, password: str) -> str: This is the reverse of :func:`parse_authorization_basic`. """ - # https://tools.ietf.org/html/rfc7617#section-2 + # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 assert ":" not in username user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() diff --git a/src/websockets/http.py b/src/websockets/http.py index 6168c5144..38848b56d 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -36,7 +36,7 @@ def build_host(host: str, port: int, secure: bool) -> str: Build a ``Host`` header. """ - # https://tools.ietf.org/html/rfc3986#section-3.2.2 + # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 # IPv6 addresses must be enclosed in brackets. try: address = ipaddress.ip_address(host) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 11b9d7f39..daa0efffb 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -19,7 +19,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://tools.ietf.org/html/rfc7230#appendix-B. +# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. # Regex for validating header names. @@ -78,7 +78,7 @@ def parse( :raises ValueError: if the request isn't well formatted """ - # https://tools.ietf.org/html/rfc7230#section-3.1.1 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -102,7 +102,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://tools.ietf.org/html/rfc7230#section-3.3.3 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -166,7 +166,7 @@ def parse( :raises ValueError: if the response isn't well formatted """ - # https://tools.ietf.org/html/rfc7230#section-3.1.2 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 # As in parse_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. @@ -197,7 +197,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://tools.ietf.org/html/rfc7230#section-3.3.3 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -251,7 +251,7 @@ def parse_headers( line or raises an exception if there isn't enough data """ - # https://tools.ietf.org/html/rfc7230#section-3.2 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 # We don't attempt to support obsolete line folding. @@ -301,7 +301,7 @@ def parse_line( # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: raise exceptions.SecurityError("line too long") - # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 + # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index bad2cbeea..6d976e0df 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -83,14 +83,14 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): timeouts on inactive connections. Set ``ping_interval`` to ``None`` to disable this behavior. - .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 + .. _Ping frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to ``None`` to disable this behavior. - .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 + .. _Pong frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index c8ae48690..40cbd41bf 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -6,7 +6,7 @@ See `section 5 of RFC 6455`_. -.. _section 5 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-5 +.. _section 5 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 """ diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 49d08cfe8..7cde58ac1 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -3,7 +3,7 @@ See `section 4 of RFC 6455`_. -.. _section 4 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 +.. _section 4 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 Some checks cannot be performed because they depend too much on the context; instead, they're documented below. diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 0b9a92267..3725fa9c3 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -22,7 +22,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://tools.ietf.org/html/rfc7230#appendix-B. +# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. # Regex for validating header names. @@ -61,7 +61,7 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: :raises ValueError: if the request isn't well formatted """ - # https://tools.ietf.org/html/rfc7230#section-3.1.1 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -105,7 +105,7 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers :raises ValueError: if the response isn't well formatted """ - # https://tools.ietf.org/html/rfc7230#section-3.1.2 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. @@ -144,7 +144,7 @@ async def read_headers(stream: asyncio.StreamReader) -> Headers: Non-ASCII characters are represented with surrogate escapes. """ - # https://tools.ietf.org/html/rfc7230#section-3.2 + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 # We don't attempt to support obsolete line folding. @@ -189,7 +189,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: # Security: this guarantees header values are small (hard-coded = 4 KiB) if len(line) > MAX_LINE: raise SecurityError("line too long") - # Not mandatory but safe - https://tools.ietf.org/html/rfc7230#section-3.5 + # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 339733db2..7b24ccf98 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -3,7 +3,7 @@ See `sections 4 to 8 of RFC 6455`_. -.. _sections 4 to 8 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-4 +.. _sections 4 to 8 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 """ @@ -480,8 +480,8 @@ async def send( bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a `Binary frame`_. - .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 - .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Text frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects. In that case the message @@ -1420,8 +1420,8 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a `Binary frame`_. - .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 - .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Text frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even if their write buffers overflow ``write_limit``. There's no backpressure. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 297fc5664..ddf6d9f87 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -101,14 +101,14 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): timeouts on inactive connections. Set ``ping_interval`` to ``None`` to disable this behavior. - .. _Ping frame: https://tools.ietf.org/html/rfc6455#section-5.5.2 + .. _Ping frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to ``None`` to disable this behavior. - .. _Pong frame: https://tools.ietf.org/html/rfc6455#section-5.5.3 + .. _Pong frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy @@ -454,7 +454,7 @@ def process_origin( """ # "The user agent MUST NOT include more than one Origin header field" - # per https://tools.ietf.org/html/rfc6454#section-7.3. + # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: origin = cast(Optional[Origin], headers.get("Origin")) except MultipleValuesError as exc: diff --git a/src/websockets/server.py b/src/websockets/server.py index 67183c685..0ae0ae940 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -222,7 +222,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: """ # "The user agent MUST NOT include more than one Origin header field" - # per https://tools.ietf.org/html/rfc6454#section-7.3. + # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: origin = cast(Optional[Origin], headers.get("Origin")) except MultipleValuesError as exc: diff --git a/src/websockets/uri.py b/src/websockets/uri.py index ed4521d53..1ab895a21 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -3,7 +3,7 @@ See `section 3 of RFC 6455`_. -.. _section 3 of RFC 6455: http://tools.ietf.org/html/rfc6455#section-3 +.. _section 3 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-3 """ @@ -31,7 +31,7 @@ class WebSocketURI: :param str user_info: ``(username, password)`` tuple when the URI contains `User Information`_, else ``None``. - .. _User Information: https://tools.ietf.org/html/rfc3986#section-3.2.1 + .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 """ secure: bool From 4407d02b8f69f654c90cfdc8ce261377e008ce16 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Aug 2021 08:01:56 +0200 Subject: [PATCH 138/762] Auto-rebuild docs on docstring changes. --- docs/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index 7a04f7827..045870645 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -20,4 +20,4 @@ help: @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) livehtml: - sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + sphinx-autobuild --watch "$(SOURCEDIR)/../src" "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) From 232351294683b422c98451536824d5d1cdca7f58 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 15 Aug 2021 09:19:35 +0200 Subject: [PATCH 139/762] Use a better exception type. --- src/websockets/legacy/protocol.py | 4 ++-- tests/legacy/test_protocol.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7b24ccf98..99a821be6 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -697,7 +697,7 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: :raises ~websockets.exceptions.ConnectionClosed: when the connection is closed - :raises ValueError: if another ping was sent with the same data and + :raises RuntimeError: if another ping was sent with the same data and the corresponding pong wasn't received yet """ @@ -708,7 +708,7 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: # Protect against duplicates if a payload is explicitly set. if data in self.pings: - raise ValueError("already waiting for a pong with the same data") + raise RuntimeError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.pings: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 61a5fe7cf..1672ab1ed 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1021,7 +1021,7 @@ def test_canceled_ping(self): def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertOneFrameSent(True, OP_PING, b"foobar") - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertNoFrameSent() From 6a0cb60d069c9c4db6be1da3e539b672b54a8c93 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 17 Aug 2021 08:08:57 +0200 Subject: [PATCH 140/762] Move logging document to topics. It's more a discussion than a how-to. --- docs/howto/index.rst | 3 +-- docs/topics/index.rst | 1 + docs/{howto => topics}/logging.rst | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename docs/{howto => topics}/logging.rst (100%) diff --git a/docs/howto/index.rst b/docs/howto/index.rst index e5af8488e..a5779f6c9 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -9,13 +9,12 @@ If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. faq cheatsheet -The following guides will help you integrate websockets into a broader system. +This guide will help you integrate websockets into a broader system. .. toctree:: :maxdepth: 2 django - logging The WebSocket protocol makes provisions for extending or specializing its features, which websockets supports fully. diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 993303106..e3b20d73f 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -5,6 +5,7 @@ Topics :maxdepth: 2 deployment + logging authentication broadcast compression diff --git a/docs/howto/logging.rst b/docs/topics/logging.rst similarity index 100% rename from docs/howto/logging.rst rename to docs/topics/logging.rst From 6881cea2cea2ecfc9430673ac042e7188b8a4125 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 17 Aug 2021 08:47:36 +0200 Subject: [PATCH 141/762] Improve intro slightly. --- docs/intro/index.rst | 32 +++++++++++++++++++++----------- example/hello.py | 3 +-- example/secure_client.py | 5 +---- example/secure_server.py | 5 +---- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 1df92ed59..ff3ae0ffd 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -8,9 +8,11 @@ Requirements websockets requires Python ≥ 3.7. -You should use the latest version of Python if possible. If you're using an -older version, be aware that for each minor version (3.x), only the latest -bugfix release (3.x.y) is officially supported. +.. admonition:: Use the most recent Python release + :class: tip + + For each minor version (3.x), only the latest bugfix or security release + (3.x.y) is officially supported. Installation ------------ @@ -33,7 +35,7 @@ It reads a name from the client, sends a greeting, and closes the connection. .. _client-example: -On the server side, websockets executes the handler coroutine ``hello`` once +On the server side, websockets executes the handler coroutine ``hello()`` once for each WebSocket connection. It closes the connection when the handler coroutine returns. @@ -42,8 +44,8 @@ Here's a corresponding WebSocket client example. .. literalinclude:: ../../example/client.py :emphasize-lines: 10 -Using :func:`connect` as an asynchronous context manager ensures the -connection is closed before exiting the ``hello`` coroutine. +Using :func:`~client.connect` as an asynchronous context manager ensures the +connection is closed before exiting the ``hello()`` coroutine. .. _secure-server-example: @@ -53,20 +55,28 @@ Secure example Secure WebSocket connections improve confidentiality and also reliability because they reduce the risk of interference by bad proxies. -The WSS protocol is to WS what HTTPS is to HTTP: the connection is encrypted -with Transport Layer Security (TLS) — which is often referred to as Secure -Sockets Layer (SSL). WSS requires TLS certificates like HTTPS. +The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. The +connection is encrypted with TLS_ (Transport Layer Security). ``wss`` +requires certificates like ``https``. + +.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security + +.. admonition:: TLS vs. SSL + :class: tip + + TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an + earlier encryption protocol; the name stuck. Here's how to adapt the server example to provide secure connections. See the documentation of the :mod:`ssl` module for configuring the context securely. .. literalinclude:: ../../example/secure_server.py - :emphasize-lines: 19-21,26 + :emphasize-lines: 19-21,24 Here's how to adapt the client. .. literalinclude:: ../../example/secure_client.py - :emphasize-lines: 10-12,18 + :emphasize-lines: 10-12,16 This client needs a context because the server uses a self-signed certificate. diff --git a/example/hello.py b/example/hello.py index 96095dd02..84f55dc52 100755 --- a/example/hello.py +++ b/example/hello.py @@ -4,8 +4,7 @@ import websockets async def hello(): - uri = "ws://localhost:8765" - async with websockets.connect(uri) as websocket: + async with websockets.connect("ws://localhost:8765") as websocket: await websocket.send("Hello world!") await websocket.recv() diff --git a/example/secure_client.py b/example/secure_client.py index 518819dd1..8a1551e29 100755 --- a/example/secure_client.py +++ b/example/secure_client.py @@ -13,10 +13,7 @@ async def hello(): uri = "wss://localhost:8765" - async with websockets.connect( - uri, - ssl=ssl_context, - ) as websocket: + async with websockets.connect(uri, ssl=ssl_context) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/secure_server.py b/example/secure_server.py index 96c300390..cd8ee0cc1 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -21,10 +21,7 @@ async def hello(websocket, path): ssl_context.load_cert_chain(localhost_pem) async def main(): - async with websockets.serve( - hello, "localhost", 8765, - ssl=ssl_context, - ): + async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): await asyncio.Future() # run forever asyncio.run(main()) From 73ae74cdbcb5b43efeaa1615da88ca90bb62ecc4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 17 Aug 2021 08:51:31 +0200 Subject: [PATCH 142/762] Clean up tables of contents. --- docs/howto/index.rst | 8 ++++---- docs/index.rst | 43 +--------------------------------------- docs/project/index.rst | 8 +++++--- docs/reference/index.rst | 6 +++--- docs/topics/index.rst | 8 +++++--- 5 files changed, 18 insertions(+), 55 deletions(-) diff --git a/docs/howto/index.rst b/docs/howto/index.rst index a5779f6c9..d7c83dd7a 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -4,7 +4,7 @@ How-to guides If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. .. toctree:: - :maxdepth: 2 + :titlesonly: faq cheatsheet @@ -12,7 +12,7 @@ If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. This guide will help you integrate websockets into a broader system. .. toctree:: - :maxdepth: 2 + :titlesonly: django @@ -20,7 +20,7 @@ The WebSocket protocol makes provisions for extending or specializing its features, which websockets supports fully. .. toctree:: - :maxdepth: 2 + :titlesonly: extensions @@ -29,7 +29,7 @@ features, which websockets supports fully. Once your application is ready, learn how to deploy it on various platforms. .. toctree:: - :maxdepth: 2 + :titlesonly: heroku kubernetes diff --git a/docs/index.rst b/docs/index.rst index 000c19ca7..30d01c2f8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,52 +40,11 @@ And here's an echo server: Do you like it? Let's dive in! -Tutorials ---------- - -If you're new to websockets, this is the place to start. - .. toctree:: - :maxdepth: 2 + :hidden: intro/index - -How-to guides -------------- - -These guides will help you build and deploy a websockets application. - -.. toctree:: - :maxdepth: 2 - howto/index - -Reference ---------- - -Find all the details you could ask for, and then some. - -.. toctree:: - :maxdepth: 2 - reference/index - -Topics ------- - -Get a deeper understanding of how websockets is built and why. - -.. toctree:: - :maxdepth: 2 - topics/index - -Project -------- - -This is about websockets-the-project rather than websockets-the-software. - -.. toctree:: - :maxdepth: 2 - project/index diff --git a/docs/project/index.rst b/docs/project/index.rst index 931fbe2d4..459146345 100644 --- a/docs/project/index.rst +++ b/docs/project/index.rst @@ -1,8 +1,10 @@ -Project -======= +About websockets +================ + +This is about websockets-the-project rather than websockets-the-software. .. toctree:: - :maxdepth: 2 + :titlesonly: changelog contributing diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 9fc4a0092..8d01c5b40 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -1,5 +1,5 @@ -Reference -========= +API reference +============= websockets provides complete client and server implementations, as shown in the :doc:`getting started guide <../intro/index>`. @@ -38,7 +38,7 @@ client and server connections. For convenience, common methods are documented both in the client API and server API. .. toctree:: - :maxdepth: 2 + :titlesonly: client server diff --git a/docs/topics/index.rst b/docs/topics/index.rst index e3b20d73f..120a3dd32 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -1,8 +1,10 @@ -Topics -====== +Topic guides +============ + +Get a deeper understanding of how websockets is built and why. .. toctree:: - :maxdepth: 2 + :titlesonly: deployment logging From f1d6345d2d88109f907e7f4e9d71f6204508e746 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 18 Aug 2021 08:23:33 +0200 Subject: [PATCH 143/762] Review logging doc. --- docs/topics/logging.rst | 50 +++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 4e91557fa..393efeb71 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -10,7 +10,7 @@ When you run a WebSocket client, your code calls coroutines provided by websockets. If an error occurs, websockets tells you by raising an exception. For example, -it raises a :exc:`~exception.ConnectionClosed` exception if the other side +it raises a :exc:`~exceptions.ConnectionClosed` exception if the other side closes the connection. When you run a WebSocket server, websockets accepts connections, performs the @@ -38,12 +38,15 @@ closed`_. .. _connection is established: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 .. _connection is closed: https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.4 -websockets doesn't log an event for every message because that would be -excessive for many applications exchanging small messages at a fast rate. -However, you could add this level of logging in your own code if necessary. +By default, websockets doesn't log an event for every message. That would be +excessive for many applications exchanging small messages at a fast rate. If +you need this level of detail, you could add logging in your own code. -See :ref:`log levels ` below for details of events logged by -websockets at each level. +Finally, you can enable debug logs to get details about everything websockets +is doing. This can be useful when developing clients as well as servers. + +See :ref:`log levels ` below for a list of events logged by +websockets logs at each log level. Configure logging ----------------- @@ -94,13 +97,14 @@ However, this technique runs into two problems: * Even with :meth:`str.format` style, you're restricted to attribute and index lookups, which isn't enough to implement some fairly simple requirements. -There's a better way. :func:`~server.serve` accepts a ``logger`` argument to -override the default :class:`~logging.Logger`. You can set ``logger`` to -a :class:`~logging.LoggerAdapter` that enriches logs. +There's a better way. :func:`~client.connect` and :func:`~server.serve` accept +a ``logger`` argument to override the default :class:`~logging.Logger`. You +can set ``logger`` to a :class:`~logging.LoggerAdapter` that enriches logs. -For example, if the server is behind a reverse proxy, ``remote_address`` gives +For example, if the server is behind a reverse +proxy, :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` gives the IP address of the proxy, which isn't useful. IP addresses of clients are -generally available in a HTTP header set by the proxy. +provided in a HTTP header set by the proxy. Here's how to include them in logs, assuming they're in the ``X-Forwarded-For`` header:: @@ -111,6 +115,7 @@ Here's how to include them in logs, assuming they're in the ) class LoggerAdapter(logging.LoggerAdapter): + """Add connection ID and client IP address to websockets logs.""" def process(self, msg, kwargs): try: websocket = kwargs["extra"]["websocket"] @@ -148,6 +153,7 @@ Finally, we populate the ``event_data`` custom attribute in log records with a :class:`~logging.LoggerAdapter`:: class LoggerAdapter(logging.LoggerAdapter): + """Add connection ID and client IP address to websockets logs.""" def process(self, msg, kwargs): try: websocket = kwargs["extra"]["websocket"] @@ -169,9 +175,9 @@ Disable logging --------------- If your application doesn't configure :mod:`logging`, Python outputs messages -of severity :data:`~logging.WARNING` and higher to :data:`~sys.stderr`. As a -consequence, you will see a message and a stack trace if a connection handler -coroutine crashes or if you hit a bug in websockets. +of severity ``WARNING`` and higher to :data:`~sys.stderr`. As a consequence, +you will see a message and a stack trace if a connection handler coroutine +crashes or if you hit a bug in websockets. If you want to disable this behavior for websockets, you can add a :class:`~logging.NullHandler`:: @@ -183,8 +189,8 @@ propagation to the root logger, or else its handlers could output logs:: logging.getLogger("websockets").propagate = False -Alternatively, you could set the log level to :data:`~logging.CRITICAL` for -websockets, as the highest level currently used is :data:`~logging.ERROR`:: +Alternatively, you could set the log level to ``CRITICAL`` for the +``"websockets"`` logger, as the highest level currently used is ``ERROR``:: logging.getLogger("websockets").setLevel(logging.CRITICAL) @@ -199,21 +205,21 @@ Log levels Here's what websockets logs at each level. -:attr:`~logging.ERROR` -...................... +``ERROR`` +......... * Exceptions raised by connection handler coroutines in servers * Exceptions resulting from bugs in websockets -:attr:`~logging.INFO` -..................... +``INFO`` +........ * Server starting and stopping * Server establishing and closing connections * Client reconnecting automatically -:attr:`~logging.DEBUG` -...................... +``DEBUG`` +......... * Changes to the state of connections * Handshake requests and responses From c7fc0d36bd8ea2aeb7c4321f53d208fb1297db85 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 20 Aug 2021 01:53:46 +1200 Subject: [PATCH 144/762] Speed up Python apply_mask 20x by using int.from_bytes/to_bytes (#1034) Speed up Python apply_mask 20x by using int.from_bytes/to_bytes This speeds up the Python version of utils.apply_mask about 20 times, using int.from_bytes so that the XOR is done in a single Python operation -- in other words, the loop over the bytes is in C rather than in Python. Note that it is a trade-off as it uses more memory: this version allocates roughly len(data) bytes for each of the intermediate values (e.g., data_int, mask_repeated, mask_int, the XOR result); whereas I believe the original version only allocates for the return value. Still, most websocket packets aren't huge, and I believe the massive speed gain here makes it worth it. (And people that use the speedups.c version won't be affected.) Obviously the speedups.c version is still significantly faster again, but this change makes the library more usable in environments when it's not feasible to use the C extension. Data Size ForLoop IntXor Speedups ------------------------------------ 1KB 78.6us 3.79us 151ns 1MB 79.7ms 4.38ms 55.4us I got these timings by using commands like the following (with the function call adjusted, and 1024 replaced with 1024*1024 as needed). python3 -m timeit \ -s 'from websockets.utils import apply_mask' \ -s 'data=b"x"*1024; mask=b"abcd"' \ 'apply_mask(data, mask)' This idea came from Will McGugan's blog post "Speeding up Websockets 60X": https://www.willmcgugan.com/blog/tech/post/speeding-up-websockets-60x/ That post contains an ever faster (about 50% faster) way to solve it using a pre-calculated XOR lookup table, but that pre-allocates a 64K-entry table at import time, which didn't seem ideal. Still, that is how aiohttp does it, so maybe it's worth considering: https://github.com/aio-libs/aiohttp/blob/6ec33c5d841c8e845c27ebdd9384bbf72651cbb8/aiohttp/http_websocket.py#L115-L140 The int.from_bytes approach is also the approach used by the websocket-client library: https://github.com/websocket-client/websocket-client/blob/5f32b3c0cfb836c016ad2a5f6caeff2978a6a16f/websocket/_abnf.py#L46-L50 Co-authored-by: Aymeric Augustin --- src/websockets/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index ffb706963..c6e4b788c 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -2,8 +2,8 @@ import base64 import hashlib -import itertools import secrets +import sys __all__ = ["accept_key", "apply_mask"] @@ -43,4 +43,7 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: if len(mask) != 4: raise ValueError("mask must contain 4 bytes") - return bytes(b ^ m for b, m in zip(data, itertools.cycle(mask))) + data_int = int.from_bytes(data, sys.byteorder) + mask_repeated = mask * (len(data) // 4) + mask[: len(data) % 4] + mask_int = int.from_bytes(mask_repeated, sys.byteorder) + return (data_int ^ mask_int).to_bytes(len(data), sys.byteorder) From 4b10c2c8e0334535ae4874d95dc7fdd6392bf68a Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Wed, 25 Aug 2021 11:02:42 +0200 Subject: [PATCH 145/762] Fix typo in the FAQ. --- docs/howto/faq.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 394920b58..d43d5fd67 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -389,7 +389,7 @@ How do I keep idle connections open? websockets sends pings at 20 seconds intervals to keep the connection open. -In closes the connection if it doesn't get a pong within 20 seconds. +It closes the connection if it doesn't get a pong within 20 seconds. You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. From 9e75d1e6fc4edc53f27d546362711d3bd22291b3 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Fri, 27 Aug 2021 14:14:48 +1200 Subject: [PATCH 146/762] Fix docstring of parse_uri to reflect the exception type raised It doesn't raise a ValueError, but an InvalidURI error (see below). If we want the exception type to be a ValueError as well, you could make InvalidURI inherit from ValueError as well as WebSocketException. But just changing the doc seems like the right fix here. >>> from websockets.uri import parse_uri >>> parse_uri('foo://example.com') Traceback (most recent call last): File "/home/ben/w/websockets/src/websockets/uri.py", line 58, in parse_uri assert parsed.scheme in ["ws", "wss"] AssertionError The above exception was the direct cause of the following exception: Traceback (most recent call last): File "", line 1, in File "/home/ben/w/websockets/src/websockets/uri.py", line 63, in parse_uri raise exceptions.InvalidURI(uri) from exc websockets.exceptions.InvalidURI: foo://example.com isn't a valid URI >>> import sys >>> exc = sys.last_value >>> isinstance(exc, ValueError) False --- src/websockets/uri.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 1ab895a21..c99c3f16e 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -49,7 +49,8 @@ def parse_uri(uri: str) -> WebSocketURI: """ Parse and validate a WebSocket URI. - :raises ValueError: if ``uri`` isn't a valid WebSocket URI. + :raises ~websockets.exceptions.InvalidURI: if ``uri`` isn't a valid + WebSocket URI. """ parsed = urllib.parse.urlparse(uri) From 5e0f002152fd42f08eb8b9f050fff6211384e30b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 28 Aug 2021 08:36:36 +0200 Subject: [PATCH 147/762] Switch from viewcode to linkcode. This makes docs builds noticeably faster as they no longer need to syntax-highlight and render all source code. --- docs/conf.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 04d4c0557..d5f527bf6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -5,7 +5,10 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html import datetime +import importlib +import inspect import os +import subprocess import sys # -- Path setup -------------------------------------------------------------- @@ -34,7 +37,7 @@ extensions = [ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", - "sphinx.ext.viewcode", + "sphinx.ext.linkcode", "sphinx_autodoc_typehints", "sphinx_copybutton", "sphinx_inline_tabs", @@ -55,6 +58,52 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +# Configure viewcode extension. +try: + git_sha1 = subprocess.run( + "git rev-parse --short HEAD", + capture_output=True, + shell=True, + check=True, + text=True, + ).stdout.strip() +except subprocess.SubprocessError as exc: + print("Cannot get git commit, disabling linkcode:", exc) + extensions.remove("sphinx.ext.linkcode") +else: + code_url = f"https://github.com/aaugustin/websockets/blob/{git_sha1}" + + +def linkcode_resolve(domain, info): + assert domain == "py" + + mod = importlib.import_module(info["module"]) + if "." in info["fullname"]: + objname, attrname = info["fullname"].split(".") + obj = getattr(mod, objname) + try: + # object is a method of a class + obj = getattr(obj, attrname) + except AttributeError: + # object is an attribute of a class + return None + else: + obj = getattr(mod, info["fullname"]) + + try: + file = inspect.getsourcefile(obj) + lines = inspect.getsourcelines(obj) + except TypeError: + # e.g. object is a typing.Union + return None + file = os.path.relpath(file, os.path.abspath("..")) + if not file.startswith("src/websockets"): + # e.g. object is a typing.NewType + return None + start, end = lines[1], lines[1] + len(lines[0]) - 1 + + return f"{code_url}/{file}#L{start}-L{end}" + # -- Options for HTML output ------------------------------------------------- From 1b0b2de59f13b79062a888a8d17ae5e9e5a60ef9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Aug 2021 22:44:09 +0200 Subject: [PATCH 148/762] Advertise support for client_max_window_bits by default. This doesn't change anything in the default setup (compression="deflate") because it sets client_max_window_bits to 12 explicitly. If a client is configured with extensions=[ClientPerMessageDeflateFactory()] this change gives the server the option to limit client_max_window_bits. Specifically, this makes such a client compatible with a websockets server using the default setup (compression="deflate"). --- src/websockets/extensions/permessage_deflate.py | 2 +- tests/legacy/test_client_server.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 59f019d53..a377abb55 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -286,7 +286,7 @@ def __init__( server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[Union[int, bool]] = None, + client_max_window_bits: Optional[Union[int, bool]] = True, compress_settings: Optional[Dict[str, Any]] = None, ) -> None: """ diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 3fcd0b044..016f08e73 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -834,7 +834,12 @@ def test_extension_client_rejection(self): ServerPerMessageDeflateFactory(), ] ) - @with_client("/extensions", extensions=[ClientPerMessageDeflateFactory()]) + @with_client( + "/extensions", + extensions=[ + ClientPerMessageDeflateFactory(client_max_window_bits=None), + ], + ) def test_extension_no_match_then_match(self): # The order requested by the client has priority. server_extensions = self.loop.run_until_complete(self.client.recv()) From 88ae5eb25fa3e8caa3f5d7be7517ec8d5c68847d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Sep 2021 20:50:03 +0200 Subject: [PATCH 149/762] Make error messages more error friendly. Avoid weird exception chaining to assertion errors. --- docs/project/changelog.rst | 11 +++++++++++ src/websockets/exceptions.py | 5 +++-- src/websockets/uri.py | 15 +++++++-------- tests/test_exceptions.py | 4 ++-- tests/test_uri.py | 5 +++++ 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 07942783f..b9329d7b5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -56,6 +56,15 @@ They may change at any time. If you raise :exc:`~exceptions.ConnectionClosed` or a subclass — rather than catch them when websockets raises them — you must change your code. +.. note:: + + **Version 10.0 adds a ``msg`` parameter to** ``InvalidURI.__init__`` **.** + + If you raise :exc:`~exceptions.InvalidURI` — rather than catch them when + websockets raises them — you must change your code. + +Also: + * Added compatibility with Python 3.10. * Added :func:`~websockets.broadcast` to send a message to many clients. @@ -150,6 +159,8 @@ They may change at any time. from websockets.client import connect from websockets.server import serve +Also: + * Added compatibility with Python 3.9. * Added support for IRIs in addition to URIs. diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b3462484f..6bbea324c 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -374,11 +374,12 @@ class InvalidURI(WebSocketException): """ - def __init__(self, uri: str) -> None: + def __init__(self, uri: str, msg: str) -> None: self.uri = uri + self.msg = msg def __str__(self) -> str: - return f"{self.uri} isn't a valid URI" + return f"{self.uri} isn't a valid URI: {self.msg}" class PayloadTooBig(WebSocketException): diff --git a/src/websockets/uri.py b/src/websockets/uri.py index c99c3f16e..397c23116 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -54,13 +54,12 @@ def parse_uri(uri: str) -> WebSocketURI: """ parsed = urllib.parse.urlparse(uri) - try: - assert parsed.scheme in ["ws", "wss"] - assert parsed.params == "" - assert parsed.fragment == "" - assert parsed.hostname is not None - except AssertionError as exc: - raise exceptions.InvalidURI(uri) from exc + if parsed.scheme not in ["ws", "wss"]: + raise exceptions.InvalidURI(uri, "scheme isn't ws or wss") + if parsed.hostname is None: + raise exceptions.InvalidURI(uri, "hostname isn't provided") + if parsed.fragment != "": + raise exceptions.InvalidURI(uri, "fragment identifier is meaningless") secure = parsed.scheme == "wss" host = parsed.hostname @@ -73,7 +72,7 @@ def parse_uri(uri: str) -> WebSocketURI: # urllib.parse.urlparse accepts URLs with a username but without a # password. This doesn't make sense for HTTP Basic Auth credentials. if parsed.password is None: - raise exceptions.InvalidURI(uri) + raise exceptions.InvalidURI(uri, "username provided without password") user_info = (parsed.username, parsed.password) try: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index e172cdd02..3ede25fdb 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -142,8 +142,8 @@ def test_str(self): "WebSocket connection isn't established yet", ), ( - InvalidURI("|"), - "| isn't a valid URI", + InvalidURI("|", "not at all!"), + "| isn't a valid URI: not at all!", ), ( PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), diff --git a/tests/test_uri.py b/tests/test_uri.py index a91bcb083..f937d2949 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -17,6 +17,10 @@ "ws://localhost/path?query", WebSocketURI(False, "localhost", 80, "/path?query", None), ), + ( + "ws://localhost/path;params", + WebSocketURI(False, "localhost", 80, "/path;params", None), + ), ( "WS://LOCALHOST/PATH?QUERY", WebSocketURI(False, "localhost", 80, "/PATH?QUERY", None), @@ -39,6 +43,7 @@ "https://localhost/", "ws://localhost/path#fragment", "ws://user@localhost/", + "ws:///path", ] From 27f6539abc678182970af621c007646403968f82 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 15:01:47 +0200 Subject: [PATCH 150/762] Review reference docs. Remove sphinx-autodoc-typehints and use native autodoc features instead, mainly because I can't find a workaround for: https://github.com/agronholm/sphinx-autodoc-typehints/issues/132 Re-introduce the "common" docs, else linking to send() and recv() is a mess. --- docs/conf.py | 33 +- docs/howto/cheatsheet.rst | 16 +- docs/howto/django.rst | 2 +- docs/howto/extensions.rst | 15 +- docs/howto/faq.rst | 16 +- docs/intro/index.rst | 2 +- docs/project/changelog.rst | 108 ++--- docs/reference/client.rst | 62 +-- docs/reference/common.rst | 51 ++ docs/reference/exceptions.rst | 6 + docs/reference/extensions.rst | 36 +- docs/reference/index.rst | 46 +- docs/reference/limitations.rst | 4 +- docs/reference/server.rst | 84 ++-- docs/reference/types.rst | 20 + docs/reference/utilities.rst | 36 +- docs/requirements.txt | 1 - docs/spelling_wordlist.txt | 2 + docs/topics/broadcast.rst | 24 +- docs/topics/compression.rst | 7 +- docs/topics/deployment.rst | 2 +- docs/topics/design.rst | 38 +- docs/topics/memory.rst | 10 +- docs/topics/timeouts.rst | 9 +- src/websockets/__init__.py | 2 + src/websockets/__main__.py | 2 +- src/websockets/auth.py | 2 + src/websockets/client.py | 40 +- src/websockets/connection.py | 10 +- src/websockets/datastructures.py | 10 +- src/websockets/exceptions.py | 31 +- src/websockets/extensions/base.py | 100 ++-- .../extensions/permessage_deflate.py | 57 +-- src/websockets/frames.py | 94 ++-- src/websockets/headers.py | 69 +-- src/websockets/http11.py | 116 +++-- src/websockets/imports.py | 14 +- src/websockets/legacy/auth.py | 71 +-- src/websockets/legacy/client.py | 288 +++++------- src/websockets/legacy/framing.py | 65 ++- src/websockets/legacy/handshake.py | 62 +-- src/websockets/legacy/http.py | 22 +- src/websockets/legacy/protocol.py | 380 ++++++++++----- src/websockets/legacy/server.py | 442 ++++++++---------- src/websockets/server.py | 109 +++-- src/websockets/streams.py | 36 +- src/websockets/typing.py | 21 +- src/websockets/uri.py | 30 +- src/websockets/utils.py | 8 +- 49 files changed, 1467 insertions(+), 1244 deletions(-) create mode 100644 docs/reference/common.rst create mode 100644 docs/reference/exceptions.rst create mode 100644 docs/reference/types.rst diff --git a/docs/conf.py b/docs/conf.py index d5f527bf6..d9e3cd598 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -31,6 +31,28 @@ # -- General configuration --------------------------------------------------- +nitpicky = True + +nitpick_ignore = [ + # topics/design.rst discusses undocumented APIs + ("py:meth", "client.WebSocketClientProtocol.handshake"), + ("py:meth", "server.WebSocketServerProtocol.handshake"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.is_client"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.messages"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.close_connection"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.close_connection_task"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.transfer_data"), + ("py:attr", "legacy.protocol.WebSocketCommonProtocol.transfer_data_task"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_open"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.ensure_open"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.fail_connection"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_lost"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.read_message"), + ("py:meth", "legacy.protocol.WebSocketCommonProtocol.write_frame"), +] + # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. @@ -38,7 +60,7 @@ "sphinx.ext.autodoc", "sphinx.ext.intersphinx", "sphinx.ext.linkcode", - "sphinx_autodoc_typehints", + "sphinx.ext.napoleon", "sphinx_copybutton", "sphinx_inline_tabs", "sphinxcontrib.spelling", @@ -46,6 +68,15 @@ "sphinxext.opengraph", ] +autodoc_typehints = "description" + +autodoc_typehints_description_target = "documented" + +# Workaround for https://github.com/sphinx-doc/sphinx/issues/9560 +from sphinx.domains.python import PythonDomain +assert PythonDomain.object_types['data'].roles == ('data', 'obj') +PythonDomain.object_types['data'].roles = ('data', 'class', 'obj') + intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} spelling_show_suggestions = True diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst index edfb00baa..95b551f67 100644 --- a/docs/howto/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -26,27 +26,27 @@ Server :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't needed in general. -* Create a server with :func:`~legacy.server.serve` which is similar to asyncio's - :meth:`~asyncio.AbstractEventLoop.create_server`. You can also use it as an - asynchronous context manager. +* Create a server with :func:`~server.serve` which is similar to asyncio's + :meth:`~asyncio.loop.create_server`. You can also use it as an asynchronous + context manager. * The server takes care of establishing connections, then lets the handler execute the application logic, and finally closes the connection after the handler exits normally or with an exception. * For advanced customization, you may subclass - :class:`~legacy.server.WebSocketServerProtocol` and pass either this subclass or + :class:`~server.WebSocketServerProtocol` and pass either this subclass or a factory function as the ``create_protocol`` argument. Client ------ -* Create a client with :func:`~legacy.client.connect` which is similar to asyncio's - :meth:`~asyncio.BaseEventLoop.create_connection`. You can also use it as an +* Create a client with :func:`~client.connect` which is similar to asyncio's + :meth:`~asyncio.loop.create_connection`. You can also use it as an asynchronous context manager. * For advanced customization, you may subclass - :class:`~legacy.server.WebSocketClientProtocol` and pass either this subclass or + :class:`~client.WebSocketClientProtocol` and pass either this subclass or a factory function as the ``create_protocol`` argument. * Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and @@ -57,7 +57,7 @@ Client :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't needed in general. -* If you aren't using :func:`~legacy.client.connect` as a context manager, call +* If you aren't using :func:`~client.connect` as a context manager, call :meth:`~legacy.protocol.WebSocketCommonProtocol.close` to terminate the connection. .. _debugging: diff --git a/docs/howto/django.rst b/docs/howto/django.rst index c87c3821c..67bf582f1 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -124,7 +124,7 @@ support asynchronous I/O. It would block the event loop if it didn't run in a separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. -Finally, we start a server with :func:`~websockets.serve`. +Finally, we start a server with :func:`~websockets.server.serve`. We're ready to test! diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index 9c49de172..2baead3f0 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -4,8 +4,10 @@ Writing an extension .. currentmodule:: websockets.extensions During the opening handshake, WebSocket clients and servers negotiate which -extensions will be used with which parameters. Then each frame is processed by -extensions before being sent or after being received. +extensions_ will be used with which parameters. Then each frame is processed +by extensions before being sent or after being received. + +.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 As a consequence, writing an extension requires implementing several classes: @@ -23,11 +25,8 @@ As a consequence, writing an extension requires implementing several classes: Extensions are initialized by extension factories, so they don't need to be part of the public API of an extension. -websockets provides abstract base classes for extension factories and -extensions. See the API documentation for details on their methods: - -* :class:`ClientExtensionFactory` and :class:`ServerExtensionFactory` for - extension factories, -* :class:`Extension` for extensions. +websockets provides base classes for extension factories and extensions. +See :class:`ClientExtensionFactory`, :class:`ServerExtensionFactory`, +and :class:`Extension` for details. diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index d43d5fd67..ec657bb06 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -5,7 +5,7 @@ FAQ .. note:: - Many questions asked in :mod:`websockets`' issue tracker are actually + Many questions asked in websockets' issue tracker are actually about :mod:`asyncio`. Python's documentation about `developing with asyncio`_ is a good complement. @@ -99,13 +99,13 @@ How do I get access HTTP headers, for example cookies? ...................................................... To access HTTP headers during the WebSocket handshake, you can override -:attr:`~legacy.server.WebSocketServerProtocol.process_request`:: +:attr:`~server.WebSocketServerProtocol.process_request`:: async def process_request(self, path, request_headers): cookies = request_header["Cookie"] Once the connection is established, they're available in -:attr:`~legacy.protocol.WebSocketServerProtocol.request_headers`:: +:attr:`~server.WebSocketServerProtocol.request_headers`:: async def handler(websocket, path): cookies = websocket.request_headers["Cookie"] @@ -123,7 +123,7 @@ How do I set which IP addresses my server listens to? Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. -:func:`serve` accepts the same arguments as +:func:`~server.serve` accepts the same arguments as :meth:`~asyncio.loop.create_server`. How do I close a connection properly? @@ -143,7 +143,7 @@ Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the -:attr:`~legacy.server.WebSocketServerProtocol.process_request` hook. +:attr:`~server.WebSocketServerProtocol.process_request` hook. If you need more, pick a HTTP server and run it separately. @@ -169,7 +169,7 @@ change it to:: How do I close a connection properly? ..................................... -The easiest is to use :func:`connect` as a context manager:: +The easiest is to use :func:`~client.connect` as a context manager:: async with connect(...) as websocket: ... @@ -196,7 +196,7 @@ How do I disable TLS/SSL certificate verification? Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. -:func:`connect` accepts the same arguments as +:func:`~client.connect` accepts the same arguments as :meth:`~asyncio.loop.create_connection`. asyncio usage @@ -449,4 +449,4 @@ I'm having problems with threads You shouldn't use threads. Use tasks instead. -:func:`~asyncio.AbstractEventLoop.call_soon_threadsafe` may help. +:meth:`~asyncio.loop.call_soon_threadsafe` may help. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index ff3ae0ffd..c8426719c 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -82,7 +82,7 @@ This client needs a context because the server uses a self-signed certificate. A client connecting to a secure WebSocket server with a valid certificate (i.e. signed by a CA that your Python installation trusts) can simply pass -``ssl=True`` to :func:`connect` instead of building a context. +``ssl=True`` to :func:`~client.connect` instead of building a context. Browser-based example --------------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b9329d7b5..073514ecd 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -37,10 +37,10 @@ They may change at any time. .. note:: **Version 10.0 enables a timeout of 10 seconds on** - :func:`~legacy.client.connect` **by default.** + :func:`~client.connect` **by default.** You can adjust the timeout with the ``open_timeout`` parameter. Set it to - ``None`` to disable the timeout entirely. + :obj:`None` to disable the timeout entirely. .. note:: @@ -51,7 +51,8 @@ They may change at any time. .. note:: - **Version 10.0 changes parameters of** ``ConnectionClosed.__init__`` **.** + **Version 10.0 changes arguments of** + :exc:`~exceptions.ConnectionClosed` **.** If you raise :exc:`~exceptions.ConnectionClosed` or a subclass — rather than catch them when websockets raises them — you must change your code. @@ -70,27 +71,30 @@ Also: * Added :func:`~websockets.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using - :func:`~legacy.client.connect` as an asynchronous iterator. + :func:`~client.connect` as an asynchronous iterator. -* Added ``open_timeout`` to :func:`~legacy.client.connect`. +* Added ``open_timeout`` to :func:`~client.connect`. * Improved logging. -* Provided additional information in :exc:`ConnectionClosed` exceptions. +* Provided additional information in :exc:`~exceptions.ConnectionClosed` + exceptions. * Optimized default compression settings to reduce memory usage. * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. -* Fixed handling of relative redirects in :func:`~legacy.client.connect`. +* Fixed handling of relative redirects in :func:`~client.connect`. + +* Improved API documentation. 9.1 ... *May 27, 2021* -.. note:: +.. caution:: **Version 9.1 fixes a security issue introduced in version 8.0.** @@ -197,7 +201,7 @@ Also: *July 31, 2019* * Restored the ability to pass a socket with the ``sock`` parameter of - :func:`~legacy.server.serve`. + :func:`~server.serve`. * Removed an incorrect assertion when a connection drops. @@ -224,9 +228,9 @@ Also: Previously, it could be a function or a coroutine. If you're passing a ``process_request`` argument to - :func:`~legacy.server.serve` - or :class:`~legacy.server.WebSocketServerProtocol`, or if you're overriding - :meth:`~legacy.server.WebSocketServerProtocol.process_request` in a subclass, + :func:`~server.serve` or :class:`~server.WebSocketServerProtocol`, or if + you're overriding + :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, define it with ``async def`` instead of ``def``. For backwards compatibility, functions are still mostly supported, but @@ -274,15 +278,15 @@ Also: :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP +* Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth on the server side. -* :func:`~legacy.client.connect` handles redirects from the server during the +* :func:`~client.connect` handles redirects from the server during the handshake. -* :func:`~legacy.client.connect` supports overriding ``host`` and ``port``. +* :func:`~client.connect` supports overriding ``host`` and ``port``. -* Added :func:`~legacy.client.unix_connect` for connecting to Unix sockets. +* Added :func:`~client.unix_connect` for connecting to Unix sockets. * Improved support for sending fragmented messages by accepting asynchronous iterators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. @@ -292,10 +296,10 @@ Also: as a workaround, you can remove it. * Changed :meth:`WebSocketServer.close() - ` to perform a proper closing handshake + ` to perform a proper closing handshake instead of failing the connection. -* Avoided a crash when a ``extra_headers`` callable returns ``None``. +* Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. * Improved error messages when HTTP parsing fails. @@ -327,7 +331,7 @@ Also: **Version 7.0 changes how a server terminates connections when it's closed with** :meth:`WebSocketServer.close() - ` **.** + ` **.** Previously, connections handlers were canceled. Now, connections are closed with close code 1001 (going away). From the perspective of the @@ -345,7 +349,7 @@ Also: .. note:: **Version 7.0 renames the** ``timeout`` **argument of** - :func:`~legacy.server.serve` **and** :func:`~legacy.client.connect` **to** + :func:`~server.serve` **and** :func:`~client.connect` **to** ``close_timeout`` **.** This prevents confusion with ``ping_timeout``. @@ -375,11 +379,11 @@ Also: Also: * Added ``process_request`` and ``select_subprotocol`` arguments to - :func:`~legacy.server.serve` and - :class:`~legacy.server.WebSocketServerProtocol` to customize - :meth:`~legacy.server.WebSocketServerProtocol.process_request` and - :meth:`~legacy.server.WebSocketServerProtocol.select_subprotocol` without - subclassing :class:`~legacy.server.WebSocketServerProtocol`. + :func:`~server.serve` and + :class:`~server.WebSocketServerProtocol` to customize + :meth:`~server.WebSocketServerProtocol.process_request` and + :meth:`~server.WebSocketServerProtocol.select_subprotocol` without + subclassing :class:`~server.WebSocketServerProtocol`. * Added support for sending fragmented messages. @@ -389,7 +393,7 @@ Also: * Added an interactive client: ``python -m websockets ``. * Changed the ``origins`` argument to represent the lack of an origin with - ``None`` rather than ``''``. + :obj:`None` rather than ``''``. * Fixed a data loss bug in :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: @@ -409,7 +413,7 @@ Also: **Version 6.0 introduces the** :class:`~datastructures.Headers` **class for managing HTTP headers and changes several public APIs:** - * :meth:`~legacy.server.WebSocketServerProtocol.process_request` now + * :meth:`~server.WebSocketServerProtocol.process_request` now receives a :class:`~datastructures.Headers` instead of a ``http.client.HTTPMessage`` in the ``request_headers`` argument. @@ -445,14 +449,14 @@ Also: *May 24, 2018* * Fixed a regression in 5.0 that broke some invocations of - :func:`~legacy.server.serve` and :func:`~legacy.client.connect`. + :func:`~server.serve` and :func:`~client.connect`. 5.0 ... *May 22, 2018* -.. note:: +.. caution:: **Version 5.0 fixes a security issue introduced in version 4.0.** @@ -465,14 +469,14 @@ Also: .. note:: **Version 5.0 adds a** ``user_info`` **field to the return value of** - :func:`~uri.parse_uri` **and** :class:`~uri.WebSocketURI` **.** + ``parse_uri`` **and** ``WebSocketURI`` **.** - If you're unpacking :class:`~uri.WebSocketURI` into four variables, adjust - your code to account for that fifth field. + If you're unpacking ``WebSocketURI`` into four variables, adjust your code + to account for that fifth field. Also: -* :func:`~legacy.client.connect` performs HTTP Basic Auth when the URI contains +* :func:`~client.connect` performs HTTP Basic Auth when the URI contains credentials. * Iterating on incoming messages no longer raises an exception when the @@ -481,7 +485,7 @@ Also: * A plain HTTP request now receives a 426 Upgrade Required response and doesn't log a stack trace. -* :func:`~legacy.server.unix_serve` can be used as an asynchronous context +* :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property @@ -536,7 +540,7 @@ Also: Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~legacy.server.serve` or :func:`~legacy.client.connect`. + :func:`~server.serve` or :func:`~client.connect`. .. note:: @@ -549,10 +553,10 @@ Also: * :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. -* Added :func:`~legacy.server.unix_serve` for listening on Unix sockets. +* Added :func:`~server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the - return value of :func:`~legacy.server.serve`. +* Added the :attr:`~server.WebSocketServer.sockets` attribute to the + return value of :func:`~server.serve`. * Reorganized and extended documentation. @@ -572,15 +576,15 @@ Also: *August 20, 2017* -* Renamed :func:`~legacy.server.serve` and :func:`~legacy.client.connect`'s +* Renamed :func:`~server.serve` and :func:`~client.connect`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. -* :func:`~legacy.server.serve` can be used as an asynchronous context manager +* :func:`~server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with - :meth:`~legacy.server.WebSocketServerProtocol.process_request`. + :meth:`~server.WebSocketServerProtocol.process_request`. * Made read and write buffer sizes configurable. @@ -588,10 +592,10 @@ Also: * Added an optional C extension to speed up low-level operations. -* An invalid response status code during :func:`~legacy.client.connect` now +* An invalid response status code during :func:`~client.connect` now raises :class:`~exceptions.InvalidStatusCode`. -* Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer +* Providing a ``sock`` argument to :func:`~client.connect` no longer crashes. 3.3 @@ -611,7 +615,7 @@ Also: *August 17, 2016* * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. + :func:`~client.connect` and :func:`~server.serve`. * Made server shutdown more robust. @@ -637,8 +641,8 @@ Also: **If you're upgrading from 2.x or earlier, please read this carefully.** :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` used to return - ``None`` when the connection was closed. This required checking the return - value of every call:: + :obj:`None` when the connection was closed. This required checking the + return value of every call:: message = await websocket.recv() if message is None: @@ -655,14 +659,14 @@ Also: In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to - :func:`~legacy.server.serve`, :func:`~legacy.client.connect`, - :class:`~legacy.server.WebSocketServerProtocol`, or - :class:`~legacy.client.WebSocketClientProtocol`. ``legacy_recv`` isn't + :func:`~server.serve`, :func:`~client.connect`, + :class:`~server.WebSocketServerProtocol`, or + :class:`~client.WebSocketClientProtocol`. ``legacy_recv`` isn't documented in their signatures but isn't scheduled for deprecation either. Also: -* :func:`~legacy.client.connect` can be used as an asynchronous context +* :func:`~client.connect` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Updated documentation with ``await`` and ``async`` syntax from Python 3.5. @@ -732,8 +736,8 @@ Also: * Added support for subprotocols. -* Added ``loop`` argument to :func:`~legacy.client.connect` and - :func:`~legacy.server.serve`. +* Added ``loop`` argument to :func:`~client.connect` and + :func:`~server.serve`. 2.3 ... diff --git a/docs/reference/client.rst b/docs/reference/client.rst index 84f66a19a..eaa6cdd76 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -6,67 +6,55 @@ Client Opening a connection -------------------- - .. autofunction:: connect(uri, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None, **kwds) + .. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None, **kwds) + .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Using a connection ------------------ - .. autoclass:: WebSocketClientProtocol(*, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origin=None, extensions=None, subprotocols=None, extra_headers=None, logger=None) + .. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) - .. attribute:: id - - UUID for the connection. - - Useful for identifying connections in logs. - - .. autoattribute:: local_address - - .. autoattribute:: remote_address - - .. autoattribute:: open - - .. autoattribute:: closed - - .. attribute:: path + .. automethod:: recv - Path of the HTTP request. + .. automethod:: send - Available once the connection is open. + .. automethod:: close - .. attribute:: request_headers + .. automethod:: wait_closed - HTTP request headers as a :class:`~websockets.http.Headers` instance. + .. automethod:: ping - Available once the connection is open. + .. automethod:: pong - .. attribute:: response_headers + WebSocket connection objects also provide these attributes: - HTTP response headers as a :class:`~websockets.http.Headers` instance. + .. autoattribute:: id - Available once the connection is open. + .. autoproperty:: local_address - .. attribute:: subprotocol + .. autoproperty:: remote_address - Subprotocol, if one was negotiated. + .. autoproperty:: open - Available once the connection is open. + .. autoproperty:: closed - .. autoattribute:: close_code + The following attributes are available after the opening handshake, + once the WebSocket connection is open: - .. autoattribute:: close_reason + .. autoattribute:: path - .. automethod:: recv + .. autoattribute:: request_headers - .. automethod:: send + .. autoattribute:: response_headers - .. automethod:: ping + .. autoattribute:: subprotocol - .. automethod:: pong + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: - .. automethod:: close + .. autoproperty:: close_code - .. automethod:: wait_closed + .. autoproperty:: close_reason diff --git a/docs/reference/common.rst b/docs/reference/common.rst new file mode 100644 index 000000000..3b9f34a57 --- /dev/null +++ b/docs/reference/common.rst @@ -0,0 +1,51 @@ +Both sides +========== + +.. automodule:: websockets.legacy.protocol + + Using a connection + ------------------ + + .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst new file mode 100644 index 000000000..907a650d2 --- /dev/null +++ b/docs/reference/exceptions.rst @@ -0,0 +1,6 @@ +Exceptions +========== + +.. automodule:: websockets.exceptions + :members: + diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index bae583a21..a70f1b1e5 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -6,7 +6,7 @@ Extensions The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public -specification, WebSocket Per-Message Deflate, specified in :rfc:`7692`. +specification, WebSocket Per-Message Deflate. .. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name @@ -16,21 +16,45 @@ Per-Message Deflate .. automodule:: websockets.extensions.permessage_deflate + :mod:`websockets.extensions.permessage_deflate` implements WebSocket + Per-Message Deflate. + + This extension is specified in :rfc:`7692`. + + Refer to the :doc:`topic guide on compression <../topics/compression>` to + learn more about tuning compression settings. + .. autoclass:: ClientPerMessageDeflateFactory .. autoclass:: ServerPerMessageDeflateFactory -Abstract classes ----------------- +Base classes +------------ .. automodule:: websockets.extensions + :mod:`websockets.extensions` defines base classes for implementing + extensions. + + Refer to the :doc:`how-to guide on extensions <../howto/extensions>` to + learn more about writing an extension. + .. autoclass:: Extension - :members: + + .. autoattribute:: name + + .. automethod:: decode + + .. automethod:: encode .. autoclass:: ClientExtensionFactory - :members: + + .. autoattribute:: name + + .. automethod:: get_request_params + + .. automethod:: process_response_params .. autoclass:: ServerExtensionFactory - :members: + .. automethod:: process_request_params diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 8d01c5b40..385beab29 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -1,24 +1,27 @@ API reference ============= -websockets provides complete client and server implementations, as shown in +.. currentmodule:: websockets + +websockets provides client and server implementations, as shown in the :doc:`getting started guide <../intro/index>`. The process for opening and closing a WebSocket connection depends on which side you're implementing. -* On the client side, connecting to a server with :class:`~websockets.connect` +* On the client side, connecting to a server with :func:`~client.connect` yields a connection object that provides methods for interacting with the connection. Your code can open a connection, then send or receive messages. - If you use :class:`~websockets.connect` as an asynchronous context manager, + If you use :func:`~client.connect` as an asynchronous context manager, then websockets closes the connection on exit. If not, then your code is responsible for closing the connection. -* On the server side, :class:`~websockets.serve` starts listening for client - connections and yields an server object that supports closing the server. +* On the server side, :func:`~server.serve` starts listening for client + connections and yields an server object that you can use to shut down + the server. - Then, when clients connects, the server initializes a connection object and + Then, when a client connects, the server initializes a connection object and passes it to a handler coroutine, which is where your code can send or receive messages. This pattern is called `inversion of control`_. It's common in frameworks implementing servers. @@ -29,28 +32,35 @@ side you're implementing. .. _inversion of control: https://en.wikipedia.org/wiki/Inversion_of_control Once the connection is open, the WebSocket protocol is symmetrical, except for -low-level details that websockets manages under the hood. The same methods are -available on client connections created with :class:`~websockets.connect` and -on server connections passed to the connection handler in the arguments. +low-level details that websockets manages under the hood. The same methods +are available on client connections created with :func:`~client.connect` and +on server connections received in argument by the connection handler +of :func:`~server.serve`. -At this point, websockets provides the same API — and uses the same code — for -client and server connections. For convenience, common methods are documented -both in the client API and server API. +Since websockets provides the same API — and uses the same code — for client +and server connections, common methods are documented in a "Both sides" page. .. toctree:: :titlesonly: client server - extensions + common utilities + exceptions + types + extensions limitations -All public APIs can be imported from the :mod:`websockets` package, unless -noted otherwise. This convenience feature is incompatible with static code -analysis tools such as mypy_, though. +Public API documented in the API reference are subject to the +:ref:`backwards-compatibility policy `. + +Anything that isn't listed in the API reference is a private API. There's no +guarantees of behavior or backwards-compatibility for private APIs. + +For convenience, many public APIs can be imported from the ``websockets`` +package. This feature is incompatible with static code analysis tools such as +mypy_, though. If you're using such tools, use the full import path. .. _mypy: https://github.com/python/mypy -Anything that isn't listed in this API documentation is a private API. There's -no guarantees of behavior or backwards-compatibility for private APIs. diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst index 81f1445b5..3304bdb8c 100644 --- a/docs/reference/limitations.rst +++ b/docs/reference/limitations.rst @@ -1,12 +1,14 @@ Limitations =========== +.. currentmodule:: websockets + Client ------ The client doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -`mandated by RFC 6455`_. However, :func:`~websockets.connect()` isn't the +`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. .. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 667c0b9d0..0a5a060f3 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -6,10 +6,10 @@ Server Starting a server ----------------- - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None, **kwds) + .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: - .. autofunction:: unix_serve(ws_handler, path, *, create_protocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, compression='deflate', origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None, **kwds) + .. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Stopping a server @@ -17,92 +17,80 @@ Server .. autoclass:: WebSocketServer - .. autoattribute:: sockets - .. automethod:: close + .. automethod:: wait_closed + .. autoattribute:: sockets + Using a connection ------------------ - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, logger=None) - - .. attribute:: id - - UUID for the connection. - - Useful for identifying connections in logs. + .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) - .. autoattribute:: local_address - - .. autoattribute:: remote_address - - .. autoattribute:: open + .. automethod:: recv - .. autoattribute:: closed + .. automethod:: send - .. attribute:: path + .. automethod:: close - Path of the HTTP request. + .. automethod:: wait_closed - Available once the connection is open. + .. automethod:: ping - .. attribute:: request_headers + .. automethod:: pong - HTTP request headers as a :class:`~websockets.http.Headers` instance. + You can customize the opening handshake in a subclass by overriding these methods: - Available once the connection is open. + .. automethod:: process_request - .. attribute:: response_headers + .. automethod:: select_subprotocol - HTTP response headers as a :class:`~websockets.http.Headers` instance. + WebSocket connection objects also provide these attributes: - Available once the connection is open. + .. autoattribute:: id - .. attribute:: subprotocol + .. autoproperty:: local_address - Subprotocol, if one was negotiated. + .. autoproperty:: remote_address - Available once the connection is open. + .. autoproperty:: open - .. autoattribute:: close_code + .. autoproperty:: closed - .. autoattribute:: close_reason + The following attributes are available after the opening handshake, + once the WebSocket connection is open: - .. automethod:: process_request + .. autoattribute:: path - .. automethod:: select_subprotocol + .. autoattribute:: request_headers - .. automethod:: recv + .. autoattribute:: response_headers - .. automethod:: send + .. autoattribute:: subprotocol - .. automethod:: ping + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: - .. automethod:: pong + .. autoproperty:: close_code - .. automethod:: close + .. autoproperty:: close_reason - .. automethod:: wait_closed Basic authentication -------------------- .. automodule:: websockets.auth + websockets supports HTTP Basic Authentication according to + :rfc:`7235` and :rfc:`7617`. + .. autofunction:: basic_auth_protocol_factory .. autoclass:: BasicAuthWebSocketServerProtocol - .. attribute:: realm - - Scope of protection. - - If provided, it should contain only ASCII characters because the - encoding of non-ASCII characters is undefined. - - .. attribute:: username + .. autoattribute:: realm - Username of the authenticated user. + .. autoattribute:: username .. automethod:: check_credentials diff --git a/docs/reference/types.rst b/docs/reference/types.rst new file mode 100644 index 000000000..3dab553af --- /dev/null +++ b/docs/reference/types.rst @@ -0,0 +1,20 @@ +Types +===== + +.. autodata:: websockets.datastructures.HeadersLike + +.. automodule:: websockets.typing + + .. autodata:: Data + + .. autodata:: LoggerLike + + .. autodata:: Origin + + .. autodata:: Subprotocol + + .. autodata:: ExtensionName + + .. autodata:: ExtensionParameter + + diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index e7f489fbd..dc6333847 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -6,36 +6,32 @@ Broadcast .. autofunction:: websockets.broadcast -Data structures ---------------- +WebSocket events +---------------- -.. automodule:: websockets.datastructures - - .. autoclass:: Headers +.. automodule:: websockets.frames - .. autodata:: HeadersLike + .. autoclass:: Frame - .. autoexception:: MultipleValuesError + .. autoclass:: Opcode -Exceptions ----------- + .. autoclass:: Close -.. automodule:: websockets.exceptions - :members: +HTTP events +----------- -Types ------ +.. automodule:: websockets.http11 -.. automodule:: websockets.typing + .. autoclass:: Request - .. autodata:: Data + .. autoclass:: Response - .. autodata:: LoggerLike +.. automodule:: websockets.datastructures - .. autodata:: Origin + .. autoclass:: Headers - .. autodata:: Subprotocol + .. automethod:: get_all - .. autodata:: ExtensionName + .. automethod:: raw_items - .. autodata:: ExtensionParameter + .. autoexception:: MultipleValuesError diff --git a/docs/requirements.txt b/docs/requirements.txt index b9c371228..bcd1d7114 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,6 @@ furo sphinx sphinx-autobuild -sphinx-autodoc-typehints sphinx-copybutton sphinx-inline-tabs sphinxcontrib-spelling diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 3d05752d5..b57d3c77f 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -35,6 +35,7 @@ linkerd liveness lookups MiB +mypy nginx permessage pid @@ -59,6 +60,7 @@ tox unregister uple uvicorn +uvloop virtualenv WebSocket websocket diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index f9cd9e281..a90cc2d70 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -1,13 +1,13 @@ Broadcasting messages ===================== -.. currentmodule: websockets +.. currentmodule:: websockets .. note:: If you just want to send a message to all connected clients, use - :func:`~websockets.broadcast`. + :func:`broadcast`. If you want to learn about its design in depth, continue reading this document. @@ -16,7 +16,7 @@ WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. Let's explore options for broadcasting a message, explain the design -of :func:`~websockets.broadcast`, and discuss alternatives. +of :func:`broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all @@ -122,7 +122,7 @@ connections before the write buffer has time to fill up. Don't set extreme ``write_limit``, ``ping_interval``, and ``ping_timeout`` values to ensure that this condition holds. Set reasonable values and use the -built-in :func:`~websockets.broadcast` function instead. +built-in :func:`broadcast` function instead. The concurrent way ------------------ @@ -207,11 +207,11 @@ If a client gets too far behind, eventually it reaches the limit defined by ``ping_timeout`` and websockets terminates the connection. You can read the discussion of :doc:`keepalive and timeouts <./timeouts>` for details. -How :func:`~websockets.broadcast` works ---------------------------------------- +How :func:`broadcast` works +--------------------------- -The built-in :func:`~websockets.broadcast` function is similar to the naive -way. The main difference is that it doesn't apply backpressure. +The built-in :func:`broadcast` function is similar to the naive way. The main +difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. @@ -321,9 +321,9 @@ the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- -The built-in :func:`~websockets.broadcast` function sends all messages without -yielding control to the event loop. So does the naive way when the network -and clients are fast and reliable. +The built-in :func:`broadcast` function sends all messages without yielding +control to the event loop. So does the naive way when the network and clients +are fast and reliable. For each client, a WebSocket frame is prepared and sent to the network. This is the minimum amount of work required to broadcast a message. @@ -343,7 +343,7 @@ However, this isn't possible in general for two reasons: All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower -than the built-in :func:`~websockets.broadcast` function. +than the built-in :func:`broadcast` function. There is no major difference between the performance of per-message queues and publish–subscribe. diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index f78e32748..d40c4257d 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -1,6 +1,8 @@ Compression =========== +.. currentmodule:: websockets.extensions.permessage_deflate + Most WebSocket servers exchange JSON messages because they're convenient to parse and serialize in a browser. These messages contain text data and tend to be repetitive. @@ -29,9 +31,8 @@ If you want to disable compression, set ``compression=None``:: websockets.serve(..., compression=None) If you want to customize compression settings, you can enable the Per-Message -Deflate extension explicitly with -:class:`~permessage_deflate.ClientPerMessageDeflateFactory` or -:class:`~permessage_deflate.ServerPerMessageDeflateFactory`:: +Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or +:class:`ServerPerMessageDeflateFactory`:: import websockets from websockets.extensions import permessage_deflate diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index d30c5568e..ac0a8ed4c 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -78,7 +78,7 @@ Option 2 almost always combines with option 3. How do I start a process? ......................... -Run a Python program that invokes :func:`~serve`. That's it. +Run a Python program that invokes :func:`~server.serve`. That's it. Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're alternatives to websockets, not complements. diff --git a/docs/topics/design.rst b/docs/topics/design.rst index 2c9d505aa..b5c55afc9 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -35,7 +35,7 @@ Transitions happen in the following places: :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` which runs when the :ref:`opening handshake ` completes and the WebSocket connection is established — not to be confused with - :meth:`~asyncio.Protocol.connection_made` which runs when the TCP connection + :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP connection is established; - ``OPEN -> CLOSING``: in :meth:`~legacy.protocol.WebSocketCommonProtocol.write_frame` immediately before @@ -58,7 +58,7 @@ connection lifecycle on the client side. :target: _images/lifecycle.svg The lifecycle is identical on the server side, except inversion of control -makes the equivalent of :meth:`~legacy.client.connect` implicit. +makes the equivalent of :meth:`~client.connect` implicit. Coroutines shown in green are called by the application. Multiple coroutines may interact with the WebSocket connection concurrently. @@ -113,7 +113,7 @@ Opening handshake ----------------- websockets performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~legacy.client.connect` executes it +connection. On the client side, :meth:`~client.connect` executes it before returning the protocol to the caller. On the server side, it's executed before passing the protocol to the ``ws_handler`` coroutine handling the connection. @@ -123,26 +123,26 @@ request and the server replies with an HTTP Switching Protocols response — websockets aims at keeping the implementation of both sides consistent with one another. -On the client side, :meth:`~legacy.client.WebSocketClientProtocol.handshake`: +On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: - builds a HTTP request based on the ``uri`` and parameters passed to - :meth:`~legacy.client.connect`; + :meth:`~client.connect`; - writes the HTTP request to the network; - reads a HTTP response from the network; - checks the HTTP response, validates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - moves to the ``OPEN`` state. -On the server side, :meth:`~legacy.server.WebSocketServerProtocol.handshake`: +On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: - reads a HTTP request from the network; -- calls :meth:`~legacy.server.WebSocketServerProtocol.process_request` which may +- calls :meth:`~server.WebSocketServerProtocol.process_request` which may abort the WebSocket handshake and return a HTTP response instead; this hook only makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds a HTTP response based on the above and parameters passed to - :meth:`~legacy.server.serve`; + :meth:`~server.serve`; - writes the HTTP response to the network; - moves to the ``OPEN`` state; - returns the ``path`` part of the ``uri``. @@ -186,8 +186,8 @@ in the same class, :class:`~legacy.protocol.WebSocketCommonProtocol`. The :attr:`~legacy.protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the -:attr:`~legacy.server.WebSocketServerProtocol` and -:attr:`~legacy.client.WebSocketClientProtocol` classes. +:attr:`~server.WebSocketServerProtocol` and +:attr:`~client.WebSocketClientProtocol` classes. Data flow ......... @@ -264,14 +264,14 @@ Closing handshake When the other side of the connection initiates the closing handshake, :meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives a close frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a -close frame, and returns ``None``, causing +close frame, and returns :obj:`None`, causing :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. When this side of the connection initiates the closing handshake with :meth:`~legacy.protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` state and sends a close frame. When the other side sends a close frame, :meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives it in the -``CLOSING`` state and returns ``None``, also causing +``CLOSING`` state and returns :obj:`None`, also causing :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close @@ -313,7 +313,7 @@ of canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_ta and failing to close the TCP connection, thus leaking resources. Then :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` cancels -:attr:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no +:meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no protocol compliance responsibilities. Terminating it to avoid leaking it is the only concern. @@ -445,15 +445,15 @@ is canceled, which is correct at this point. to prevent cancellation. :meth:`~legacy.protocol.WebSocketCommonProtocol.close` and -:func:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` are the only +:meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` are the only places where :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` may be canceled. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connnection_task` starts by +:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` starts by waiting for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`. It catches :exc:`~asyncio.CancelledError` to prevent a cancellation of :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` from propagating -to :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connnection_task`. +to :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`. .. _backpressure: @@ -520,21 +520,21 @@ For each connection, the receiving side contains these buffers: - OS buffers: tuning them is an advanced optimization. - :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64 KiB. You can set another limit by passing a ``read_limit`` keyword argument to - :func:`~legacy.client.connect()` or :func:`~legacy.server.serve`. + :func:`~client.connect()` or :func:`~server.serve`. - Incoming messages :class:`~collections.deque`: its size depends both on the size and the number of messages it contains. By default the maximum UTF-8 encoded size is 1 MiB and the maximum number is 32. In the worst case, after UTF-8 decoding, a single message could take up to 4 MiB of memory and the overall memory consumption could reach 128 MiB. You should adjust these limits by setting the ``max_size`` and ``max_queue`` keyword arguments of - :func:`~legacy.client.connect()` or :func:`~legacy.server.serve` according to your + :func:`~client.connect()` or :func:`~server.serve` according to your application's requirements. For each connection, the sending side contains these buffers: - :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64 KiB. You can set another limit by passing a ``write_limit`` keyword argument to - :func:`~legacy.client.connect()` or :func:`~legacy.server.serve`. + :func:`~client.connect()` or :func:`~server.serve`. - OS buffers: tuning them is an advanced optimization. Concurrency diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index c880d5579..ee0109c35 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -1,6 +1,8 @@ Memory usage ============ +.. currentmodule:: websockets + In most cases, memory usage of a WebSocket server is proportional to the number of open connections. When a server handles thousands of connections, memory usage can become a bottleneck. @@ -17,8 +19,8 @@ Baseline Compression settings are the main factor affecting the baseline amount of memory used by each connection. -Read to the topic guide on :doc:`../topics/compression` to learn more about -tuning compression settings. +Refer to the :doc:`topic guide on compression <../topics/compression>` to +learn more about tuning compression settings. Buffers ------- @@ -29,7 +31,7 @@ Under high load, if a server receives more messages than it can process, bufferbloat can result in excessive memory usage. By default websockets has generous limits. It is strongly recommended to adapt -them to your application. When you call :func:`~legacy.server.serve`: +them to your application. When you call :func:`~server.serve`: - Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of messages your application generates. @@ -40,4 +42,4 @@ them to your application. When you call :func:`~legacy.server.serve`: Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: 64 KiB) to reduce the size of buffers for incoming and outgoing data. -The design document provides :ref:`more details about buffers`. +The design document provides :ref:`more details about buffers `. diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 8febfce9f..51666ceea 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -1,6 +1,8 @@ Timeouts ======== +.. currentmodule:: websockets + Since the WebSocket protocol is intended for real-time communications over long-lived connections, it is desirable to ensure that connections don't break, and if they do, to report the problem quickly. @@ -13,15 +15,18 @@ As a consequence, proxies may terminate WebSocket connections prematurely, when no message was exchanged in 30 seconds. In order to avoid this problem, websockets implements a keepalive mechanism -based on WebSocket Ping and Pong frames. Ping and Pong are designed for this +based on WebSocket Ping_ and Pong_ frames. Ping and Pong are designed for this purpose. +.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 +.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + By default, websockets waits 20 seconds, then sends a Ping frame, and expects to receive the corresponding Pong frame within 20 seconds. Else, it considers the connection broken and closes it. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` -arguments of :func:`~websockets.connect` and :func:`~websockets.serve`. +arguments of :func:`~client.connect` and :func:`~server.serve`. While WebSocket runs on top of TCP, websockets doesn't rely on TCP keepalive because it's disabled by default and, if enabled, the default interval is no diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 5883c3d65..ec3484124 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .imports import lazy_import from .version import version as __version__ # noqa diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 785d2c3c9..860e4b1fa 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -183,7 +183,7 @@ def main() -> None: # Due to zealous removal of the loop parameter in the Queue constructor, # we need a factory coroutine to run in the freshly created event loop. - async def queue_factory() -> "asyncio.Queue[str]": + async def queue_factory() -> asyncio.Queue[str]: return asyncio.Queue() # Create a queue of user inputs. There's no need to limit its size. diff --git a/src/websockets/auth.py b/src/websockets/auth.py index f97c1feb0..afcb38cff 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,2 +1,4 @@ +from __future__ import annotations + # See #940 for why lazy_import isn't used here for backwards compatibility. from .legacy.auth import * # noqa diff --git a/src/websockets/client.py b/src/websockets/client.py index e21fd36c0..13c8a8ad0 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -68,7 +68,7 @@ def __init__( def connect(self) -> Request: # noqa: F811 """ - Create a WebSocket handshake request event to send to the server. + Create a WebSocket handshake request event to open a connection. """ headers = Headers() @@ -107,12 +107,13 @@ def connect(self) -> Request: # noqa: F811 def process_response(self, response: Response) -> None: """ - Check a handshake response received from the server. + Check a handshake response. - :param response: response - :param key: comes from :func:`build_request` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake response - is invalid + Args: + request: WebSocket handshake response received from the server. + + Raises: + InvalidHandshake: if the handshake response is invalid. """ @@ -162,11 +163,6 @@ def process_extensions(self, headers: Headers) -> List[Extension]: Check that each extension is supported, as well as its parameters. - Return the list of accepted extensions. - - Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the - connection. - :rfc:`6455` leaves the rules up to the specification of each extension. @@ -182,6 +178,15 @@ def process_extensions(self, headers: Headers) -> List[Extension]: Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. + Args: + headers: WebSocket handshake response headers. + + Returns: + List[Extension]: List of accepted extensions. + + Raises: + InvalidHandshake: to abort the handshake. + """ accepted_extensions: List[Extension] = [] @@ -232,9 +237,13 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: """ Handle the Sec-WebSocket-Protocol HTTP response header. - Check that it contains exactly one supported subprotocol. + If provided, check that it contains exactly one supported subprotocol. - Return the selected subprotocol. + Args: + headers: WebSocket handshake response headers. + + Returns: + Optional[Subprotocol]: Subprotocol, if one was selected. """ subprotocol: Optional[Subprotocol] = None @@ -263,7 +272,10 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: def send_request(self, request: Request) -> None: """ - Send a WebSocket handshake request to the server. + Send a handshake request to the server. + + Args: + request: WebSocket handshake request event to send. """ if self.debug: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 52fd9bb81..684664860 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -212,7 +212,8 @@ def receive_data(self, data: bytes) -> None: - You must call :meth:`data_to_send` and send this data. - You should call :meth:`events_received` and process these events. - :raises EOFError: if :meth:`receive_eof` was called before + Raises: + EOFError: if :meth:`receive_eof` was called before. """ self.reader.feed_data(data) @@ -228,7 +229,8 @@ def receive_eof(self) -> None: - You aren't exepcted to call :meth:`events_received` as it won't return any new events. - :raises EOFError: if :meth:`receive_eof` was called before + Raises: + EOFError: if :meth:`receive_eof` was called before. """ self.reader.feed_eof() @@ -367,8 +369,8 @@ def close_expected(self) -> bool: Tell whether the TCP connection is expected to close soon. Call this method immediately after calling any of the ``receive_*()`` - or ``fail_*()`` methods and, if it returns ``True``, schedule closing - the TCP connection after a short timeout. + or ``fail_*()`` methods and, if it returns :obj:`True`, schedule + closing the TCP connection after a short timeout. """ # We already got a TCP Close if and only if the state is CLOSED. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 65c5d4115..1ff586abd 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,8 +1,3 @@ -""" -:mod:`websockets.datastructures` defines a class for manipulating HTTP headers. - -""" - from __future__ import annotations from typing import ( @@ -141,7 +136,7 @@ def clear(self) -> None: def update(self, *args: HeadersLike, **kwargs: str) -> None: """ - Update from a Headers instance and/or keyword arguments. + Update from a :class:`Headers` instance and/or keyword arguments. """ args = tuple( @@ -155,7 +150,8 @@ def get_all(self, key: str) -> List[str]: """ Return the (possibly empty) list of all values for a header. - :param key: header name + Args: + key: header name. """ return self._dict.get(key.lower(), []) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 6bbea324c..0c4fc5185 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -67,7 +67,7 @@ class WebSocketException(Exception): """ - Base class for all exceptions defined by :mod:`websockets`. + Base class for all exceptions defined by websockets. """ @@ -76,17 +76,14 @@ class ConnectionClosed(WebSocketException): """ Raised when trying to interact with a closed connection. - If a close frame was received, its code and reason are available in the - ``rcvd.code`` and ``rcvd.reason`` attributes. Else, the ``rcvd`` - attribute is ``None``. - - Likewise, if a close frame was sent, its code and reason are available in - the ``sent.code`` and ``sent.reason`` attributes. Else, the ``sent`` - attribute is ``None``. - - If close frames were received and sent, the ``rcvd_then_sent`` attribute - tells in which order this happened, from the perspective of this side of - the connection. + Attributes: + rcvd (Optional[Close]): if a close frame was received, its code and + reason are available in ``rcvd.code`` and ``rcvd.reason``. + sent (Optional[Close]): if a close frame was sent, its code and reason + are available in ``sent.code`` and ``sent.reason``. + rcvd_then_sent (Optional[bool]): if close frames were received and + sent, this attribute tells in which order this happened, from the + perspective of this side of the connection. """ @@ -249,9 +246,6 @@ class InvalidStatusCode(InvalidHandshake): """ Raised when a handshake response status code is invalid. - The integer status code is available in the ``status_code`` attribute and - HTTP headers in the ``headers`` attribute. - """ def __init__(self, status_code: int, headers: datastructures.Headers) -> None: @@ -320,8 +314,13 @@ class AbortHandshake(InvalidHandshake): This exception is an implementation detail. - The public API is :meth:`~legacy.server.WebSocketServerProtocol.process_request`. + The public API + is :meth:`~websockets.server.WebSocketServerProtocol.process_request`. + Attributes: + status (~http.HTTPStatus): HTTP status code. + headers (Headers): HTTP response headers. + body (bytes): HTTP response body. """ def __init__( diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 7217aa513..060967618 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,13 +1,3 @@ -""" -:mod:`websockets.extensions.base` defines abstract classes for implementing -extensions. - -See `section 9 of RFC 6455`_. - -.. _section 9 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 - -""" - from __future__ import annotations from typing import List, Optional, Sequence, Tuple @@ -21,16 +11,12 @@ class Extension: """ - Abstract class for extensions. + Base class for extensions. """ - @property - def name(self) -> ExtensionName: - """ - Extension identifier. - - """ + name: ExtensionName + """Extension identifier.""" def decode( self, @@ -41,8 +27,15 @@ def decode( """ Decode an incoming frame. - :param frame: incoming frame - :param max_size: maximum payload size in bytes + Args: + frame (Frame): incoming frame. + max_size: maximum payload size in bytes. + + Returns: + Frame: Decoded frame. + + Raises: + PayloadTooBig: if decoding the payload exceeds ``max_size``. """ @@ -50,29 +43,30 @@ def encode(self, frame: frames.Frame) -> frames.Frame: """ Encode an outgoing frame. - :param frame: outgoing frame + Args: + frame (Frame): outgoing frame. + + Returns: + Frame: Encoded frame. """ class ClientExtensionFactory: """ - Abstract class for client-side extension factories. + Base class for client-side extension factories. """ - @property - def name(self) -> ExtensionName: - """ - Extension identifier. - - """ + name: ExtensionName + """Extension identifier.""" def get_request_params(self) -> List[ExtensionParameter]: """ - Build request parameters. + Build parameters to send to the server for this extension. - Return a list of ``(name, value)`` pairs. + Returns: + List[ExtensionParameter]: Parameters to send to the server. """ @@ -82,28 +76,31 @@ def process_response_params( accepted_extensions: Sequence[Extension], ) -> Extension: """ - Process response parameters received from the server. + Process parameters received from the server. + + Args: + params (Sequence[ExtensionParameter]): parameters received from + the server for this extension. + accepted_extensions (Sequence[Extension]): list of previously + accepted extensions. - :param params: list of ``(name, value)`` pairs. - :param accepted_extensions: list of previously accepted extensions. - :raises ~websockets.exceptions.NegotiationError: if parameters aren't - acceptable + Returns: + Extension: An extension instance. + + Raises: + NegotiationError: if parameters aren't acceptable. """ class ServerExtensionFactory: """ - Abstract class for server-side extension factories. + Base class for server-side extension factories. """ - @property - def name(self) -> ExtensionName: - """ - Extension identifier. - - """ + name: ExtensionName + """Extension identifier.""" def process_request_params( self, @@ -111,16 +108,21 @@ def process_request_params( accepted_extensions: Sequence[Extension], ) -> Tuple[List[ExtensionParameter], Extension]: """ - Process request parameters received from the client. + Process parameters received from the client. - To accept the offer, return a 2-uple containing: + Args: + params (Sequence[ExtensionParameter]): parameters received from + the client for this extension. + accepted_extensions (Sequence[Extension]): list of previously + accepted extensions. - - response parameters: a list of ``(name, value)`` pairs - - an extension: an instance of a subclass of :class:`Extension` + Returns: + Tuple[List[ExtensionParameter], Extension]: To accept the offer, + parameters to send to the client for this extension and an + extension instance. - :param params: list of ``(name, value)`` pairs. - :param accepted_extensions: list of previously accepted extensions. - :raises ~websockets.exceptions.NegotiationError: to reject the offer, - if parameters aren't acceptable + Raises: + NegotiationError: to reject the offer, if parameters received from + the client aren't acceptable. """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index a377abb55..da2bc153e 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -1,9 +1,3 @@ -""" -:mod:`websockets.extensions.permessage_deflate` implements the Compression -Extensions for WebSocket as specified in :rfc:`7692`. - -""" - from __future__ import annotations import dataclasses @@ -204,8 +198,8 @@ def _extract_parameters( """ Extract compression parameters from a list of ``(name, value)`` pairs. - If ``is_server`` is ``True``, ``client_max_window_bits`` may be provided - without a value. This is only allow in handshake requests. + If ``is_server`` is :obj:`True`, ``client_max_window_bits`` may be + provided without a value. This is only allowed in handshake requests. """ server_no_context_takeover: bool = False @@ -264,18 +258,23 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): """ Client-side extension factory for the Per-Message Deflate extension. - Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to - ``True`` to include them in the negotiation offer without a value or to an - integer value to include them with this value. + Parameters behave as described in `section 7.1 of RFC 7692`_. .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 - :param server_no_context_takeover: defaults to ``False`` - :param client_no_context_takeover: defaults to ``False`` - :param server_max_window_bits: optional, defaults to ``None`` - :param client_max_window_bits: optional, defaults to ``None`` - :param compress_settings: optional, keyword arguments for - :func:`zlib.compressobj`, excluding ``wbits`` + Set them to :obj:`True` to include them in the negotiation offer without a + value or to an integer value to include them with this value. + + Args: + server_no_context_takeover: prevent server from using context takeover. + client_no_context_takeover: prevent client from using context takeover. + server_max_window_bits: maximum size of the server's LZ77 sliding window + in bits, between 8 and 15. + client_max_window_bits: maximum size of the client's LZ77 sliding window + in bits, between 8 and 15, or :obj:`True` to indicate support without + setting a limit. + compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + excluding ``wbits``. """ @@ -440,7 +439,6 @@ def enable_client_permessage_deflate( If the extension is already present, perhaps with non-default settings, the configuration isn't changed. - """ if extensions is None: extensions = [] @@ -462,18 +460,23 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): """ Server-side extension factory for the Per-Message Deflate extension. - Parameters behave as described in `section 7.1 of RFC 7692`_. Set them to - ``True`` to include them in the negotiation offer without a value or to an - integer value to include them with this value. + Parameters behave as described in `section 7.1 of RFC 7692`_. .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 - :param server_no_context_takeover: defaults to ``False`` - :param client_no_context_takeover: defaults to ``False`` - :param server_max_window_bits: optional, defaults to ``None`` - :param client_max_window_bits: optional, defaults to ``None`` - :param compress_settings: optional, keyword arguments for - :func:`zlib.compressobj`, excluding ``wbits`` + Set them to :obj:`True` to include them in the negotiation offer without a + value or to an integer value to include them with this value. + + Args: + server_no_context_takeover: prevent server from using context takeover. + client_no_context_takeover: prevent client from using context takeover. + server_max_window_bits: maximum size of the server's LZ77 sliding window + in bits, between 8 and 15. + client_max_window_bits: maximum size of the client's LZ77 sliding window + in bits, between 8 and 15, or :obj:`True` to indicate support without + setting a limit. + compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + excluding ``wbits``. """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index a0ce1d350..9a97f2530 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -1,8 +1,3 @@ -""" -Parse and serialize WebSocket frames. - -""" - from __future__ import annotations import dataclasses @@ -104,15 +99,16 @@ class Frame: """ WebSocket frame. - :param int opcode: opcode - :param bytes data: payload data - :param bool fin: FIN bit - :param bool rsv1: RSV1 bit - :param bool rsv2: RSV2 bit - :param bool rsv3: RSV3 bit + Args: + opcode: opcode. + data: payload data. + fin: FIN bit. + rsv1: RSV1 bit. + rsv2: RSV2 bit. + rsv3: RSV3 bit. Only these fields are needed. The MASK bit, payload length and masking-key - are handled on the fly by :meth:`parse` and :meth:`serialize`. + are handled on the fly when parsing and serializing frames. """ @@ -176,22 +172,23 @@ def parse( mask: bool, max_size: Optional[int] = None, extensions: Optional[Sequence[extensions.Extension]] = None, - ) -> Generator[None, None, "Frame"]: + ) -> Generator[None, None, Frame]: """ - Read a WebSocket frame. - - :param read_exact: generator-based coroutine that reads the requested - number of bytes or raises an exception if there isn't enough data - :param mask: whether the frame should be masked i.e. whether the read - happens on the server side - :param max_size: maximum payload size in bytes - :param extensions: list of classes with a ``decode()`` method that - transforms the frame and return a new frame; extensions are applied - in reverse order - :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds - ``max_size`` - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values + Parse a WebSocket frame. + + This is a generator-based coroutine. + + Args: + read_exact: generator-based coroutine that reads the requested + bytes or raises an exception if there isn't enough data. + mask: whether the frame should be masked i.e. whether the read + happens on the server side. + max_size: maximum payload size in bytes. + extensions: list of extensions, applied in reverse order. + + Raises: + PayloadTooBig: if the frame's payload size exceeds ``max_size``. + ProtocolError: if the frame contains incorrect values. """ # Read the header. @@ -249,16 +246,15 @@ def serialize( extensions: Optional[Sequence[extensions.Extension]] = None, ) -> bytes: """ - Write a WebSocket frame. + Serialize a WebSocket frame. + + Args: + mask: whether the frame should be masked i.e. whether the write + happens on the client side. + extensions: list of extensions, applied in order. - :param frame: frame to write - :param mask: whether the frame should be masked i.e. whether the write - happens on the client side - :param extensions: list of classes with an ``encode()`` method that - transform the frame and return a new frame; extensions are applied - in order - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values + Raises: + ProtocolError: if the frame contains incorrect values. """ self.check() @@ -306,8 +302,8 @@ def check(self) -> None: """ Check that reserved bits and opcode have acceptable values. - :raises ~websockets.exceptions.ProtocolError: if a reserved - bit or the opcode is invalid + Raises: + ProtocolError: if a reserved bit or the opcode is invalid. """ if self.rsv1 or self.rsv2 or self.rsv3: @@ -332,7 +328,8 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like object. - :raises TypeError: if ``data`` doesn't have a supported type + Raises: + TypeError: if ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -354,7 +351,8 @@ def prepare_ctrl(data: Data) -> bytes: If ``data`` is a bytes-like object, return a :class:`bytes` object. - :raises TypeError: if ``data`` doesn't have a supported type + Raises: + TypeError: if ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -398,8 +396,12 @@ def parse(cls, data: bytes) -> Close: """ Parse the payload of a close frame. - :raises ~websockets.exceptions.ProtocolError: if data is ill-formed - :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + Args: + data: payload of the close frame. + + Raises: + ProtocolError: if data is ill-formed. + UnicodeDecodeError: if the reason isn't valid UTF-8. """ if len(data) >= 2: @@ -415,9 +417,7 @@ def parse(cls, data: bytes) -> Close: def serialize(self) -> bytes: """ - Serialize the payload for a close frame. - - This is the reverse of :meth:`parse`. + Serialize the payload of a close frame. """ self.check() @@ -427,8 +427,8 @@ def check(self) -> None: """ Check that the close code has a valid value for a close frame. - :raises ~websockets.exceptions.ProtocolError: if the close code - is invalid + Raises: + ProtocolError: if the close code is invalid. """ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): diff --git a/src/websockets/headers.py b/src/websockets/headers.py index ee6dd1672..a2fdfdd30 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -1,9 +1,3 @@ -""" -:mod:`websockets.headers` provides parsers and serializers for HTTP headers -used in WebSocket handshake messages. - -""" - from __future__ import annotations import base64 @@ -48,7 +42,7 @@ def peek_ahead(header: str, pos: int) -> Optional[str]: """ Return the next character from ``header`` at the given position. - Return ``None`` at the end of ``header``. + Return :obj:`None` at the end of ``header``. We never need to peek more than one character ahead. @@ -83,7 +77,8 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ match = _token_re.match(header, pos) @@ -106,7 +101,8 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i Return the unquoted value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ match = _quoted_string_re.match(header, pos) @@ -158,7 +154,8 @@ def parse_list( Return a list of items. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient @@ -211,7 +208,8 @@ def parse_connection_option( Return the protocol value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -224,8 +222,11 @@ def parse_connection(header: str) -> List[ConnectionOption]: Return a list of HTTP connection options. - :param header: value of the ``Connection`` header - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Args + header: value of the ``Connection`` header. + + Raises: + InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_connection_option, header, 0, "Connection") @@ -244,7 +245,8 @@ def parse_upgrade_protocol( Return the protocol value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ match = _protocol_re.match(header, pos) @@ -261,8 +263,11 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: Return a list of HTTP protocols. - :param header: value of the ``Upgrade`` header - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Args: + header: value of the ``Upgrade`` header. + + Raises: + InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") @@ -276,7 +281,8 @@ def parse_extension_item_param( Return a ``(name, value)`` pair and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ # Extract parameter name. @@ -312,7 +318,8 @@ def parse_extension_item( Return an ``(extension name, parameters)`` pair, where ``parameters`` is a list of ``(name, value)`` pairs, and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ # Extract extension name. @@ -344,9 +351,10 @@ def parse_extension(header: str) -> List[ExtensionHeader]: ... ] - Parameter values are ``None`` when no value is provided. + Parameter values are :obj:`None` when no value is provided. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") @@ -397,7 +405,8 @@ def parse_subprotocol_item( Return the subprotocol value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -410,7 +419,8 @@ def parse_subprotocol(header: str) -> List[Subprotocol]: Return a list of WebSocket subprotocols. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") @@ -450,7 +460,8 @@ def build_www_authenticate_basic(realm: str) -> str: """ Build a ``WWW-Authenticate`` header for HTTP Basic Auth. - :param realm: authentication realm + Args: + realm: identifier of the protection space. """ # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 @@ -468,7 +479,8 @@ def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. - :raises ~websockets.exceptions.InvalidHeaderFormat: on invalid inputs. + Raises: + InvalidHeaderFormat: on invalid inputs. """ match = _token68_re.match(header, pos) @@ -494,9 +506,12 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: Return a ``(username, password)`` tuple. - :param header: value of the ``Authorization`` header - :raises InvalidHeaderFormat: on invalid inputs - :raises InvalidHeaderValue: on unsupported inputs + Args: + header: value of the ``Authorization`` header. + + Raises: + InvalidHeaderFormat: on invalid inputs. + InvalidHeaderValue: on unsupported inputs. """ # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 diff --git a/src/websockets/http11.py b/src/websockets/http11.py index daa0efffb..b82a0bfdc 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -44,38 +44,46 @@ class Request: """ WebSocket handshake request. - :param path: path and optional query - :param headers: + Attributes: + path: Request path, including optional query. + headers: Request headers. + exception: If processing the response triggers an exception, + the exception is stored in this attribute. """ path: str headers: datastructures.Headers - # body isn't useful is the context of this library + # body isn't useful is the context of this library. - # If processing the request triggers an exception, it's stored here. exception: Optional[Exception] = None @classmethod def parse( - cls, read_line: Callable[[], Generator[None, None, bytes]] - ) -> Generator[None, None, "Request"]: + cls, + read_line: Callable[[], Generator[None, None, bytes]], + ) -> Generator[None, None, Request]: """ - Parse an HTTP/1.1 GET request and return ``(path, headers)``. + Parse a WebSocket handshake request. + + This is a generator-based coroutine. - ``path`` isn't URL-decoded or validated in any way. + The request path isn't URL-decoded or validated in any way. - ``path`` and ``headers`` are expected to contain only ASCII characters. - Other characters are represented with surrogate escapes. + The request path and headers are expected to contain only ASCII + characters. Other characters are represented with surrogate escapes. - :func:`parse_request` doesn't attempt to read the request body because + :meth:`parse` doesn't attempt to read the request body because WebSocket handshake requests don't have one. If the request contains a - body, it may be read from ``stream`` after this coroutine returns. + body, it may be read from the data stream after :meth:`parse` returns. - :param read_line: generator-based coroutine that reads a LF-terminated - line or raises an exception if there isn't enough data - :raises EOFError: if the connection is closed without a full HTTP request - :raises exceptions.SecurityError: if the request exceeds a security limit - :raises ValueError: if the request isn't well formatted + Args: + read_line: generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data + + Raises: + EOFError: if the connection is closed without a full HTTP request. + SecurityError: if the request exceeds a security limit. + ValueError: if the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -114,10 +122,10 @@ def parse( def serialize(self) -> bytes: """ - Serialize an HTTP/1.1 GET request. + Serialize a WebSocket handshake request. """ - # Since the path and headers only contain ASCII characters, + # Since the request line and headers only contain ASCII characters, # we can keep this simple. request = f"GET {self.path} HTTP/1.1\r\n".encode() request += self.headers.serialize() @@ -129,6 +137,14 @@ class Response: """ WebSocket handshake response. + Attributes: + status_code: Response code. + reason_phrase: Response reason. + headers: Response headers. + body: Response body, if any. + exception: if processing the response triggers an exception, + the exception is stored in this attribute. + """ status_code: int @@ -136,7 +152,6 @@ class Response: headers: datastructures.Headers body: Optional[bytes] = None - # If processing the response triggers an exception, it's stored here. exception: Optional[Exception] = None @classmethod @@ -145,32 +160,32 @@ def parse( read_line: Callable[[], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], read_to_eof: Callable[[], Generator[None, None, bytes]], - ) -> Generator[None, None, "Response"]: + ) -> Generator[None, None, Response]: """ - Parse an HTTP/1.1 response and return ``(status_code, reason, headers)``. + Parse a WebSocket handshake response. - ``reason`` and ``headers`` are expected to contain only ASCII characters. - Other characters are represented with surrogate escapes. + This is a generator-based coroutine. - :func:`parse_request` doesn't attempt to read the response body because - WebSocket handshake responses don't have one. If the response contains a - body, it may be read from ``stream`` after this coroutine returns. + The reason phrase and headers are expected to contain only ASCII + characters. Other characters are represented with surrogate escapes. - :param read_line: generator-based coroutine that reads a LF-terminated - line or raises an exception if there isn't enough data - :param read_exact: generator-based coroutine that reads the requested - number of bytes or raises an exception if there isn't enough data - :raises EOFError: if the connection is closed without a full HTTP response - :raises exceptions.SecurityError: if the response exceeds a security limit - :raises LookupError: if the response isn't well formatted - :raises ValueError: if the response isn't well formatted + Args: + read_line: generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data. + read_exact: generator-based coroutine that reads the requested + bytes or raises an exception if there isn't enough data. + read_to_eof: generator-based coroutine that reads until the end + of the strem. + + Raises: + EOFError: if the connection is closed without a full HTTP response. + SecurityError: if the response exceeds a security limit. + LookupError: if the response isn't well formatted. + ValueError: if the response isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 - # As in parse_request, parsing is simple because a fixed value is expected - # for version, status_code is a 3-digit number, and reason can be ignored. - try: status_line = yield from parse_line(read_line) except EOFError as exc: @@ -227,7 +242,7 @@ def parse( def serialize(self) -> bytes: """ - Serialize an HTTP/1.1 GET response. + Serialize a WebSocket handshake response. """ # Since the status line and headers only contain ASCII characters, @@ -240,15 +255,21 @@ def serialize(self) -> bytes: def parse_headers( - read_line: Callable[[], Generator[None, None, bytes]] + read_line: Callable[[], Generator[None, None, bytes]], ) -> Generator[None, None, datastructures.Headers]: """ Parse HTTP headers. Non-ASCII characters are represented with surrogate escapes. - :param read_line: generator-based coroutine that reads a LF-terminated - line or raises an exception if there isn't enough data + Args: + read_line: generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: if the connection is closed without complete headers. + SecurityError: if the request exceeds a security limit. + ValueError: if the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 @@ -285,15 +306,20 @@ def parse_headers( def parse_line( - read_line: Callable[[], Generator[None, None, bytes]] + read_line: Callable[[], Generator[None, None, bytes]], ) -> Generator[None, None, bytes]: """ Parse a single line. CRLF is stripped from the return value. - :param read_line: generator-based coroutine that reads a LF-terminated - line or raises an exception if there isn't enough data + Args: + read_line: generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: if the connection is closed without a CRLF. + SecurityError: if the response exceeds a security limit. """ # Security: TODO: add a limit here diff --git a/src/websockets/imports.py b/src/websockets/imports.py index c9508d188..a6a59d4c2 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -9,15 +9,15 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: """ - Import from in . + Import ``name`` from ``source`` in ``namespace``. - There are two cases: + There are two use cases: - - is an object defined in - - is a submodule of source + - ``name`` is an object defined in ``source``; + - ``name`` is a submodule of ``source``. - Neither __import__ nor importlib.import_module does exactly this. - __import__ is closer to the intended behavior. + Neither :func:`__import__` nor :func:`~importlib.import_module` does + exactly this. :func:`__import__` is closer to the intended behavior. """ level = 0 @@ -49,7 +49,7 @@ def lazy_import( } ) - This function defines __getattr__ and __dir__ per PEP 562. + This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`. """ if aliases is None: diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 5f2b1311a..8825c14ec 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -1,9 +1,3 @@ -""" -:mod:`websockets.legacy.auth` provides HTTP Basic Authentication according to -:rfc:`7235` and :rfc:`7617`. - -""" - from __future__ import annotations import functools @@ -37,7 +31,16 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): """ - realm = "" + realm: str = "" + """ + Scope of protection. + + If provided, it should contain only ASCII characters because the + encoding of non-ASCII characters is undefined. + """ + + username: Optional[str] = None + """Username of the authenticated user.""" def __init__( self, @@ -55,13 +58,17 @@ async def check_credentials(self, username: str, password: str) -> bool: """ Check whether credentials are authorized. - If ``check_credentials`` returns ``True``, the WebSocket handshake - continues. If it returns ``False``, the handshake fails with a HTTP - 401 error. - This coroutine may be overridden in a subclass, for example to authenticate against a database or an external service. + Args: + username: HTTP Basic Auth username. + password: HTTP Basic Auth password. + + Returns: + bool: :obj:`True` if the handshake should continue; + :obj:`False` if it should fail with a HTTP 401 error. + """ if self._check_credentials is not None: return await self._check_credentials(username, password) @@ -116,8 +123,8 @@ def basic_auth_protocol_factory( """ Protocol factory that enforces HTTP Basic Auth. - ``basic_auth_protocol_factory`` is designed to integrate with - :func:`~websockets.legacy.server.serve` like this:: + :func:`basic_auth_protocol_factory` is designed to integrate with + :func:`~websockets.server.serve` like this:: websockets.serve( ..., @@ -127,28 +134,22 @@ def basic_auth_protocol_factory( ) ) - ``realm`` indicates the scope of protection. It should contain only ASCII - characters because the encoding of non-ASCII characters is undefined. - Refer to section 2.2 of :rfc:`7235` for details. - - ``credentials`` defines hard coded authorized credentials. It can be a - ``(username, password)`` pair or a list of such pairs. - - ``check_credentials`` defines a coroutine that checks whether credentials - are authorized. This coroutine receives ``username`` and ``password`` - arguments and returns a :class:`bool`. - - One of ``credentials`` or ``check_credentials`` must be provided but not - both. - - By default, ``basic_auth_protocol_factory`` creates a factory for building - :class:`BasicAuthWebSocketServerProtocol` instances. You can override this - with the ``create_protocol`` parameter. - - :param realm: scope of protection - :param credentials: hard coded credentials - :param check_credentials: coroutine that verifies credentials - :raises TypeError: if the credentials argument has the wrong type + Args: + realm: indicates the scope of protection. It should contain only ASCII + characters because the encoding of non-ASCII characters is + undefined. Refer to section 2.2 of :rfc:`7235` for details. + credentials: defines hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: defines a coroutine that verifies credentials. + This coroutine receives ``username`` and ``password`` arguments + and returns a :class:`bool`. One of ``credentials`` or + ``check_credentials`` must be provided but not both. + create_protocol: factory that creates the protocol. By default, this + is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced + by a subclass. + Raises: + TypeError: if the ``credentials`` or ``check_credentials`` argument is + wrong. """ if (credentials is None) == (check_credentials is None): diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 6d976e0df..e5743cc0e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -1,8 +1,3 @@ -""" -:mod:`websockets.legacy.client` defines the WebSocket client APIs. - -""" - from __future__ import annotations import asyncio @@ -57,99 +52,27 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): """ - :class:`~asyncio.Protocol` subclass implementing a WebSocket client. - - :class:`WebSocketClientProtocol`: + WebSocket client connection. - * performs the opening handshake to establish the connection; - * provides :meth:`recv` and :meth:`send` coroutines for receiving and - sending messages; - * deals with control frames automatically; - * performs the closing handshake to terminate the connection. + :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send` + coroutines for receiving and sending messages. - :class:`WebSocketClientProtocol` supports asynchronous iteration:: + It supports asynchronous iteration to receive incoming messages:: async for message in websocket: await process(message) - The iterator yields incoming messages. It exits normally when the - connection is closed with the close code 1000 (OK) or 1001 (going away). - It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other code. - - Once the connection is open, a `Ping frame`_ is sent every - ``ping_interval`` seconds. This serves as a keepalive. It helps keeping - the connection open, especially in the presence of proxies with short - timeouts on inactive connections. Set ``ping_interval`` to ``None`` to - disable this behavior. - - .. _Ping frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 - - If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with - code 1011. This ensures that the remote endpoint remains responsive. Set - ``ping_timeout`` to ``None`` to disable this behavior. - - .. _Pong frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 - - The ``close_timeout`` parameter defines a maximum wait time for completing - the closing handshake and terminating the TCP connection. For legacy - reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds. - - ``close_timeout`` needs to be a parameter of the protocol because - websockets usually calls :meth:`close` implicitly upon exit when - :func:`connect` is used as a context manager. - - To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. - - The ``max_size`` parameter enforces the maximum size for incoming messages - in bytes. The default value is 1 MiB. ``None`` disables the limit. If a - message larger than the maximum size is received, :meth:`recv` will - raise :exc:`~websockets.exceptions.ConnectionClosedError` and the - connection will be closed with code 1009. - - The ``max_queue`` parameter sets the maximum length of the queue that - holds incoming messages. The default value is ``32``. ``None`` disables - the limit. Messages are added to an in-memory queue when they're received; - then :meth:`recv` pops from that queue. In order to prevent excessive - memory consumption when messages are received faster than they can be - processed, the queue must be bounded. If the queue fills up, the protocol - stops processing incoming data until :meth:`recv` is called. In this - situation, various receive buffers (at least in :mod:`asyncio` and in the - OS) will fill up, then the TCP receive window will shrink, slowing down - transmission to avoid packet loss. - - Since Python can use up to 4 bytes of memory to represent a single - character, each connection may use up to ``4 * max_size * max_queue`` - bytes of memory to store incoming messages. By default, this is 128 MiB. - You may want to lower the limits, depending on your application's - requirements. - - The ``read_limit`` argument sets the high-water limit of the buffer for - incoming bytes. The low-water limit is half the high-water limit. The - default value is 64 KiB, half of asyncio's default (based on the current - implementation of :class:`~asyncio.StreamReader`). - - The ``write_limit`` argument sets the high-water limit of the buffer for - outgoing bytes. The low-water limit is a quarter of the high-water limit. - The default value is 64 KiB, equal to asyncio's default (based on the - current implementation of ``FlowControlMixin``). - - As soon as the HTTP request and response in the opening handshake are - processed: - - * the request path is available in the :attr:`path` attribute; - * the request and response HTTP headers are available in the - :attr:`request_headers` and :attr:`response_headers` attributes, - which are :class:`~websockets.http.Headers` instances. - - If a subprotocol was negotiated, it's available in the :attr:`subprotocol` - attribute. - - Once the connection is closed, the code is available in the - :attr:`close_code` attribute and the reason in :attr:`close_reason`. - - All attributes must be treated as read-only. + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away). It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection + is closed with any other code. + + See :func:`connect` for the documentation of ``logger``, ``origin``, + ``extensions``, ``subprotocols``, and ``extra_headers``. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. """ @@ -159,11 +82,11 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__( self, *, + logger: Optional[LoggerLike] = None, origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, - logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: if logger is None: @@ -201,8 +124,9 @@ async def read_http_response(self) -> Tuple[int, Headers]: If the response contains a body, it may be read from ``self.reader`` after this coroutine returns. - :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is - malformed or isn't an HTTP/1.1 GET response + Raises: + InvalidMessage: if the HTTP message is malformed or isn't an + HTTP/1.1 GET response. """ try: @@ -346,17 +270,17 @@ async def handshake( """ Perform the client side of the opening handshake. - :param origin: sets the Origin HTTP header - :param available_extensions: list of supported extensions in the order - in which they should be used - :param available_subprotocols: list of supported subprotocols in order - of decreasing preference - :param extra_headers: sets additional HTTP request headers; it must be - a :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, - value)`` pairs - :raises ~websockets.exceptions.InvalidHandshake: if the handshake - fails + Args: + wsuri: URI of the WebSocket server. + origin: value of the ``Origin`` header. + available_extensions: list of supported extensions, in order in + which they should be tried. + available_subprotocols: list of supported subprotocols, in order + of decreasing preference. + extra_headers: arbitrary HTTP headers to add to the request. + + Raises: + InvalidHandshake: if the handshake fails. """ request_headers = Headers() @@ -416,14 +340,14 @@ async def handshake( class Connect: """ - Connect to the WebSocket server at the given ``uri``. + Connect to the WebSocket server at ``uri``. Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which can then be used to send and receive messages. :func:`connect` can be used as a asynchronous context manager:: - async with connect(...) as websocket: + async with websockets.connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -431,59 +355,71 @@ class Connect: :func:`connect` can be used as an infinite asynchronous iterator to reconnect automatically on errors:: - async for websocket in connect(...): - ... + async for websocket in websockets.connect(...): + try: + ... + except websockets.ConnectionClosed: + continue - You must catch all exceptions, or else you will exit the loop prematurely. - As above, connections are closed automatically. Connection attempts are - delayed with exponential backoff, starting at three seconds and - increasing up to one minute. - - :func:`connect` is a wrapper around the event loop's - :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments - are passed to :meth:`~asyncio.loop.create_connection`. - - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to - a ``wss://`` URI, if this argument isn't provided explicitly, - :func:`ssl.create_default_context` is called to create a context. - - You can connect to a different host and port from those found in ``uri`` - by setting ``host`` and ``port`` keyword arguments. This only changes the - destination of the TCP connection. The host name from ``uri`` is still - used in the TLS handshake for secure connections and in the ``Host`` HTTP - header. - - ``create_protocol`` defaults to :class:`WebSocketClientProtocol`. It may - be replaced by a wrapper or a subclass to customize the protocol that - manages the connection. - - If the WebSocket connection isn't established within ``open_timeout`` - seconds, :func:`connect` raises :exc:`~asyncio.TimeoutError`. The default - is 10 seconds. Set ``open_timeout`` to ``None`` to disable the timeout. - - The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, - ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is - described in :class:`WebSocketClientProtocol`. - - :func:`connect` also accepts the following optional arguments: - - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression. - * ``origin`` sets the Origin HTTP header. - * ``extensions`` is a list of supported extensions in order of - decreasing preference. - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference. - * ``extra_headers`` sets additional HTTP request headers; it can be a - :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` - pairs. - - :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid - :raises ~websockets.handshake.InvalidHandshake: if the opening handshake - fails + The connection is closed automatically after each iteration of the loop. + + If an error occurs while establishing the connection, :func:`connect` + retries with exponential backoff. The backoff delay starts at three + seconds and increases up to one minute. + + If an error occurs in the body of the loop, you can handle the exception + and :func:`connect` will reconnect with the next iteration; or you can + let the exception bubble up and break out of the loop. This lets you + decide which errors trigger a reconnection and which errors are fatal. + + Args: + uri: URI of the WebSocket server. + create_protocol: factory for the :class:`asyncio.Protocol` managing + the connection; defaults to :class:`WebSocketClientProtocol`; may + be set to a wrapper or a subclass to customize connection handling. + logger: logger for this connection; + defaults to ``logging.getLogger("websockets.client")``; + see the :doc:`logging guide <../topics/logging>` for details. + compression: shortcut that enables the "permessage-deflate" extension + by default; may be set to :obj:`None` to disable compression; + see the :doc:`compression guide <../topics/compression>` for details. + origin: value of the ``Origin`` header. This is useful when connecting + to a server that validates the ``Origin`` header to defend against + Cross-Site WebSocket Hijacking attacks. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of decreasing + preference. + extra_headers: arbitrary HTTP headers to add to the request. + open_timeout: timeout for opening the connection in seconds; + :obj:`None` to disable the timeout + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + Any other keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS + settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't + provided, a TLS context is created + with :func:`~ssl.create_default_context`. + + * You can set ``host`` and ``port`` to connect to a different host and + port from those found in ``uri``. This only changes the destination of + the TCP connection. The host name from ``uri`` is still used in the TLS + handshake for secure connections and in the ``Host`` header. + + Returns: + WebSocketClientProtocol: WebSocket connection. + + Raises: + InvalidURI: if ``uri`` isn't a valid WebSocket URI. + InvalidHandshake: if the opening handshake fails. + ~asyncio.TimeoutError: if the opening handshake times out. """ @@ -494,6 +430,12 @@ def __init__( uri: str, *, create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None, + logger: Optional[LoggerLike] = None, + compression: Optional[str] = "deflate", + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -502,12 +444,6 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - compression: Optional[str] = "deflate", - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -560,6 +496,11 @@ def __init__( factory = functools.partial( create_protocol, + logger=logger, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, @@ -567,16 +508,11 @@ def __init__( max_queue=max_queue, read_limit=read_limit, write_limit=write_limit, - loop=_loop, host=wsuri.host, port=wsuri.port, secure=wsuri.secure, legacy_recv=legacy_recv, - origin=origin, - extensions=extensions, - subprotocols=subprotocols, - extra_headers=extra_headers, - logger=logger, + loop=_loop, ) if kwargs.pop("unix", False): @@ -743,15 +679,17 @@ def unix_connect( """ Similar to :func:`connect`, but for connecting to a Unix socket. - This function calls the event loop's + This function builds upon the event loop's :meth:`~asyncio.loop.create_unix_connection` method. It is only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. - :param path: file system path to the Unix socket - :param uri: WebSocket URI + Args: + path: file system path to the Unix socket. + uri: URI of the WebSocket server; the host is used in the TLS + handshake for secure connections and in the ``Host`` header. """ return connect(uri=uri, path=path, unix=True, **kwargs) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 40cbd41bf..c4de7eb28 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,15 +1,3 @@ -""" -:mod:`websockets.legacy.framing` reads and writes WebSocket frames. - -It deals with a single frame at a time. Anything that depends on the sequence -of frames is implemented in :mod:`websockets.legacy.protocol`. - -See `section 5 of RFC 6455`_. - -.. _section 5 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 - -""" - from __future__ import annotations import dataclasses @@ -60,22 +48,21 @@ async def read( mask: bool, max_size: Optional[int] = None, extensions: Optional[Sequence[extensions.Extension]] = None, - ) -> "Frame": + ) -> Frame: """ Read a WebSocket frame. - :param reader: coroutine that reads exactly the requested number of - bytes, unless the end of file is reached - :param mask: whether the frame should be masked i.e. whether the read - happens on the server side - :param max_size: maximum payload size in bytes - :param extensions: list of classes with a ``decode()`` method that - transforms the frame and return a new frame; extensions are applied - in reverse order - :raises ~websockets.exceptions.PayloadTooBig: if the frame exceeds - ``max_size`` - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values + Args: + reader: coroutine that reads exactly the requested number of + bytes, unless the end of file is reached. + mask: whether the frame should be masked i.e. whether the read + happens on the server side. + max_size: maximum payload size in bytes. + extensions: list of extensions, applied in reverse order. + + Raises: + PayloadTooBig: if the frame exceeds ``max_size``. + ProtocolError: if the frame contains incorrect values. """ @@ -142,15 +129,15 @@ def write( """ Write a WebSocket frame. - :param frame: frame to write - :param write: function that writes bytes - :param mask: whether the frame should be masked i.e. whether the write - happens on the client side - :param extensions: list of classes with an ``encode()`` method that - transform the frame and return a new frame; extensions are applied - in order - :raises ~websockets.exceptions.ProtocolError: if the frame - contains incorrect values + Args: + frame: frame to write. + write: function that writes bytes. + mask: whether the frame should be masked i.e. whether the write + happens on the client side. + extensions: list of extensions, applied in order. + + Raises: + ProtocolError: if the frame contains incorrect values. """ # The frame is written in a single call to write in order to prevent @@ -168,10 +155,12 @@ def parse_close(data: bytes) -> Tuple[int, str]: """ Parse the payload from a close frame. - Return ``(code, reason)``. + Returns: + Tuple[int, str]: close code and reason. - :raises ~websockets.exceptions.ProtocolError: if data is ill-formed - :raises UnicodeDecodeError: if the reason isn't valid UTF-8 + Raises: + ProtocolError: if data is ill-formed. + UnicodeDecodeError: if the reason isn't valid UTF-8. """ return dataclasses.astuple(Close.parse(data)) # type: ignore @@ -181,7 +170,5 @@ def serialize_close(code: int, reason: str) -> bytes: """ Serialize the payload for a close frame. - This is the reverse of :func:`parse_close`. - """ return Close(code, reason).serialize() diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 7cde58ac1..569937bb9 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -1,30 +1,3 @@ -""" -:mod:`websockets.legacy.handshake` provides helpers for the WebSocket handshake. - -See `section 4 of RFC 6455`_. - -.. _section 4 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 - -Some checks cannot be performed because they depend too much on the -context; instead, they're documented below. - -To accept a connection, a server must: - -- Read the request, check that the method is GET, and check the headers with - :func:`check_request`, -- Send a 101 response to the client with the headers created by - :func:`build_response` if the request is valid; otherwise, send an - appropriate HTTP error code. - -To open a connection, a client must: - -- Send a GET request to the server with the headers created by - :func:`build_request`, -- Read the response, check that the status code is 101, and check the headers - with :func:`check_response`. - -""" - from __future__ import annotations import base64 @@ -47,8 +20,11 @@ def build_request(headers: Headers) -> str: Update request headers passed in argument. - :param headers: request headers - :returns: ``key`` which must be passed to :func:`check_response` + Args: + headers: handshake request headers. + + Returns: + str: ``key`` that must be passed to :func:`check_response`. """ key = generate_key() @@ -68,10 +44,15 @@ def check_request(headers: Headers) -> str: are usually performed earlier in the HTTP request handling code. They're the responsibility of the caller. - :param headers: request headers - :returns: ``key`` which must be passed to :func:`build_response` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake request - is invalid; then the server must return 400 Bad Request error + Args: + headers: handshake request headers. + + Returns: + str: ``key`` that must be passed to :func:`build_response`. + + Raises: + InvalidHandshake: if the handshake request is invalid; + then the server must return 400 Bad Request error. """ connection: List[ConnectionOption] = sum( @@ -128,8 +109,9 @@ def build_response(headers: Headers, key: str) -> None: Update response headers passed in argument. - :param headers: response headers - :param key: comes from :func:`check_request` + Args: + headers: handshake response headers. + key: returned by :func:`check_request`. """ headers["Upgrade"] = "websocket" @@ -145,10 +127,12 @@ def check_response(headers: Headers, key: str) -> None: response with a 101 status code. These controls are the responsibility of the caller. - :param headers: response headers - :param key: comes from :func:`build_request` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake response - is invalid + Args: + headers: handshake response headers. + key: returned by :func:`build_request`. + + Raises: + InvalidHandshake: if the handshake response is invalid. """ connection: List[ConnectionOption] = sum( diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 3725fa9c3..cc2ef1f06 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -55,10 +55,13 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: WebSocket handshake requests don't have one. If the request contains a body, it may be read from ``stream`` after this coroutine returns. - :param stream: input to read the request from - :raises EOFError: if the connection is closed without a full HTTP request - :raises SecurityError: if the request exceeds a security limit - :raises ValueError: if the request isn't well formatted + Args: + stream: input to read the request from + + Raises: + EOFError: if the connection is closed without a full HTTP request + SecurityError: if the request exceeds a security limit + ValueError: if the request isn't well formatted """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -99,10 +102,13 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers WebSocket handshake responses don't have one. If the response contains a body, it may be read from ``stream`` after this coroutine returns. - :param stream: input to read the response from - :raises EOFError: if the connection is closed without a full HTTP response - :raises SecurityError: if the response exceeds a security limit - :raises ValueError: if the response isn't well formatted + Args: + stream: input to read the response from + + Raises: + EOFError: if the connection is closed without a full HTTP response + SecurityError: if the response exceeds a security limit + ValueError: if the response isn't well formatted """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 99a821be6..4631151e6 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1,12 +1,3 @@ -""" -:mod:`websockets.legacy.protocol` handles WebSocket control and data frames. - -See `sections 4 to 8 of RFC 6455`_. - -.. _sections 4 to 8 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 - -""" - from __future__ import annotations import asyncio @@ -71,17 +62,99 @@ class WebSocketCommonProtocol(asyncio.Protocol): """ - :class:`~asyncio.Protocol` subclass implementing the data transfer phase. - - Once the WebSocket connection is established, during the data transfer - phase, the protocol is almost symmetrical between the server side and the - client side. :class:`WebSocketCommonProtocol` implements logic that's - shared between servers and clients. - - Subclasses such as - :class:`~websockets.legacy.server.WebSocketServerProtocol` and - :class:`~websockets.legacy.client.WebSocketClientProtocol` implement the - opening handshake, which is different between servers and clients. + WebSocket connection. + + :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket + servers and clients. You shouldn't use it directly. Instead, use + :class:`~websockets.client.WebSocketClientProtocol` or + :class:`~websockets.server.WebSocketServerProtocol`. + + This documentation focuses on low-level details that aren't covered in the + documentation of :class:`~websockets.client.WebSocketClientProtocol` and + :class:`~websockets.server.WebSocketServerProtocol` for the sake of + simplicity. + + Once the connection is open, a Ping_ frame is sent every ``ping_interval`` + seconds. This serves as a keepalive. It helps keeping the connection + open, especially in the presence of proxies with short timeouts on + inactive connections. Set ``ping_interval`` to :obj:`None` to disable + this behavior. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + If the corresponding Pong_ frame isn't received within ``ping_timeout`` + seconds, the connection is considered unusable and is closed with code + 1011. This ensures that the remote endpoint remains responsive. Set + ``ping_timeout`` to :obj:`None` to disable this behavior. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + The ``close_timeout`` parameter defines a maximum wait time for completing + the closing handshake and terminating the TCP connection. For legacy + reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds + for clients and ``4 * close_timeout`` for servers. + + See the discussion of :doc:`timeouts <../topics/timeouts>` for details. + + ``close_timeout`` needs to be a parameter of the protocol because + websockets usually calls :meth:`close` implicitly upon exit: + + * on the client side, when :func:`~websockets.client.connect` is used as a + context manager; + * on the server side, when the connection handler terminates; + + To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. + + The ``max_size`` parameter enforces the maximum size for incoming messages + in bytes. The default value is 1 MiB. If a larger message is received, + :meth:`recv` will raise :exc:`~websockets.exceptions.ConnectionClosedError` + and the connection will be closed with code 1009. + + The ``max_queue`` parameter sets the maximum length of the queue that + holds incoming messages. The default value is ``32``. Messages are added + to an in-memory queue when they're received; then :meth:`recv` pops from + that queue. In order to prevent excessive memory consumption when + messages are received faster than they can be processed, the queue must + be bounded. If the queue fills up, the protocol stops processing incoming + data until :meth:`recv` is called. In this situation, various receive + buffers (at least in :mod:`asyncio` and in the OS) will fill up, then the + TCP receive window will shrink, slowing down transmission to avoid packet + loss. + + Since Python can use up to 4 bytes of memory to represent a single + character, each connection may use up to ``4 * max_size * max_queue`` + bytes of memory to store incoming messages. By default, this is 128 MiB. + You may want to lower the limits, depending on your application's + requirements. + + The ``read_limit`` argument sets the high-water limit of the buffer for + incoming bytes. The low-water limit is half the high-water limit. The + default value is 64 KiB, half of asyncio's default (based on the current + implementation of :class:`~asyncio.StreamReader`). + + The ``write_limit`` argument sets the high-water limit of the buffer for + outgoing bytes. The low-water limit is a quarter of the high-water limit. + The default value is 64 KiB, equal to asyncio's default (based on the + current implementation of ``FlowControlMixin``). + + See the discussion of :doc:`memory usage <../topics/memory>` for details. + + Args: + logger: logger for this connection; + defaults to ``logging.getLogger("websockets.protocol")``; + see the :doc:`logging guide <../topics/logging>` for details. + ping_interval: delay between keepalive pings in seconds; + :obj:`None` to disable keepalive pings. + ping_timeout: timeout for keepalive pings in seconds; + :obj:`None` to disable timeouts. + close_timeout: timeout for closing the connection in seconds; + for legacy reasons, the actual timeout is 4 or 5 times larger. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + max_queue: maximum number of incoming messages in receive buffer; + :obj:`None` to disable the limit. + read_limit: high-water mark of read buffer in bytes. + write_limit: high-water mark of write buffer in bytes. """ @@ -94,6 +167,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): def __init__( self, *, + logger: Optional[LoggerLike] = None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, @@ -101,14 +175,13 @@ def __init__( max_queue: Optional[int] = 2 ** 5, read_limit: int = 2 ** 16, write_limit: int = 2 ** 16, - logger: Optional[LoggerLike] = None, # The following arguments are kept only for backwards compatibility. host: Optional[str] = None, port: Optional[int] = None, secure: Optional[bool] = None, - timeout: Optional[float] = None, legacy_recv: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, + timeout: Optional[float] = None, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: @@ -134,7 +207,8 @@ def __init__( self.write_limit = write_limit # Unique identifier. For logs. - self.id = uuid.uuid4() + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" # Logger or LoggerAdapter for this connection. if logger is None: @@ -175,12 +249,16 @@ def __init__( # HTTP protocol parameters. self.path: str + """Path of the opening handshake request.""" self.request_headers: Headers + """Opening handshake request headers.""" self.response_headers: Headers + """Opening handshake response headers.""" # WebSocket protocol parameters. self.extensions: List[Extension] = [] self.subprotocol: Optional[Subprotocol] = None + """Subprotocol, if one was negotiated.""" # Close code and reason, set when a close frame is sent or received. self.close_rcvd: Optional[Close] = None @@ -286,9 +364,14 @@ def secure(self) -> Optional[bool]: @property def local_address(self) -> Any: """ - Local address of the connection as a ``(host, port)`` tuple. + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family; + see :meth:`~socket.socket.getsockname`. - When the connection isn't open, ``local_address`` is ``None``. + :obj:`None` if the TCP connection isn't established yet. """ try: @@ -301,9 +384,14 @@ def local_address(self) -> Any: @property def remote_address(self) -> Any: """ - Remote address of the connection as a ``(host, port)`` tuple. + Remote address of the connection. - When the connection isn't open, ``remote_address`` is ``None``. + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family; + see :meth:`~socket.socket.getpeername`. + + :obj:`None` if the TCP connection isn't established yet. """ try: @@ -316,13 +404,11 @@ def remote_address(self) -> Any: @property def open(self) -> bool: """ - ``True`` when the connection is usable. + :obj:`True` when the connection is open; :obj:`False` otherwise. - It may be used to detect disconnections. However, this approach is - discouraged per the EAFP_ principle. - - When ``open`` is ``False``, using the connection raises a - :exc:`~websockets.exceptions.ConnectionClosed` exception. + This attribute may be used to detect disconnections. However, this + approach is discouraged per the EAFP_ principle. Instead, you should + handle :exc:`~websockets.exceptions.ConnectionClosed` exceptions. .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp @@ -332,10 +418,10 @@ def open(self) -> bool: @property def closed(self) -> bool: """ - ``True`` once the connection is closed. + :obj:`True` when the connection is closed; :obj:`False` otherwise. - Be aware that both :attr:`open` and :attr:`closed` are ``False`` during - the opening and closing sequences. + Be aware that both :attr:`open` and :attr:`closed` are :obj:`False` + during the opening and closing sequences. """ return self.state is State.CLOSED @@ -343,9 +429,12 @@ def closed(self) -> bool: @property def close_code(self) -> Optional[int]: """ - WebSocket close code received in a close frame. + WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. + + .. _section 7.1.5 of RFC 6455: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 - Available once the connection is closed. + :obj:`None` if the connection isn't closed yet. """ if self.state is not State.CLOSED: @@ -358,9 +447,12 @@ def close_code(self) -> Optional[int]: @property def close_reason(self) -> Optional[str]: """ - WebSocket close reason received in a close frame. + WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. - Available once the connection is closed. + .. _section 7.1.6 of RFC 6455: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + + :obj:`None` if the connection isn't closed yet. """ if self.state is not State.CLOSED: @@ -370,25 +462,14 @@ def close_reason(self) -> Optional[str]: else: return self.close_rcvd.reason - async def wait_closed(self) -> None: - """ - Wait until the connection is closed. - - This is identical to :attr:`closed`, except it can be awaited. - - This can make it easier to handle connection termination, regardless - of its cause, in tasks that interact with the WebSocket connection. - - """ - await asyncio.shield(self.connection_lost_waiter) - async def __aiter__(self) -> AsyncIterator[Data]: """ - Iterate on received messages. - - Exit normally when the connection is closed with code 1000 or 1001. + Iterate on incoming messages. - Raise an exception in other cases. + The iterator exits normally when the connection is closed with the + close code 1000 (OK) or 1001(going away). It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` exception when + the connection is closed with any other code. """ try: @@ -401,24 +482,30 @@ async def recv(self) -> Data: """ Receive the next message. - Return a :class:`str` for a text frame and :class:`bytes` for a binary - frame. - - When the end of the message stream is reached, :meth:`recv` raises + When the connection is closed, :meth:`recv` raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol - error or a network failure. + error or a network failure. This is how you detect the end of the + message stream. Canceling :meth:`recv` is safe. There's no risk of losing the next - message. The next invocation of :meth:`recv` will return it. This - makes it possible to enforce a timeout by wrapping :meth:`recv` in - :func:`~asyncio.wait_for`. + message. The next invocation of :meth:`recv` will return it. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` + in :func:`~asyncio.wait_for`. - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed - :raises RuntimeError: if two coroutines call :meth:`recv` concurrently + Returns: + Data: A string (:class:`str`) for a Text_ frame. A bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: when the connection is closed. + RuntimeError: if two coroutines call :meth:`recv` concurrently. """ if self._pop_message_waiter is not None: @@ -471,43 +558,58 @@ async def recv(self) -> Data: return message async def send( - self, message: Union[Data, Iterable[Data], AsyncIterable[Data]] + self, + message: Union[Data, Iterable[Data], AsyncIterable[Data]], ) -> None: """ Send a message. - A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or - :class:`memoryview`) is sent as a `Binary frame`_. + :class:`memoryview`) is sent as a Binary_ frame. - .. _Text frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of - strings, bytestrings, or bytes-like objects. In that case the message - is fragmented. Each item is treated as a message fragment and sent in - its own frame. All items must be of the same type, or else - :meth:`send` will raise a :exc:`TypeError` and the connection will be - closed. + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. - If you wish to send the keys of a dict-like object as fragments, call - its :meth:`~dict.keys` method and pass the result to :meth:`send`. + (If you want to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`.) Canceling :meth:`send` is discouraged. Instead, you should close the connection with :meth:`close`. Indeed, there are only two situations - where :meth:`send` may yield control to the event loop: + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: 1. The write buffer is full. If you don't want to wait until enough data is sent, your only alternative is to close the connection. :meth:`close` will likely time out then abort the TCP connection. 2. ``message`` is an asynchronous iterator that yields control. Stopping in the middle of a fragmented message will cause a - protocol error. Closing the connection has the same effect. + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed - :raises TypeError: if ``message`` doesn't have a supported type + Args: + message (Union[Data, Iterable[Data], AsyncIterable[Data]): message + to send. + + Raises: + ConnectionClosed: when the connection is closed. + TypeError: if ``message`` doesn't have a supported type. """ await self.ensure_open() @@ -621,7 +723,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: :meth:`close` waits for the other end to complete the handshake and for the TCP connection to terminate. As a consequence, there's no need - to await :meth:`wait_closed`; :meth:`close` already does it. + to await :meth:`wait_closed` after :meth:`close`. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. @@ -631,10 +733,12 @@ async def close(self, code: int = 1000, reason: str = "") -> None: Canceling :meth:`close` is discouraged. If it takes too long, you can set a shorter ``close_timeout``. If you don't want to wait, let the - Python process exit, then the OS will close the TCP connection. + Python process exit, then the OS will take care of closing the TCP + connection. - :param code: WebSocket close code - :param reason: WebSocket close reason + Args: + code: WebSocket close code. + reason: WebSocket close reason. """ try: @@ -653,7 +757,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # the data transfer task and raises TimeoutError. # If close() is called multiple times concurrently and one of these - # calls hits the timeout, the data transfer task will be cancelled. + # calls hits the timeout, the data transfer task will be canceled. # Other calls will receive a CancelledError here. try: @@ -670,23 +774,27 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # Wait for the close connection task to close the TCP connection. await asyncio.shield(self.close_connection_task) - async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + async def wait_closed(self) -> None: """ - Send a ping. + Wait until the connection is closed. - Return a :class:`~asyncio.Future` that will be completed when the - corresponding pong is received. You can ignore it if you don't intend - to wait. + This coroutine is identical to the :attr:`closed` attribute, except it + can be awaited. - A ping may serve as a keepalive or as a check that the remote endpoint - received all messages up to this point:: + This can make it easier to detect connection termination, regardless + of its cause, in tasks that interact with the WebSocket connection. - pong_waiter = await ws.ping() - await pong_waiter # only if you want to wait for the pong + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 - By default, the ping contains four random bytes. This payload may be - overridden with the optional ``data`` argument which must be a string - (which will be encoded to UTF-8) or a bytes-like object. + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return immediately, it means the write buffer is full. If you don't want to @@ -695,10 +803,25 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no effect. - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed - :raises RuntimeError: if another ping was sent with the same data and - the corresponding pong wasn't received yet + Args: + data (Optional[Data]): payload of the ping; a string will be + encoded to UTF-8; or :obj:`None` to generate a payload + containing four random bytes. + + Returns: + ~asyncio.Future: A future that will be completed when the + corresponding pong is received. You can ignore it if you + don't intend to wait. + + :: + + pong_waiter = await ws.ping() + await pong_waiter # only if you want to wait for the pong + + Raises: + ConnectionClosed: when the connection is closed. + RuntimeError: if another ping was sent with the same data and + the corresponding pong wasn't received yet. """ await self.ensure_open() @@ -722,18 +845,22 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: async def pong(self, data: Data = b"") -> None: """ - Send a pong. + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. - The payload may be set with the optional ``data`` argument which must - be a string (which will be encoded to UTF-8) or a bytes-like object. + Canceling :meth:`pong` is discouraged. If :meth:`pong` doesn't return + immediately, it means the write buffer is full. If you don't want to + wait, you should close the connection. - Canceling :meth:`pong` is discouraged for the same reason as - :meth:`ping`. + Args: + data (Data): payload of the pong; a string will be encoded to + UTF-8. - :raises ~websockets.exceptions.ConnectionClosed: when the - connection is closed + Raises: + ConnectionClosed: when the connection is closed. """ await self.ensure_open() @@ -876,7 +1003,7 @@ async def read_message(self) -> Optional[Data]: Re-assemble data frames if the message is fragmented. - Return ``None`` when the closing handshake is started. + Return :obj:`None` when the closing handshake is started. """ frame = await self.read_data_frame(max_size=self.max_size) @@ -950,7 +1077,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: Process control frames received before the next data frame. - Return ``None`` if a close frame is encountered before any data frame. + Return :obj:`None` if a close frame is encountered before any data frame. """ # 6.2. Receiving Data @@ -1224,7 +1351,8 @@ async def wait_for_connection_lost(self) -> bool: """ Wait until the TCP connection is closed or ``self.close_timeout`` elapses. - Return ``True`` if the connection is closed and ``False`` otherwise. + Return :obj:`True` if the connection is closed and :obj:`False` + otherwise. """ if not self.connection_lost_waiter.done(): @@ -1404,7 +1532,7 @@ def eof_received(self) -> None: the TCP or TLS connection after sending and receiving a close frame. As a consequence, they never need to write after receiving EOF, so - there's no reason to keep the transport open by returning ``True``. + there's no reason to keep the transport open by returning :obj:`True`. Besides, that doesn't work on TLS connections. @@ -1416,15 +1544,15 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N """ Broadcast a message to several WebSocket connections. - A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like object (:class:`bytes`, :class:`bytearray`, or - :class:`memoryview`) is sent as a `Binary frame`_. + :class:`memoryview`) is sent as a Binary_ frame. - .. _Text frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even - if their write buffers overflow ``write_limit``. There's no backpressure. + if their write buffers are overflowing. There's no backpressure. :func:`broadcast` skips silently connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. @@ -1440,8 +1568,14 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N them in memory, while :func:`broadcast` buffers one copy per connection as fast as possible. - :raises RuntimeError: if a connection is busy sending a fragmented message - :raises TypeError: if ``message`` doesn't have a supported type + Args: + websockets (Iterable[WebSocketCommonProtocol]): WebSocket connections + to which the message will be sent. + message (Data): message to send. + + Raises: + RuntimeError: if a connection is busy sending a fragmented message. + TypeError: if ``message`` doesn't have a supported type. """ if not isinstance(message, (str, bytes, bytearray, memoryview)): diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index ddf6d9f87..673888c3d 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1,8 +1,3 @@ -""" -:mod:`websockets.legacy.server` defines the WebSocket server APIs. - -""" - from __future__ import annotations import asyncio @@ -65,109 +60,33 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): """ - :class:`~asyncio.Protocol` subclass implementing a WebSocket server. + WebSocket server connection. - :class:`WebSocketServerProtocol`: + :class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send` + coroutines for receiving and sending messages. - * performs the opening handshake to establish the connection; - * provides :meth:`recv` and :meth:`send` coroutines for receiving and - sending messages; - * deals with control frames automatically; - * performs the closing handshake to terminate the connection. + It supports asynchronous iteration to receive messages:: - You may customize the opening handshake by subclassing - :class:`WebSocketServer` and overriding: + async for message in websocket: + await process(message) - * :meth:`process_request` to intercept the client request before any - processing and, if appropriate, to abort the WebSocket request and - return a HTTP response instead; - * :meth:`select_subprotocol` to select a subprotocol, if the client and - the server have multiple subprotocols in common and the default logic - for choosing one isn't suitable (this is rarely needed). + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away). It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection + is closed with any other code. - :class:`WebSocketServerProtocol` supports asynchronous iteration:: + You may customize the opening handshake in a subclass by + overriding :meth:`process_request` or :meth:`select_subprotocol`. - async for message in websocket: - await process(message) + Args: + ws_server: WebSocket server that created this connection. - The iterator yields incoming messages. It exits normally when the - connection is closed with the close code 1000 (OK) or 1001 (going away). - It raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other code. - - Once the connection is open, a `Ping frame`_ is sent every - ``ping_interval`` seconds. This serves as a keepalive. It helps keeping - the connection open, especially in the presence of proxies with short - timeouts on inactive connections. Set ``ping_interval`` to ``None`` to - disable this behavior. - - .. _Ping frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 - - If the corresponding `Pong frame`_ isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with - code 1011. This ensures that the remote endpoint remains responsive. Set - ``ping_timeout`` to ``None`` to disable this behavior. - - .. _Pong frame: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 - - The ``close_timeout`` parameter defines a maximum wait time for completing - the closing handshake and terminating the TCP connection. For legacy - reasons, :meth:`close` completes in at most ``4 * close_timeout`` seconds. - - ``close_timeout`` needs to be a parameter of the protocol because - websockets usually calls :meth:`close` implicitly when the connection - handler terminates. - - To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. - - The ``max_size`` parameter enforces the maximum size for incoming messages - in bytes. The default value is 1 MiB. ``None`` disables the limit. If a - message larger than the maximum size is received, :meth:`recv` will - raise :exc:`~websockets.exceptions.ConnectionClosedError` and the - connection will be closed with code 1009. - - The ``max_queue`` parameter sets the maximum length of the queue that - holds incoming messages. The default value is ``32``. ``None`` disables - the limit. Messages are added to an in-memory queue when they're received; - then :meth:`recv` pops from that queue. In order to prevent excessive - memory consumption when messages are received faster than they can be - processed, the queue must be bounded. If the queue fills up, the protocol - stops processing incoming data until :meth:`recv` is called. In this - situation, various receive buffers (at least in :mod:`asyncio` and in the - OS) will fill up, then the TCP receive window will shrink, slowing down - transmission to avoid packet loss. - - Since Python can use up to 4 bytes of memory to represent a single - character, each connection may use up to ``4 * max_size * max_queue`` - bytes of memory to store incoming messages. By default, this is 128 MiB. - You may want to lower the limits, depending on your application's - requirements. - - The ``read_limit`` argument sets the high-water limit of the buffer for - incoming bytes. The low-water limit is half the high-water limit. The - default value is 64 KiB, half of asyncio's default (based on the current - implementation of :class:`~asyncio.StreamReader`). - - The ``write_limit`` argument sets the high-water limit of the buffer for - outgoing bytes. The low-water limit is a quarter of the high-water limit. - The default value is 64 KiB, equal to asyncio's default (based on the - current implementation of ``FlowControlMixin``). - - As soon as the HTTP request and response in the opening handshake are - processed: - - * the request path is available in the :attr:`path` attribute; - * the request and response HTTP headers are available in the - :attr:`request_headers` and :attr:`response_headers` attributes, - which are :class:`~websockets.http.Headers` instances. - - If a subprotocol was negotiated, it's available in the :attr:`subprotocol` - attribute. - - Once the connection is closed, the code is available in the - :attr:`close_code` attribute and the reason in :attr:`close_reason`. - - All attributes must be treated as read-only. + See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``, + ``extensions``, ``subprotocols``, and ``extra_headers``. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. """ @@ -179,6 +98,7 @@ def __init__( ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], ws_server: WebSocketServer, *, + logger: Optional[LoggerLike] = None, origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, @@ -189,7 +109,6 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, - logger: Optional[LoggerLike] = None, **kwargs: Any, ) -> None: if logger is None: @@ -339,8 +258,9 @@ async def read_http_request(self) -> Tuple[str, Headers]: If the request contains a body, it may be read from ``self.reader`` after this coroutine returns. - :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is - malformed or isn't an HTTP/1.1 GET request + Raises: + InvalidMessage: if the HTTP message is malformed or isn't an + HTTP/1.1 GET request. """ try: @@ -394,18 +314,7 @@ async def process_request( """ Intercept the HTTP request and return an HTTP response if appropriate. - If ``process_request`` returns ``None``, the WebSocket handshake - continues. If it returns 3-uple containing a status code, response - headers and a response body, that HTTP response is sent and the - connection is closed. In that case: - - * The HTTP status must be a :class:`~http.HTTPStatus`. - * HTTP headers must be a :class:`~websockets.http.Headers` instance, a - :class:`~collections.abc.Mapping`, or an iterable of ``(name, - value)`` pairs. - * The HTTP response body must be :class:`bytes`. It may be empty. - - This coroutine may be overridden in a :class:`WebSocketServerProtocol` + You may override this method in a :class:`WebSocketServerProtocol` subclass, for example: * to return a HTTP 200 OK response on a given path; then a load @@ -413,19 +322,27 @@ async def process_request( * to authenticate the request and return a HTTP 401 Unauthorized or a HTTP 403 Forbidden when authentication fails. - Instead of subclassing, it is possible to override this method by - passing a ``process_request`` argument to the :func:`serve` function - or the :class:`WebSocketServerProtocol` constructor. This is - equivalent, except ``process_request`` won't have access to the + You may also override this method with the ``process_request`` + argument of :func:`serve` and :class:`WebSocketServerProtocol`. This + is equivalent, except ``process_request`` won't have access to the protocol instance, so it can't store information for later use. - ``process_request`` is expected to complete quickly. If it may run for - a long time, then it should await :meth:`wait_closed` and exit if + :meth:`process_request` is expected to complete quickly. If it may run + for a long time, then it should await :meth:`wait_closed` and exit if :meth:`wait_closed` completes, or else it could prevent the server from shutting down. - :param path: request path, including optional query string - :param request_headers: request headers + Args: + path: request path, including optional query string. + request_headers: request headers. + + Returns: + Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]: :obj:`None` + to continue the WebSocket handshake normally. + + An HTTP response, represented by a 3-uple of the response status, + headers, and body, to abort the WebSocket handshake and return + that HTTP response instead. """ if self._process_request is not None: @@ -447,10 +364,12 @@ def process_origin( """ Handle the Origin HTTP request header. - :param headers: request headers - :param origins: optional list of acceptable origins - :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't - acceptable + Args: + headers: request headers. + origins: optional list of acceptable origins. + + Raises: + InvalidOrigin: if the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -496,10 +415,12 @@ def process_extensions( Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. - :param headers: request headers - :param extensions: optional list of supported extensions - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code + Args: + headers: request headers. + extensions: optional list of supported extensions. + + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. """ response_header_value: Optional[str] = None @@ -557,10 +478,12 @@ def process_subprotocol( Return Sec-WebSocket-Protocol HTTP response header, which is the same as the selected subprotocol. - :param headers: request headers - :param available_subprotocols: optional list of supported subprotocols - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code + Args: + headers: request headers. + available_subprotocols: optional list of supported subprotocols. + + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. """ subprotocol: Optional[Subprotocol] = None @@ -588,22 +511,27 @@ def select_subprotocol( Pick a subprotocol among those offered by the client. If several subprotocols are supported by the client and the server, - the default implementation selects the preferred subprotocols by + the default implementation selects the preferred subprotocol by giving equal value to the priorities of the client and the server. - If no subprotocol is supported by the client and the server, it proceeds without a subprotocol. - This is unlikely to be the most useful implementation in practice, as - many servers providing a subprotocol will require that the client uses - that subprotocol. Such rules can be implemented in a subclass. + This is unlikely to be the most useful implementation in practice. + Many servers providing a subprotocol will require that the client + uses that subprotocol. Such rules can be implemented in a subclass. + + You may also override this method with the ``select_subprotocol`` + argument of :func:`serve` and :class:`WebSocketServerProtocol`. - Instead of subclassing, it is possible to override this method by - passing a ``select_subprotocol`` argument to the :func:`serve` - function or the :class:`WebSocketServerProtocol` constructor. + Args: + client_subprotocols: list of subprotocols offered by the client. + server_subprotocols: list of subprotocols available on the server. + + Returns: + Optional[Subprotocol]: Selected subprotocol. + + :obj:`None` to continue without a subprotocol. - :param client_subprotocols: list of subprotocols offered by the client - :param server_subprotocols: list of subprotocols available on the server """ if self._select_subprotocol is not None: @@ -629,19 +557,18 @@ async def handshake( Return the path of the URI of the request. - :param origins: list of acceptable values of the Origin HTTP header; - include ``None`` if the lack of an origin is acceptable - :param available_extensions: list of supported extensions in the order - in which they should be used - :param available_subprotocols: list of supported subprotocols in order - of decreasing preference - :param extra_headers: sets additional HTTP response headers when the - handshake succeeds; it can be a :class:`~websockets.http.Headers` - instance, a :class:`~collections.abc.Mapping`, an iterable of - ``(name, value)`` pairs, or a callable taking the request path and - headers in arguments and returning one of the above. - :raises ~websockets.exceptions.InvalidHandshake: if the handshake - fails + Args: + origins: list of acceptable values of the Origin HTTP header; + include :obj:`None` if the lack of an origin is acceptable. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of + decreasing preference. + extra_headers: arbitrary HTTP headers to add to the response when + the handshake succeeds. + + Raises: + InvalidHandshake: if the handshake fails. """ path, request_headers = await self.read_http_request() @@ -714,24 +641,24 @@ class WebSocketServer: """ WebSocket server returned by :func:`serve`. - This class provides the same interface as - :class:`~asyncio.AbstractServer`, namely the - :meth:`~asyncio.AbstractServer.close` and - :meth:`~asyncio.AbstractServer.wait_closed` methods. + This class provides the same interface as :class:`~asyncio.Server`, + notably the :meth:`~asyncio.Server.close` + and :meth:`~asyncio.Server.wait_closed` methods. It keeps track of WebSocket connections in order to close them properly when shutting down. - Instances of this class store a reference to the :class:`~asyncio.Server` - object returned by :meth:`~asyncio.loop.create_server` rather than inherit - from :class:`~asyncio.Server` in part because - :meth:`~asyncio.loop.create_server` doesn't support passing a custom - :class:`~asyncio.Server` class. + Args: + logger: logger for this server; + defaults to ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../topics/logging>` for details. """ def __init__( - self, loop: asyncio.AbstractEventLoop, logger: Optional[LoggerLike] = None + self, + loop: asyncio.AbstractEventLoop, + logger: Optional[LoggerLike] = None, ) -> None: # Store a reference to loop to avoid relying on self.server._loop. self.loop = loop @@ -874,15 +801,26 @@ async def wait_closed(self) -> None: When :meth:`wait_closed` returns, all TCP connections are closed and all connection handlers have returned. + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + """ await asyncio.shield(self.closed_waiter) @property def sockets(self) -> Optional[List[socket.socket]]: """ - List of :class:`~socket.socket` objects the server is listening to. + List of :obj:`~socket.socket` objects the server is listening on. - ``None`` if the server is closed. + :obj:`None` if the server is closed. """ return self.server.sockets @@ -890,25 +828,21 @@ def sockets(self) -> Optional[List[socket.socket]]: class Serve: """ + Start a WebSocket server listening on ``host`` and ``port``. - Create, start, and return a WebSocket server on ``host`` and ``port``. - - Whenever a client connects, the server accepts the connection, creates a + Whenever a client connects, the server creates a :class:`WebSocketServerProtocol`, performs the opening handshake, and - delegates to the connection handler defined by ``ws_handler``. Once the - handler completes, either normally or with an exception, the server - performs the closing handshake and closes the connection. + delegates to the connection handler, ``ws_handler``. - Awaiting :func:`serve` yields a :class:`WebSocketServer`. This instance - provides :meth:`~WebSocketServer.close` and - :meth:`~WebSocketServer.wait_closed` methods for terminating the server - and cleaning up its resources. + The handler receives the :class:`WebSocketServerProtocol` and uses it to + send and receive messages. + + Once the handler completes, either normally or with an exception, the + server performs the closing handshake and closes the connection. - When a server is closed with :meth:`~WebSocketServer.close`, it closes all - connections with close code 1001 (going away). Connections handlers, which - are running the ``ws_handler`` coroutine, will receive a - :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their - current or next interaction with the WebSocket connection. + Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object + provides :meth:`~WebSocketServer.close` and + :meth:`~WebSocketServer.wait_closed` methods for shutting down the server. :func:`serve` can be used as an asynchronous context manager:: @@ -917,64 +851,61 @@ class Serve: async with serve(...): await stop - In this case, the server is shut down when exiting the context. - - :func:`serve` is a wrapper around the event loop's - :meth:`~asyncio.loop.create_server` method. It creates and starts a - :class:`asyncio.Server` with :meth:`~asyncio.loop.create_server`. Then it - wraps the :class:`asyncio.Server` in a :class:`WebSocketServer` and - returns the :class:`WebSocketServer`. - - ``ws_handler`` is the WebSocket handler. It must be a coroutine accepting - two arguments: the WebSocket connection, which is an instance of - :class:`WebSocketServerProtocol`, and the path of the request. - - The ``host`` and ``port`` arguments, as well as unrecognized keyword - arguments, are passed to :meth:`~asyncio.loop.create_server`. - - For example, you can set the ``ssl`` keyword argument to a - :class:`~ssl.SSLContext` to enable TLS. - - ``create_protocol`` defaults to :class:`WebSocketServerProtocol`. It may - be replaced by a wrapper or a subclass to customize the protocol that - manages the connection. - - The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, - ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is - described in :class:`WebSocketServerProtocol`. - - :func:`serve` also accepts the following optional arguments: - - * ``compression`` is a shortcut to configure compression extensions; - by default it enables the "permessage-deflate" extension; set it to - ``None`` to disable compression. - * ``origins`` defines acceptable Origin HTTP headers; include ``None`` in - the list if the lack of an origin is acceptable. - * ``extensions`` is a list of supported extensions in order of - decreasing preference. - * ``subprotocols`` is a list of supported subprotocols in order of - decreasing preference. - * ``extra_headers`` sets additional HTTP response headers when the - handshake succeeds; it can be a :class:`~websockets.http.Headers` - instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name, - value)`` pairs, or a callable taking the request path and headers in - arguments and returning one of the above. - * ``process_request`` allows intercepting the HTTP request; it must be a - coroutine taking the request path and headers in argument; see - :meth:`~WebSocketServerProtocol.process_request` for details. - * ``select_subprotocol`` allows customizing the logic for selecting a - subprotocol; it must be a callable taking the subprotocols offered by - the client and available on the server in argument; see - :meth:`~WebSocketServerProtocol.select_subprotocol` for details. - - Since there's no useful way to propagate exceptions triggered in handlers, - they're sent to the ``"websockets.server"`` logger instead. - Debugging is much easier if you configure logging to print them:: - - import logging - logger = logging.getLogger("websockets.server") - logger.setLevel(logging.ERROR) - logger.addHandler(logging.StreamHandler()) + The server is shut down automatically when exiting the context. + + Args: + ws_handler: connection handler. It must be a coroutine accepting + two arguments: the WebSocket connection, which is a + :class:`WebSocketServerProtocol`, and the path of the request. + host: network interfaces the server is bound to; + see :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on; + see :meth:`~asyncio.loop.create_server` for details. + create_protocol: factory for the :class:`asyncio.Protocol` managing + the connection; defaults to :class:`WebSocketServerProtocol`; may + be set to a wrapper or a subclass to customize connection handling. + logger: logger for this server; + defaults to ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../topics/logging>` for details. + compression: shortcut that enables the "permessage-deflate" extension + by default; may be set to :obj:`None` to disable compression; + see the :doc:`compression guide <../topics/compression>` for details. + origins: acceptable values of the ``Origin`` header; include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket + Hijacking attacks. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of decreasing + preference. + extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]): + arbitrary HTTP headers to add to the request; this can be + a :data:`~websockets.datastructures.HeadersLike` or a callable + taking the request path and headers in arguments and returning + a :data:`~websockets.datastructures.HeadersLike`. + process_request (Optional[Callable[[str, Headers], \ + Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): + intercept HTTP request before the opening handshake; + see :meth:`~WebSocketServerProtocol.process_request` for details. + select_subprotocol: select a subprotocol supported by the client; + see :meth:`~WebSocketServerProtocol.select_subprotocol` for details. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + Any other keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to a :obj:`~socket.socket` that you created + outside of websockets. + + Returns: + WebSocketServer: WebSocket server. """ @@ -985,13 +916,7 @@ def __init__( port: Optional[int] = None, *, create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, + logger: Optional[LoggerLike] = None, compression: Optional[str] = "deflate", origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -1003,7 +928,13 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, - logger: Optional[LoggerLike] = None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2 ** 20, + max_queue: Optional[int] = 2 ** 5, + read_limit: int = 2 ** 16, + write_limit: int = 2 ** 16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. @@ -1131,14 +1062,15 @@ def unix_serve( """ Similar to :func:`serve`, but for listening on Unix sockets. - This function calls the event loop's - :meth:`~asyncio.loop.create_unix_server` method. + This function builds upon the event + loop's :meth:`~asyncio.loop.create_unix_server` method. It is only available on Unix. It's useful for deploying a server behind a reverse proxy such as nginx. - :param path: file system path to the Unix socket + Args: + path: file system path to the Unix socket. """ return serve(ws_handler, path=path, unix=True, **kwargs) diff --git a/src/websockets/server.py b/src/websockets/server.py index 0ae0ae940..5f7bec30d 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -69,14 +69,24 @@ def __init__( def accept(self, request: Request) -> Response: """ - Create a WebSocket handshake response event to send to the client. + Create a WebSocket handshake response event to accept the connection. - If the connection cannot be established, the response rejects the - connection, which may be unexpected. + If the connection cannot be established, create a HTTP response event + to reject the handshake. + + Args: + request: handshake request event received from the client. + + Returns: + Response: handshake response event to send to the client. """ try: - key, extensions_header, protocol_header = self.process_request(request) + ( + accept_header, + extensions_header, + protocol_header, + ) = self.process_request(request) except InvalidOrigin as exc: request.exception = exc if self.debug: @@ -124,7 +134,7 @@ def accept(self, request: Request) -> Response: headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" - headers["Sec-WebSocket-Accept"] = accept_key(key) + headers["Sec-WebSocket-Accept"] = accept_header if extensions_header is not None: headers["Sec-WebSocket-Extensions"] = extensions_header @@ -141,17 +151,24 @@ def process_request( self, request: Request ) -> Tuple[str, Optional[str], Optional[str]]: """ - Check a handshake request received from the client. + Check a handshake request and negociate extensions and subprotocol. - This function doesn't verify that the request is an HTTP/1.1 or higher GET - request and doesn't perform ``Host`` and ``Origin`` checks. These controls - are usually performed earlier in the HTTP request handling code. They're + This function doesn't verify that the request is an HTTP/1.1 or higher + GET request and doesn't check the ``Host`` header. These controls are + usually performed earlier in the HTTP request handling code. They're the responsibility of the caller. - :param request: request - :returns: ``key`` which must be passed to :func:`build_response` - :raises ~websockets.exceptions.InvalidHandshake: if the handshake request - is invalid; then the server must return 400 Bad Request error + Args: + request: WebSocket handshake request received from the client. + + Returns: + Tuple[str, Optional[str], Optional[str]]: + ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and + ``Sec-WebSocket-Protocol`` headers for the handshake response. + + Raises: + InvalidHandshake: if the handshake request is invalid; + then the server must return 400 Bad Request error. """ headers = request.headers @@ -204,21 +221,32 @@ def process_request( if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) + accept_header = accept_key(key) + self.origin = self.process_origin(headers) extensions_header, self.extensions = self.process_extensions(headers) protocol_header = self.subprotocol = self.process_subprotocol(headers) - return key, extensions_header, protocol_header + return ( + accept_header, + extensions_header, + protocol_header, + ) def process_origin(self, headers: Headers) -> Optional[Origin]: """ Handle the Origin HTTP request header. - :param headers: request headers - :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't - acceptable + Args: + headers: WebSocket handshake request headers. + + Returns: + Optional[Origin]: origin, if it is acceptable. + + Raises: + InvalidOrigin: if the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -242,9 +270,6 @@ def process_extensions( Accept or reject each extension proposed in the client request. Negotiate parameters for accepted extensions. - Return the Sec-WebSocket-Extensions HTTP response header and the list - of accepted extensions. - :rfc:`6455` leaves the rules up to the specification of each :extension. @@ -263,9 +288,15 @@ def process_extensions( Other requirements, for example related to mandatory extensions or the order of extensions, may be implemented by overriding this method. - :param headers: request headers - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code + Args: + headers: WebSocket handshake request headers. + + Returns: + Tuple[Optional[str], List[Extension]]: ``Sec-WebSocket-Extensions`` + HTTP response header and list of accepted extensions. + + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. """ response_header_value: Optional[str] = None @@ -317,12 +348,15 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: """ Handle the Sec-WebSocket-Protocol HTTP request header. - Return Sec-WebSocket-Protocol HTTP response header, which is the same - as the selected subprotocol. + Args: + headers: WebSocket handshake request headers. + + Returns: + Optional[Subprotocol]: Subprotocol, if one was selected; this is + also the value of the ``Sec-WebSocket-Protocol`` response header. - :param headers: request headers - :raises ~websockets.exceptions.InvalidHandshake: to abort the - handshake with an HTTP 400 error code + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. """ subprotocol: Optional[Subprotocol] = None @@ -360,8 +394,13 @@ def select_subprotocol( many servers providing a subprotocol will require that the client uses that subprotocol. - :param client_subprotocols: list of subprotocols offered by the client - :param server_subprotocols: list of subprotocols available on the server + Args: + client_subprotocols: list of subprotocols offered by the client. + server_subprotocols: list of subprotocols available on the server. + + Returns: + Optional[Subprotocol]: Subprotocol, if a common subprotocol was + found. """ subprotocols = set(client_subprotocols) & set(server_subprotocols) @@ -380,11 +419,14 @@ def reject( exception: Optional[Exception] = None, ) -> Response: """ - Create a HTTP response event to send to the client. + Create a HTTP response event to reject the connection. A short plain text response is the best fallback when failing to establish a WebSocket connection. + Returns: + Response: HTTP handshake response to send to the client. + """ body = text.encode() if headers is None: @@ -399,7 +441,10 @@ def reject( def send_response(self, response: Response) -> None: """ - Send a WebSocket handshake response to the client. + Send a handshake response to the client. + + Args: + response: WebSocket handshake response event to send. """ if self.debug: diff --git a/src/websockets/streams.py b/src/websockets/streams.py index d1ce377e7..094cbb53a 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -7,8 +7,8 @@ class StreamReader: """ Generator-based stream reader. - This class doesn't support concurrent calls to :meth:`read_line()`, - :meth:`read_exact()`, or :meth:`read_to_eof()`. Make sure calls are + This class doesn't support concurrent calls to :meth:`read_line`, + :meth:`read_exact`, or :meth:`read_to_eof`. Make sure calls are serialized. """ @@ -21,11 +21,12 @@ def read_line(self) -> Generator[None, None, bytes]: """ Read a LF-terminated line from the stream. - The return value includes the LF character. - This is a generator-based coroutine. - :raises EOFError: if the stream ends without a LF + The return value includes the LF character. + + Raises: + EOFError: if the stream ends without a LF. """ n = 0 # number of bytes to read @@ -44,11 +45,15 @@ def read_line(self) -> Generator[None, None, bytes]: def read_exact(self, n: int) -> Generator[None, None, bytes]: """ - Read ``n`` bytes from the stream. + Read a given number of bytes from the stream. This is a generator-based coroutine. - :raises EOFError: if the stream ends in less than ``n`` bytes + Args: + n: how many bytes to read. + + Raises: + EOFError: if the stream ends in less than ``n`` bytes. """ assert n >= 0 @@ -92,11 +97,15 @@ def at_eof(self) -> Generator[None, None, bool]: def feed_data(self, data: bytes) -> None: """ - Write ``data`` to the stream. + Write data to the stream. + + :meth:`feed_data` cannot be called after :meth:`feed_eof`. - :meth:`feed_data()` cannot be called after :meth:`feed_eof()`. + Args: + data: data to write. - :raises EOFError: if the stream has ended + Raises: + EOFError: if the stream has ended. """ if self.eof: @@ -107,9 +116,10 @@ def feed_eof(self) -> None: """ End the stream. - :meth:`feed_eof()` must be called at must once. + :meth:`feed_eof` cannot be called more than once. - :raises EOFError: if the stream has ended + Raises: + EOFError: if the stream has ended. """ if self.eof: @@ -118,7 +128,7 @@ def feed_eof(self) -> None: def discard(self) -> None: """ - Discarding all buffered data, but don't end the stream. + Discard all buffered data, but don't end the stream. """ del self.buffer[:] diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 1bd118071..dadee7aba 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -17,24 +17,25 @@ # Public types used in the signature of public APIs Data = Union[str, bytes] -Data.__doc__ = """ -Types supported in a WebSocket message: +Data.__doc__ = """Types supported in a WebSocket message: +:class:`str` for a Text_ frame, :class:`bytes` for a Binary_. + +.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 +.. _Binary : https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 -- :class:`str` for text messages; -- :class:`bytes` for binary messages. """ LoggerLike = Union[logging.Logger, logging.LoggerAdapter] -LoggerLike.__doc__ = """Types accepted where :class:`~logging.Logger` is expected.""" +LoggerLike.__doc__ = """Types accepted where a :class:`~logging.Logger` is expected.""" Origin = NewType("Origin", str) -Origin.__doc__ = """Value of a Origin header.""" +Origin.__doc__ = """Value of a ``Origin`` header.""" Subprotocol = NewType("Subprotocol", str) -Subprotocol.__doc__ = """Subprotocol in a Sec-WebSocket-Protocol header.""" +Subprotocol.__doc__ = """Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" ExtensionName = NewType("ExtensionName", str) @@ -48,12 +49,12 @@ # Private types ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] -ExtensionHeader.__doc__ = """Extension in a Sec-WebSocket-Extensions header.""" +ExtensionHeader.__doc__ = """Extension in a ``Sec-WebSocket-Extensions`` header.""" ConnectionOption = NewType("ConnectionOption", str) -ConnectionOption.__doc__ = """Connection option in a Connection header.""" +ConnectionOption.__doc__ = """Connection option in a ``Connection`` header.""" UpgradeProtocol = NewType("UpgradeProtocol", str) -UpgradeProtocol.__doc__ = """Upgrade protocol in an Upgrade header.""" +UpgradeProtocol.__doc__ = """Upgrade protocol in an ``Upgrade`` header.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 397c23116..3d8f7cd95 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -1,12 +1,3 @@ -""" -:mod:`websockets.uri` parses WebSocket URIs. - -See `section 3 of RFC 6455`_. - -.. _section 3 of RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-3 - -""" - from __future__ import annotations import dataclasses @@ -24,14 +15,16 @@ class WebSocketURI: """ WebSocket URI. - :param bool secure: secure flag - :param str host: lower-case host - :param int port: port, always set even if it's the default - :param str resource_name: path and optional query - :param str user_info: ``(username, password)`` tuple when the URI contains - `User Information`_, else ``None``. + Attributes: + secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI. + host: Host, normalized to lower case. + port: Port, always set even if it's the default. + resource_name: Path and optional query. + user_info: ``(username, password)`` when the URI contains + `User Information`_, else :obj:`None`. .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 + """ secure: bool @@ -49,8 +42,11 @@ def parse_uri(uri: str) -> WebSocketURI: """ Parse and validate a WebSocket URI. - :raises ~websockets.exceptions.InvalidURI: if ``uri`` isn't a valid - WebSocket URI. + Args: + uri: WebSocket URI. + + Raises: + InvalidURI: if ``uri`` isn't a valid WebSocket URI. """ parsed = urllib.parse.urlparse(uri) diff --git a/src/websockets/utils.py b/src/websockets/utils.py index c6e4b788c..c40404906 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -25,7 +25,8 @@ def accept_key(key: str) -> str: """ Compute the value of the Sec-WebSocket-Accept header. - :param key: value of the Sec-WebSocket-Key header + Args: + key: value of the Sec-WebSocket-Key header. """ sha1 = hashlib.sha1((key + GUID).encode()).digest() @@ -36,8 +37,9 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: """ Apply masking to the data of a WebSocket message. - :param data: Data to mask - :param mask: 4-bytes mask + Args: + data: data to mask. + mask: 4-bytes mask. """ if len(mask) != 4: From f635e5e6afc01e0138cd54cacade764143f5b050 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 08:53:30 +0200 Subject: [PATCH 151/762] Improve changlog. * Use admonitions with custom titles and classes. * Separate changes by category. --- docs/project/changelog.rst | 649 ++++++++++++++++++++++--------------- 1 file changed, 395 insertions(+), 254 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 073514ecd..9161176fc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -6,7 +6,7 @@ Changelog .. _backwards-compatibility policy: Backwards-compatibility policy -.............................. +------------------------------ websockets is intended for production use. Therefore, stability is a goal. @@ -26,45 +26,44 @@ Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. 10.0 -.... +---- *In development* -.. warning:: - - **Version 10.0 drops compatibility with Python 3.6.** - -.. note:: - - **Version 10.0 enables a timeout of 10 seconds on** - :func:`~client.connect` **by default.** +Backwards-incompatible changes +.............................. - You can adjust the timeout with the ``open_timeout`` parameter. Set it to - :obj:`None` to disable the timeout entirely. +.. admonition:: websockets 10.0 requires Python ≥ 3.7. + :class: tip -.. note:: + websockets 9.1 is the last version supporting Python 3.6. - **Version 10.0 deprecates the** ``loop`` **parameter from all APIs.** +.. admonition:: The ``loop`` parameter is deprecated from all APIs. + :class: caution This reflects a decision made in Python 3.8. See the release notes of Python 3.10 for details. -.. note:: +.. admonition:: :func:`~client.connect` times out after 10 seconds by default. + :class: note - **Version 10.0 changes arguments of** - :exc:`~exceptions.ConnectionClosed` **.** + You can adjust the timeout with the ``open_timeout`` parameter. Set it to + :obj:`None` to disable the timeout entirely. - If you raise :exc:`~exceptions.ConnectionClosed` or a subclass — rather - than catch them when websockets raises them — you must change your code. +.. admonition:: The signature of :exc:`~exceptions.ConnectionClosed` changed. + :class: note -.. note:: + If you raise :exc:`~exceptions.ConnectionClosed` or a subclass, rather + than catch them when websockets raises them, you must change your code. - **Version 10.0 adds a ``msg`` parameter to** ``InvalidURI.__init__`` **.** +.. admonition:: A ``msg`` parameter was added to :exc:`~exceptions.InvalidURI`. + :class: note - If you raise :exc:`~exceptions.InvalidURI` — rather than catch them when - websockets raises them — you must change your code. + If you raise :exc:`~exceptions.InvalidURI`, rather than catch it when + websockets raises it, you must change your code. -Also: +New features +............ * Added compatibility with Python 3.10. @@ -75,6 +74,9 @@ Also: * Added ``open_timeout`` to :func:`~client.connect`. +Improvements +............ + * Improved logging. * Provided additional information in :exc:`~exceptions.ConnectionClosed` @@ -85,88 +87,100 @@ Also: * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. -* Fixed handling of relative redirects in :func:`~client.connect`. +* Supported relative redirects in :func:`~client.connect`. * Improved API documentation. 9.1 -... +--- *May 27, 2021* -.. caution:: +Security fix +............ - **Version 9.1 fixes a security issue introduced in version 8.0.** +.. admonition:: websockets 9.1 fixes a security issue introduced in 8.0. + :class: important Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords (`CVE-2021-33880`_). .. _CVE-2021-33880: https://nvd.nist.gov/vuln/detail/CVE-2021-33880 - 9.0.2 -..... +----- *May 15, 2021* +Bug fixes +......... + * Restored compatibility of ``python -m websockets`` with Python < 3.9. * Restored compatibility with mypy. 9.0.1 -..... +----- *May 2, 2021* +Bug fixes +......... + * Fixed issues with the packaging of the 9.0 release. 9.0 -... +--- *May 1, 2021* -.. note:: +Backwards-incompatible changes +.............................. - **Version 9.0 moves or deprecates several APIs.** +.. admonition:: Several modules are moved or deprecated. + :class: caution - Aliases provide backwards compatibility for all previously public APIs. + Aliases provide compatibility for all previously public APIs according to + the `backwards-compatibility policy`_ * :class:`~datastructures.Headers` and - :exc:`~datastructures.MultipleValuesError` were moved from + :exc:`~datastructures.MultipleValuesError` are moved from ``websockets.http`` to :mod:`websockets.datastructures`. If you're using them, you should adjust the import path. * The ``client``, ``server``, ``protocol``, and ``auth`` modules were - moved from the websockets package to ``websockets.legacy`` sub-package, - as part of an upcoming refactoring. Despite the name, they're still - fully supported. The refactoring should be a transparent upgrade for - most uses when it's available. The legacy implementation will be - preserved according to the `backwards-compatibility policy`_. + moved from the ``websockets`` package to a ``websockets.legacy`` + sub-package. Despite the name, they're still fully supported. * The ``framing``, ``handshake``, ``headers``, ``http``, and ``uri`` - modules in the websockets package are deprecated. These modules provided - low-level APIs for reuse by other WebSocket implementations, but that - never happened. Keeping these APIs public makes it more difficult to - improve websockets for no actual benefit. + modules in the ``websockets`` package are deprecated. These modules + provided low-level APIs for reuse by other projects, but they didn't + reach that goal. Keeping these APIs public makes it more difficult to + improve websockets. -.. note:: + These changes pave the path for a refactoring that should be a transparent + upgrade for most uses and facilitate integration by other projects. - **Version 9.0 may require changes if you use static code analysis tools.** +.. admonition:: Convenience imports from ``websockets`` are performed lazily. + :class: note - Convenience imports from the websockets module are performed lazily. While - this is supported by Python, static code analysis tools such as mypy are + While Python supports this, static code analysis tools such as mypy are unable to understand the behavior. If you depend on such tools, use the real import path, which can be found - in the API documentation:: + in the API documentation, for example:: from websockets.client import connect from websockets.server import serve -Also: +New features +............ * Added compatibility with Python 3.9. +Improvements +............ + * Added support for IRIs in addition to URIs. * Added close codes 1012, 1013, and 1014. @@ -174,6 +188,11 @@ Also: * Raised an error when passing a :class:`dict` to :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. +* Improved error reporting. + +Bug fixes +......... + * Fixed sending fragmented, compressed messages. * Fixed ``Host`` header sent when connecting to an IPv6 address. @@ -185,98 +204,93 @@ Also: * Ensured cancellation always propagates, even on Python versions where :exc:`~asyncio.CancelledError` inherits :exc:`Exception`. -* Improved error reporting. - - 8.1 -... +--- *November 1, 2019* +New features +............ + * Added compatibility with Python 3.8. 8.0.2 -..... +----- *July 31, 2019* +Bug fixes +......... + * Restored the ability to pass a socket with the ``sock`` parameter of :func:`~server.serve`. * Removed an incorrect assertion when a connection drops. 8.0.1 -..... +----- *July 21, 2019* -* Restored the ability to import ``WebSocketProtocolError`` from websockets. +Bug fixes +......... + +* Restored the ability to import ``WebSocketProtocolError`` from + ``websockets``. 8.0 -... +--- *July 7, 2019* -.. warning:: - - **Version 8.0 drops compatibility with Python 3.4 and 3.5.** +Backwards-incompatible changes +.............................. -.. note:: +.. admonition:: websockets 8.0 requires Python ≥ 3.6. + :class: tip - **Version 8.0 expects** ``process_request`` **to be a coroutine.** + websockets 7.0 is the last version supporting Python 3.4 and 3.5. - Previously, it could be a function or a coroutine. +.. admonition:: ``process_request`` is now expected to be a coroutine. + :class: note If you're passing a ``process_request`` argument to :func:`~server.serve` or :class:`~server.WebSocketServerProtocol`, or if you're overriding :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, - define it with ``async def`` instead of ``def``. - - For backwards compatibility, functions are still mostly supported, but - mixing functions and coroutines won't work in some inheritance scenarios. + define it with ``async def`` instead of ``def``. Previously, both were supported. -.. note:: + For backwards compatibility, functions are still accepted, but mixing + functions and coroutines won't work in some inheritance scenarios. - **Version 8.0 changes the behavior of the** ``max_queue`` **parameter.** +.. admonition:: ``max_queue`` must be :obj:`None` to disable the limit. + :class: note If you were setting ``max_queue=0`` to make the queue of incoming messages unbounded, change it to ``max_queue=None``. -.. note:: - - **Version 8.0 deprecates the** ``host`` **,** ``port`` **, and** ``secure`` - **attributes of** :class:`~legacy.protocol.WebSocketCommonProtocol`. +.. admonition:: The ``host``, ``port``, and ``secure`` attributes + of :class:`~legacy.protocol.WebSocketCommonProtocol` are deprecated. + :class: note Use :attr:`~legacy.protocol.WebSocketCommonProtocol.local_address` in servers and :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` in clients instead of ``host`` and ``port``. -.. note:: - - **Version 8.0 renames the** ``WebSocketProtocolError`` **exception** - to :exc:`~exceptions.ProtocolError` **.** +.. admonition:: ``WebSocketProtocolError`` is renamed + to :exc:`~exceptions.ProtocolError`. + :class: note - A ``WebSocketProtocolError`` alias provides backwards compatibility. + An alias provides backwards compatibility. -.. note:: +.. admonition:: ``read_response()`` now returns the reason phrase. + :class: note - **Version 8.0 adds the reason phrase to the return type of the low-level - API** ``read_response()`` **.** + If you're using this low-level API, you must change your code. -Also: - -* :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, - :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support bytes-like - types :class:`bytearray` and :class:`memoryview` in addition to - :class:`bytes`. - -* Added :exc:`~exceptions.ConnectionClosedOK` and - :exc:`~exceptions.ConnectionClosedError` subclasses of - :exc:`~exceptions.ConnectionClosed` to tell apart normal connection - termination from errors. +New features +............ * Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP Basic Auth on the server side. @@ -288,20 +302,9 @@ Also: * Added :func:`~client.unix_connect` for connecting to Unix sockets. -* Improved support for sending fragmented messages by accepting asynchronous - iterators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send`. - -* Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` - exceptions in keepalive ping task. If you were using ``ping_timeout=None`` - as a workaround, you can remove it. - -* Changed :meth:`WebSocketServer.close() - ` to perform a proper closing handshake - instead of failing the connection. - -* Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. - -* Improved error messages when HTTP parsing fails. +* Added support for asynchronous generators + in :meth:`~legacy.protocol.WebSocketCommonProtocol.send` + to generate fragmented messages incrementally. * Enabled readline in the interactive client. @@ -313,30 +316,61 @@ Also: * Documented how to optimize memory usage. +Improvements +............ + +* :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, + :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support bytes-like + types :class:`bytearray` and :class:`memoryview` in addition to + :class:`bytes`. + +* Added :exc:`~exceptions.ConnectionClosedOK` and + :exc:`~exceptions.ConnectionClosedError` subclasses of + :exc:`~exceptions.ConnectionClosed` to tell apart normal connection + termination from errors. + +* Changed :meth:`WebSocketServer.close() + ` to perform a proper closing handshake + instead of failing the connection. + +* Improved error messages when HTTP parsing fails. + * Improved API documentation. +Bug fixes +......... + +* Prevented spurious log messages about :exc:`~exceptions.ConnectionClosed` + exceptions in keepalive ping task. If you were using ``ping_timeout=None`` + as a workaround, you can remove it. + +* Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. + 7.0 -... +--- *November 1, 2018* -.. warning:: +Backwards-incompatible changes +.............................. - websockets **now sends Ping frames at regular intervals and closes the - connection if it doesn't receive a matching Pong frame.** +.. admonition:: Keepalive is enabled by default. + :class: important + websockets now sends Ping frames at regular intervals and closes the + connection if it doesn't receive a matching Pong frame. See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. -.. warning:: - - **Version 7.0 changes how a server terminates connections when it's closed - with** :meth:`WebSocketServer.close() - ` **.** +.. admonition:: Termination of connections by :meth:`WebSocketServer.close() + ` changes. + :class: caution Previously, connections handlers were canceled. Now, connections are - closed with close code 1001 (going away). From the perspective of the - connection handler, this is the same as if the remote endpoint was - disconnecting. This removes the need to prepare for + closed with close code 1001 (going away). + + From the perspective of the connection handler, this is the same as if the + remote endpoint was disconnecting. This removes the need to prepare for :exc:`~asyncio.CancelledError` in connection handlers. You can restore the previous behavior by adding the following line at the @@ -346,44 +380,45 @@ Also: closed = asyncio.ensure_future(websocket.wait_closed()) closed.add_done_callback(lambda task: task.cancel()) -.. note:: +.. admonition:: Calling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` + concurrently raises a :exc:`RuntimeError`. + :class: note + + Concurrent calls lead to non-deterministic behavior because there are no + guarantees about which coroutine will receive which message. - **Version 7.0 renames the** ``timeout`` **argument of** - :func:`~server.serve` **and** :func:`~client.connect` **to** - ``close_timeout`` **.** +.. admonition:: The ``timeout`` argument of :func:`~server.serve` + and :func:`~client.connect` is renamed to ``close_timeout`` . + :class: note This prevents confusion with ``ping_timeout``. For backwards compatibility, ``timeout`` is still supported. -.. note:: - - **Version 7.0 changes how a** - :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` **that hasn't - received a pong yet behaves when the connection is closed.** +.. admonition:: The ``origins`` argument of :func:`~server.serve` changes. + :class: note - The ping — as in ``ping = await websocket.ping()`` — used to be canceled - when the connection is closed, so that ``await ping`` raised - :exc:`~asyncio.CancelledError`. Now ``await ping`` raises - :exc:`~exceptions.ConnectionClosed` like other public APIs. + Include :obj:`None` in the list rather than ``''`` to allow requests that + don't contain an Origin header. -.. note:: +.. admonition:: Pending pings aren't canceled when the connection is closed. + :class: note - **Version 7.0 raises a** :exc:`RuntimeError` **exception if two coroutines - call** :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` - **concurrently.** + A ping — as in ``ping = await websocket.ping()`` — for which no pong was + received yet used to be canceled when the connection is closed, so that + ``await ping`` raised :exc:`~asyncio.CancelledError`. - Concurrent calls lead to non-deterministic behavior because there are no - guarantees about which coroutine will receive which message. + Now ``await ping`` raises :exc:`~exceptions.ConnectionClosed` like other + public APIs. -Also: +New features +............ * Added ``process_request`` and ``select_subprotocol`` arguments to :func:`~server.serve` and - :class:`~server.WebSocketServerProtocol` to customize + :class:`~server.WebSocketServerProtocol` to facilitate customization of :meth:`~server.WebSocketServerProtocol.process_request` and - :meth:`~server.WebSocketServerProtocol.select_subprotocol` without - subclassing :class:`~server.WebSocketServerProtocol`. + :meth:`~server.WebSocketServerProtocol.select_subprotocol`. * Added support for sending fragmented messages. @@ -392,33 +427,39 @@ Also: * Added an interactive client: ``python -m websockets ``. -* Changed the ``origins`` argument to represent the lack of an origin with - :obj:`None` rather than ``''``. - -* Fixed a data loss bug in - :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: - canceling it at the wrong time could result in messages being dropped. +Improvements +............ * Improved handling of multiple HTTP headers with the same name. * Improved error messages when a required HTTP header is missing. +Bug fixes +......... + +* Fixed a data loss bug in + :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: + canceling it at the wrong time could result in messages being dropped. + 6.0 -... +--- *July 16, 2018* -.. warning:: +Backwards-incompatible changes +.............................. - **Version 6.0 introduces the** :class:`~datastructures.Headers` **class - for managing HTTP headers and changes several public APIs:** +.. admonition:: The :class:`~datastructures.Headers` class is introduced and + several APIs are updated to use it. + :class: caution - * :meth:`~server.WebSocketServerProtocol.process_request` now - receives a :class:`~datastructures.Headers` instead of a - ``http.client.HTTPMessage`` in the ``request_headers`` argument. + * The ``request_headers`` argument + of :meth:`~server.WebSocketServerProtocol.process_request` is now + a :class:`~datastructures.Headers` instead of + an ``http.client.HTTPMessage``. * The ``request_headers`` and ``response_headers`` attributes of - :class:`~legacy.protocol.WebSocketCommonProtocol` are + :class:`~legacy.protocol.WebSocketCommonProtocol` are now :class:`~datastructures.Headers` instead of ``http.client.HTTPMessage``. * The ``raw_request_headers`` and ``raw_response_headers`` attributes of @@ -435,30 +476,35 @@ Also: pairs. Since :class:`~datastructures.Headers` and ``http.client.HTTPMessage`` - provide similar APIs, this change won't affect most of the code dealing - with HTTP headers. + provide similar APIs, much of the code dealing with HTTP headers won't + require changes. - -Also: +New features +............ * Added compatibility with Python 3.7. 5.0.1 -..... +----- *May 24, 2018* +Bug fixes +......... + * Fixed a regression in 5.0 that broke some invocations of :func:`~server.serve` and :func:`~client.connect`. 5.0 -... +--- *May 22, 2018* -.. caution:: +Security fix +............ - **Version 5.0 fixes a security issue introduced in version 4.0.** +.. admonition:: websockets 5.0 fixes a security issue introduced in 4.0. + :class: important Version 4.0 was vulnerable to denial of service by memory exhaustion because it didn't enforce ``max_size`` when decompressing compressed @@ -466,46 +512,43 @@ Also: .. _CVE-2018-1000518: https://nvd.nist.gov/vuln/detail/CVE-2018-1000518 -.. note:: +Backwards-incompatible changes +.............................. - **Version 5.0 adds a** ``user_info`` **field to the return value of** - ``parse_uri`` **and** ``WebSocketURI`` **.** +.. admonition:: A ``user_info`` field is added to the return value of + ``parse_uri`` and ``WebSocketURI``. + :class: note If you're unpacking ``WebSocketURI`` into four variables, adjust your code to account for that fifth field. -Also: +New features +............ * :func:`~client.connect` performs HTTP Basic Auth when the URI contains credentials. -* Iterating on incoming messages no longer raises an exception when the - connection terminates with close code 1001 (going away). - -* A plain HTTP request now receives a 426 Upgrade Required response and - doesn't log a stack trace. - * :func:`~server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property to protocols. -* If a :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` doesn't receive a - pong, it's canceled when the connection is closed. - -* Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. - * Added new examples in the documentation. -* Updated documentation with new features from Python 3.6. +Improvements +............ -* Improved several other sections of the documentation. +* Iterating on incoming messages no longer raises an exception when the + connection terminates with close code 1001 (going away). -* Fixed missing close code, which caused :exc:`TypeError` on connection close. +* A plain HTTP request now receives a 426 Upgrade Required response and + doesn't log a stack trace. -* Fixed a race condition in the closing handshake that raised - :exc:`~exceptions.InvalidState`. +* If a :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` doesn't receive a + pong, it's canceled when the connection is closed. + +* Reported the cause of :exc:`~exceptions.ConnectionClosed` exceptions. * Stopped logging stack traces when the TCP connection dies prematurely. @@ -515,40 +558,59 @@ Also: * Prevented processing of incoming frames after failing the connection. +* Updated documentation with new features from Python 3.6. + +* Improved several sections of the documentation. + +Bug fixes +......... + +* Prevented :exc:`TypeError` due to missing close code on connection close. + +* Fixed a race condition in the closing handshake that raised + :exc:`~exceptions.InvalidState`. + 4.0.1 -..... +----- *November 2, 2017* +Bug fixes +......... + * Fixed issues with the packaging of the 4.0 release. 4.0 -... +--- *November 2, 2017* -.. warning:: +Backwards-incompatible changes +.............................. - **Version 4.0 drops compatibility with Python 3.3.** +.. admonition:: websockets 4.0 requires Python ≥ 3.4. + :class: tip -.. note:: + websockets 3.4 is the last version supporting Python 3.3. - **Version 4.0 enables compression with the permessage-deflate extension.** +.. admonition:: Compression is enabled by default. + :class: important - In August 2017, Firefox and Chrome support it, but not Safari and IE. + In August 2017, Firefox and Chrome support the permessage-deflate + extension, but not Safari and IE. Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling :func:`~server.serve` or :func:`~client.connect`. -.. note:: - - **Version 4.0 removes the** ``state_name`` **attribute of protocols.** +.. admonition:: The ``state_name`` attribute of protocols is deprecated. + :class: note Use ``protocol.state.name`` instead of ``protocol.state_name``. -Also: +New features +............ * :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. @@ -558,27 +620,42 @@ Also: * Added the :attr:`~server.WebSocketServer.sockets` attribute to the return value of :func:`~server.serve`. -* Reorganized and extended documentation. +* Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. -* Aborted connections if they don't close within the configured ``timeout``. +Improvements +............ -* Rewrote connection termination to increase robustness in edge cases. +* Reorganized and extended documentation. -* Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on - a connection while it's being closed. +* Rewrote connection termination to increase robustness in edge cases. * Reduced verbosity of "Failing the WebSocket connection" logs. -* Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. +Bug fixes +......... + +* Aborted connections if they don't close within the configured ``timeout``. + +* Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on + a connection while it's being closed. 3.4 -... +--- *August 20, 2017* -* Renamed :func:`~server.serve` and :func:`~client.connect`'s - ``klass`` argument to ``create_protocol`` to reflect that it can also be a - callable. For backwards compatibility, ``klass`` is still supported. +Backwards-incompatible changes +.............................. + +.. admonition:: ``InvalidStatus`` is replaced + by :class:`~exceptions.InvalidStatusCode`. + :class: note + + This exception is raised when :func:`~client.connect` receives an invalid + response status code from the server. + +New features +............ * :func:`~server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. @@ -588,57 +665,85 @@ Also: * Made read and write buffer sizes configurable. +Improvements +............ + +* Renamed :func:`~server.serve` and :func:`~client.connect`'s + ``klass`` argument to ``create_protocol`` to reflect that it can also be a + callable. For backwards compatibility, ``klass`` is still supported. + * Rewrote HTTP handling for simplicity and performance. * Added an optional C extension to speed up low-level operations. -* An invalid response status code during :func:`~client.connect` now - raises :class:`~exceptions.InvalidStatusCode`. +Bug fixes +......... * Providing a ``sock`` argument to :func:`~client.connect` no longer crashes. 3.3 -... +--- *March 29, 2017* +New features +............ + * Ensured compatibility with Python 3.6. +Improvements +............ + * Reduced noise in logs caused by connection resets. +Bug fixes +......... + * Avoided crashing on concurrent writes on slow connections. 3.2 -... +--- *August 17, 2016* +New features +............ + * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to :func:`~client.connect` and :func:`~server.serve`. +Improvements +............ + * Made server shutdown more robust. 3.1 -... +--- *April 21, 2016* -* Avoided a warning when closing a connection before the opening handshake. +New features +............ * Added flow control for incoming data. +Bug fixes +......... + +* Avoided a warning when closing a connection before the opening handshake. + 3.0 -... +--- *December 25, 2015* -.. warning:: - - **Version 3.0 introduces a backwards-incompatible change in the** - :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` **API.** +Backwards-incompatible changes +.............................. - **If you're upgrading from 2.x or earlier, please read this carefully.** +.. admonition:: :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` now + raises an exception when the connection is closed. + :class: caution :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` used to return :obj:`None` when the connection was closed. This required checking the @@ -653,61 +758,77 @@ Also: message = await websocket.recv() - When implementing a server, which is the more popular use case, there's no - strong reason to handle such exceptions. Let them bubble up, terminate the - handler coroutine, and the server will simply ignore them. + When implementing a server, there's no strong reason to handle such + exceptions. Let them bubble up, terminate the handler coroutine, and the + server will simply ignore them. In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to :func:`~server.serve`, :func:`~client.connect`, :class:`~server.WebSocketServerProtocol`, or - :class:`~client.WebSocketClientProtocol`. ``legacy_recv`` isn't - documented in their signatures but isn't scheduled for deprecation either. + :class:`~client.WebSocketClientProtocol`. -Also: +New features +............ * :func:`~client.connect` can be used as an asynchronous context manager on Python ≥ 3.5.1. -* Updated documentation with ``await`` and ``async`` syntax from Python 3.5. - * :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` and :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support data passed as :class:`str` in addition to :class:`bytes`. +* Made ``state_name`` attribute on protocols a public API. + +Improvements +............ + +* Updated documentation with ``await`` and ``async`` syntax from Python 3.5. + * Worked around an :mod:`asyncio` bug affecting connection termination under load. -* Made ``state_name`` attribute on protocols a public API. - * Improved documentation. 2.7 -... +--- *November 18, 2015* +New features +............ + * Added compatibility with Python 3.5. +Improvements +............ + * Refreshed documentation. 2.6 -... +--- *August 18, 2015* +New features +............ + * Added ``local_address`` and ``remote_address`` attributes on protocols. * Closed open connections with code 1001 when a server shuts down. +Bug fixes +......... + * Avoided TCP fragmentation of small frames. 2.5 -... +--- *July 28, 2015* -* Improved documentation. +New features +............ * Provided access to handshake request and response HTTP headers. @@ -715,49 +836,66 @@ Also: * Added support for running on a non-default event loop. -* Returned a 403 status code instead of 400 when the request Origin isn't - allowed. +Improvements +............ -* Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer - drops the next message. +* Improved documentation. + +* Sent a 403 status code instead of 400 when request Origin isn't allowed. * Clarified that the closing handshake can be initiated by the client. * Set the close code and reason more consistently. -* Strengthened connection termination by simplifying the implementation. +* Strengthened connection termination. + +Bug fixes +......... -* Improved tests, added tox configuration, and enforced 100% branch coverage. +* Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer + drops the next message. 2.4 -... +--- *January 31, 2015* +New features +............ + * Added support for subprotocols. * Added ``loop`` argument to :func:`~client.connect` and :func:`~server.serve`. 2.3 -... +--- *November 3, 2014* +Improvements +............ + * Improved compliance of close codes. 2.2 -... +--- *July 28, 2014* +New features +............ + * Added support for limiting message size. 2.1 -... +--- *April 26, 2014* +New features +............ + * Added ``host``, ``port`` and ``secure`` attributes on protocols. * Added support for providing and checking Origin_. @@ -765,36 +903,39 @@ Also: .. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 2.0 -... +--- *February 16, 2014* -.. warning:: +Backwards-incompatible changes +.............................. - **Version 2.0 introduces a backwards-incompatible change in the** - :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, +.. admonition:: :meth:`~legacy.protocol.WebSocketCommonProtocol.send`, :meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` **APIs.** - - **If you're upgrading from 1.x or earlier, please read this carefully.** + :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` are now coroutines. + :class: caution - These APIs used to be functions. Now they're coroutines. + They used to be functions. Instead of:: websocket.send(message) - you must now write:: + you must write:: await websocket.send(message) -Also: +New features +............ * Added flow control for outgoing data. 1.0 -... +--- *November 14, 2013* +New features +............ + * Initial public release. From ab0e3b9114f65158104a9cdc1b83ee3357438390 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 09:03:39 +0200 Subject: [PATCH 152/762] Remove loop parameter from WebSocketServer. --- docs/project/changelog.rst | 3 +++ src/websockets/legacy/server.py | 25 +++++++++++-------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9161176fc..66d8d541d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -44,6 +44,9 @@ Backwards-incompatible changes This reflects a decision made in Python 3.8. See the release notes of Python 3.10 for details. + The ``loop`` parameter is also removed + from :class:`~server.WebSocketServer`. This should be transparent. + .. admonition:: :func:`~client.connect` times out after 10 seconds by default. :class: note diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 673888c3d..4399f0782 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -655,14 +655,7 @@ class WebSocketServer: """ - def __init__( - self, - loop: asyncio.AbstractEventLoop, - logger: Optional[LoggerLike] = None, - ) -> None: - # Store a reference to loop to avoid relying on self.server._loop. - self.loop = loop - + def __init__(self, logger: Optional[LoggerLike] = None): if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger @@ -674,7 +667,7 @@ def __init__( self.close_task: Optional[asyncio.Task[None]] = None # Completed when the server is closed and connections are terminated. - self.closed_waiter: asyncio.Future[None] = loop.create_future() + self.closed_waiter: asyncio.Future[None] def wrap(self, server: asyncio.AbstractServer) -> None: """ @@ -705,6 +698,10 @@ def wrap(self, server: asyncio.AbstractServer) -> None: name = str(sock.getsockname()) self.logger.info("server listening on %s", name) + # Initialized here because we need a reference to the event loop. + # This should be moved back to __init__ in Python 3.10. + self.closed_waiter = server.get_loop().create_future() + def register(self, protocol: WebSocketServerProtocol) -> None: """ Register a connection with this server. @@ -743,7 +740,7 @@ def close(self) -> None: """ if self.close_task is None: - self.close_task = self.loop.create_task(self._close()) + self.close_task = self.server.get_loop().create_task(self._close()) async def _close(self) -> None: """ @@ -764,7 +761,7 @@ async def _close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) + await asyncio.sleep(0, **loop_if_py_lt_38(self.server.get_loop())) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING connections with a HTTP 503 @@ -777,7 +774,7 @@ async def _close(self) -> None: asyncio.create_task(websocket.close(1001)) for websocket in self.websockets ], - **loop_if_py_lt_38(self.loop), + **loop_if_py_lt_38(self.server.get_loop()), ) # Wait until all connection handlers are complete. @@ -786,7 +783,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - **loop_if_py_lt_38(self.loop), + **loop_if_py_lt_38(self.server.get_loop()), ) # Tell wait_closed() to return. @@ -968,7 +965,7 @@ def __init__( loop = _loop warnings.warn("remove loop argument", DeprecationWarning) - ws_server = WebSocketServer(logger=logger, loop=loop) + ws_server = WebSocketServer(logger=logger) secure = kwargs.get("ssl") is not None From 6f5e609b3d7644a42367946f0a8d33afcda854f4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 09:07:53 +0200 Subject: [PATCH 153/762] Deprecated legacy_recv. It's been more than 5 years. I would like to get rid of it at some point in the future -- 5 years from now if I apply the backwards-compatibility policy strictly :-/ --- docs/project/changelog.rst | 4 ++++ src/websockets/legacy/protocol.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 66d8d541d..8b6325a1a 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -53,6 +53,10 @@ Backwards-incompatible changes You can adjust the timeout with the ``open_timeout`` parameter. Set it to :obj:`None` to disable the timeout entirely. +.. admonition:: The ``legacy_recv`` option is deprecated. + + See the release notes of websockets 3.0 for details. + .. admonition:: The signature of :exc:`~exceptions.ConnectionClosed` changed. :class: note diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 4631151e6..618e451c2 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -183,6 +183,9 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, timeout: Optional[float] = None, ) -> None: + if legacy_recv: # pragma: no cover + warnings.warn("legacy_recv is deprecated", DeprecationWarning) + # Backwards compatibility: close_timeout used to be called timeout. if timeout is None: timeout = 10 From 6680923c0f253ed85f885f983bb9f29efcd86b30 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 16:44:19 +0200 Subject: [PATCH 154/762] Fix typo. --- docs/topics/authentication.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 44c9c6151..8beeea9aa 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -120,7 +120,7 @@ WebSocket server. the old credentials, which may be expired, resulting in an HTTP 401. Then the next connection succeeds. Perhaps errors clear the cache. - When tokens are short-lived on single-use, this bug produces an + When tokens are short-lived or single-use, this bug produces an interesting effect: every other WebSocket connection fails. * Safari 14 ignores credentials entirely. From 903fd24d8cb163e7a8836f3a1faf03fa8869d969 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 17:01:04 +0200 Subject: [PATCH 155/762] Minor formatting fixes in FAQ. --- docs/howto/faq.rst | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index ec657bb06..a32e5ec1b 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -11,9 +11,6 @@ FAQ .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html -.. contents:: - :local: - Server side ----------- @@ -221,11 +218,13 @@ Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough to yield control. Awaiting a coroutine may yield control, but there's no guarantee that it will. -For example, ``send()`` only yields control when send buffers are full, which -never happens in most practical cases. +For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields +control when send buffers are full, which never happens in most practical +cases. -If you run a loop that contains only synchronous operations and a ``send()`` -call, you must yield control explicitly with :func:`asyncio.sleep`:: +If you run a loop that contains only synchronous operations and +a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield +control explicitly with :func:`asyncio.sleep`:: async def producer(websocket): message = generate_next_message() @@ -248,7 +247,8 @@ blocks the event loop and prevents asyncio from operating normally. This may lead to messages getting send but not received, to connection timeouts, and to unexpected results of shotgun debugging e.g. adding an -unnecessary call to ``send()`` makes the program functional. +unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +makes the program functional. Both sides ---------- @@ -352,8 +352,8 @@ See the discussion of :doc:`timeouts <../topics/timeouts>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. -How do I set a timeout on ``recv()``? -..................................... +How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? +................................................................................ Use :func:`~asyncio.wait_for`:: From b12adc59e74dc521710973894576d13d03dff869 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 17:02:20 +0200 Subject: [PATCH 156/762] Reformat conf.py with black. --- docs/conf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d9e3cd598..ffe61f7ba 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -74,8 +74,9 @@ # Workaround for https://github.com/sphinx-doc/sphinx/issues/9560 from sphinx.domains.python import PythonDomain -assert PythonDomain.object_types['data'].roles == ('data', 'obj') -PythonDomain.object_types['data'].roles = ('data', 'class', 'obj') + +assert PythonDomain.object_types["data"].roles == ("data", "obj") +PythonDomain.object_types["data"].roles = ("data", "class", "obj") intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} From 788f8e149c8c76570fa485a193fdb4191beff69b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 17:05:37 +0200 Subject: [PATCH 157/762] Add missing items to changelog. --- docs/project/changelog.rst | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8b6325a1a..84ab6a314 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -81,23 +81,47 @@ New features * Added ``open_timeout`` to :func:`~client.connect`. +* Documented how to integrate with `Django `_. + +* Documented how to deploy websockets in production, with several options. + +* Documented how to authenticate connections. + +* Documented how to broadcast messages to many connections. + Improvements ............ -* Improved logging. - -* Provided additional information in :exc:`~exceptions.ConnectionClosed` - exceptions. +* Improved logging. See the :doc:`logging guide <../topics/logging>`. * Optimized default compression settings to reduce memory usage. +* Optimized processing of client-to-server messages when the C extension isn't + available. + +* Supported relative redirects in :func:`~client.connect`. + +* Handled TCP connection drops during the opening handshake. + * Made it easier to customize authentication with :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. -* Supported relative redirects in :func:`~client.connect`. +* Provided additional information in :exc:`~exceptions.ConnectionClosed` + exceptions. + +* Clarified several exceptions or log messages. + +* Restructured documentation. * Improved API documentation. +* Extended FAQ. + +Bug fixes +......... + +* Avoided a crash when receiving a ping while the connection is closing. + 9.1 --- From add0d464b9721a94195f7481aa1a8bbffeed3f98 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 19:05:31 +0200 Subject: [PATCH 158/762] Make the logger attribute a public API. --- docs/reference/client.rst | 2 ++ docs/reference/common.rst | 2 ++ docs/reference/server.rst | 2 ++ src/websockets/legacy/protocol.py | 3 ++- 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/reference/client.rst b/docs/reference/client.rst index eaa6cdd76..dc31dc032 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -33,6 +33,8 @@ Client .. autoattribute:: id + .. autoattribute:: logger + .. autoproperty:: local_address .. autoproperty:: remote_address diff --git a/docs/reference/common.rst b/docs/reference/common.rst index 3b9f34a57..f5422bc35 100644 --- a/docs/reference/common.rst +++ b/docs/reference/common.rst @@ -24,6 +24,8 @@ Both sides .. autoattribute:: id + .. autoattribute:: logger + .. autoproperty:: local_address .. autoproperty:: remote_address diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 0a5a060f3..1864594f5 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -50,6 +50,8 @@ Server .. autoattribute:: id + .. autoattribute:: logger + .. autoproperty:: local_address .. autoproperty:: remote_address diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 618e451c2..a31a5c7c8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -218,7 +218,8 @@ def __init__( logger = logging.getLogger("websockets.protocol") # https://github.com/python/typeshed/issues/5561 logger = cast(logging.Logger, logger) - self.logger = logging.LoggerAdapter(logger, {"websocket": self}) + self.logger: LoggerLike = logging.LoggerAdapter(logger, {"websocket": self}) + """Logger for this connection.""" # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) From 0a935b8ec16f4430ffe638cdbfbe45f6f9d7f684 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 20:53:39 +0200 Subject: [PATCH 159/762] Prevent unlimited reads. This can mitigate some denial of service scenarios. --- src/websockets/http11.py | 40 +++++++++++++++++------- src/websockets/streams.py | 21 +++++++++++-- tests/test_http11.py | 25 +++++++++++++-- tests/test_streams.py | 65 +++++++++++++++++++++++++++++++-------- 4 files changed, 123 insertions(+), 28 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index b82a0bfdc..052719c67 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -7,8 +7,18 @@ from . import datastructures, exceptions +# Maximum total size of headers is around 256 * 4 KiB = 1 MiB MAX_HEADERS = 256 -MAX_LINE = 4110 + +# We can use the same limit for the request line and header lines: +# "GET <4096 bytes> HTTP/1.1\r\n" = 4111 bytes +# "Set-Cookie: <4097 bytes>\r\n" = 4111 bytes +# (RFC requires 4096 bytes; for some reason Firefox supports 4097 bytes.) +MAX_LINE = 4111 + +# Support for HTTP response bodies is intended to read an error message +# returned by a server. It isn't designed to perform large file transfers. +MAX_BODY = 2 ** 20 # 1 MiB def d(value: bytes) -> str: @@ -60,7 +70,7 @@ class Request: @classmethod def parse( cls, - read_line: Callable[[], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, Request]: """ Parse a WebSocket handshake request. @@ -157,9 +167,9 @@ class Response: @classmethod def parse( cls, - read_line: Callable[[], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], - read_to_eof: Callable[[], Generator[None, None, bytes]], + read_to_eof: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, Response]: """ Parse a WebSocket handshake response. @@ -234,7 +244,16 @@ def parse( content_length = int(raw_content_length) if content_length is None: - body = yield from read_to_eof() + try: + body = yield from read_to_eof(MAX_BODY) + except RuntimeError: + raise exceptions.SecurityError( + f"body too large: over {MAX_BODY} bytes" + ) + elif content_length > MAX_BODY: + raise exceptions.SecurityError( + f"body too large: {content_length} bytes" + ) else: body = yield from read_exact(content_length) @@ -255,7 +274,7 @@ def serialize(self) -> bytes: def parse_headers( - read_line: Callable[[], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, datastructures.Headers]: """ Parse HTTP headers. @@ -306,7 +325,7 @@ def parse_headers( def parse_line( - read_line: Callable[[], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, bytes]: """ Parse a single line. @@ -322,10 +341,9 @@ def parse_line( SecurityError: if the response exceeds a security limit. """ - # Security: TODO: add a limit here - line = yield from read_line() - # Security: this guarantees header values are small (hard-coded = 4 KiB) - if len(line) > MAX_LINE: + try: + line = yield from read_line(MAX_LINE) + except RuntimeError: raise exceptions.SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 094cbb53a..f861d4bd2 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -17,7 +17,7 @@ def __init__(self) -> None: self.buffer = bytearray() self.eof = False - def read_line(self) -> Generator[None, None, bytes]: + def read_line(self, m: int) -> Generator[None, None, bytes]: """ Read a LF-terminated line from the stream. @@ -25,8 +25,12 @@ def read_line(self) -> Generator[None, None, bytes]: The return value includes the LF character. + Args: + m: maximum number bytes to read; this is a security limit. + Raises: EOFError: if the stream ends without a LF. + RuntimeError: if the stream ends in more than ``m`` bytes. """ n = 0 # number of bytes to read @@ -36,9 +40,13 @@ def read_line(self) -> Generator[None, None, bytes]: if n > 0: break p = len(self.buffer) + if p > m: + raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") if self.eof: raise EOFError(f"stream ends after {p} bytes, before end of line") yield + if n > m: + raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes") r = self.buffer[:n] del self.buffer[:n] return r @@ -66,14 +74,23 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: del self.buffer[:n] return r - def read_to_eof(self) -> Generator[None, None, bytes]: + def read_to_eof(self, m: int) -> Generator[None, None, bytes]: """ Read all bytes from the stream. This is a generator-based coroutine. + Args: + m: maximum number bytes to read; this is a security limit. + + Raises: + RuntimeError: if the stream ends in more than ``m`` bytes. + """ while not self.eof: + p = len(self.buffer) + if p > m: + raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") yield r = self.buffer[:] del self.buffer[:] diff --git a/tests/test_http11.py b/tests/test_http11.py index afd85f64a..61d377925 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -158,7 +158,8 @@ def test_parse_empty(self): with self.assertRaises(EOFError) as raised: next(self.parse()) self.assertEqual( - str(raised.exception), "connection closed while reading HTTP status line" + str(raised.exception), + "connection closed while reading HTTP status line", ) def test_parse_invalid_status_line(self): @@ -230,6 +231,24 @@ def test_parse_body_without_content_length(self): response = self.assertGeneratorReturns(gen) self.assertEqual(response.body, b"Hello world!\n") + def test_parse_body_with_content_length_too_long(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\nContent-Length: 1048577\r\n\r\n") + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "body too large: 1048577 bytes", + ) + + def test_parse_body_without_content_length_too_long(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n" + b"a" * 1048577) + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "body too large: over 1048576 bytes", + ) + def test_parse_body_with_transfer_encoding(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: @@ -314,8 +333,8 @@ def test_parse_too_long_value(self): next(self.parse_headers()) def test_parse_too_long_line(self): - # Header line contains 5 + 4104 + 2 = 4111 bytes. - self.reader.feed_data(b"foo: " + b"a" * 4104 + b"\r\n\r\n") + # Header line contains 5 + 4105 + 2 = 4112 bytes. + self.reader.feed_data(b"foo: " + b"a" * 4105 + b"\r\n\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) diff --git a/tests/test_streams.py b/tests/test_streams.py index 8abefbcc9..fd7c66a0b 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -10,24 +10,24 @@ def setUp(self): def test_read_line(self): self.reader.feed_data(b"spam\neggs\n") - gen = self.reader.read_line() + gen = self.reader.read_line(32) line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"spam\n") - gen = self.reader.read_line() + gen = self.reader.read_line(32) line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"eggs\n") def test_read_line_need_more_data(self): self.reader.feed_data(b"spa") - gen = self.reader.read_line() + gen = self.reader.read_line(32) self.assertGeneratorRunning(gen) self.reader.feed_data(b"m\neg") line = self.assertGeneratorReturns(gen) self.assertEqual(line, b"spam\n") - gen = self.reader.read_line() + gen = self.reader.read_line(32) self.assertGeneratorRunning(gen) self.reader.feed_data(b"gs\n") line = self.assertGeneratorReturns(gen) @@ -37,11 +37,34 @@ def test_read_line_not_enough_data(self): self.reader.feed_data(b"spa") self.reader.feed_eof() - gen = self.reader.read_line() + gen = self.reader.read_line(32) with self.assertRaises(EOFError) as raised: next(gen) self.assertEqual( - str(raised.exception), "stream ends after 3 bytes, before end of line" + str(raised.exception), + "stream ends after 3 bytes, before end of line", + ) + + def test_read_line_too_long(self): + self.reader.feed_data(b"spam\neggs\n") + + gen = self.reader.read_line(2) + with self.assertRaises(RuntimeError) as raised: + next(gen) + self.assertEqual( + str(raised.exception), + "read 5 bytes, expected no more than 2 bytes", + ) + + def test_read_line_too_long_need_more_data(self): + self.reader.feed_data(b"spa") + + gen = self.reader.read_line(2) + with self.assertRaises(RuntimeError) as raised: + next(gen) + self.assertEqual( + str(raised.exception), + "read 3 bytes, expected no more than 2 bytes", ) def test_read_exact(self): @@ -78,11 +101,12 @@ def test_read_exact_not_enough_data(self): with self.assertRaises(EOFError) as raised: next(gen) self.assertEqual( - str(raised.exception), "stream ends after 3 bytes, expected 4 bytes" + str(raised.exception), + "stream ends after 3 bytes, expected 4 bytes", ) def test_read_to_eof(self): - gen = self.reader.read_to_eof() + gen = self.reader.read_to_eof(32) self.reader.feed_data(b"spam") self.assertGeneratorRunning(gen) @@ -94,10 +118,21 @@ def test_read_to_eof(self): def test_read_to_eof_at_eof(self): self.reader.feed_eof() - gen = self.reader.read_to_eof() + gen = self.reader.read_to_eof(32) data = self.assertGeneratorReturns(gen) self.assertEqual(data, b"") + def test_read_to_eof_too_long(self): + gen = self.reader.read_to_eof(2) + + self.reader.feed_data(b"spam") + with self.assertRaises(RuntimeError) as raised: + next(gen) + self.assertEqual( + str(raised.exception), + "read 4 bytes, expected no more than 2 bytes", + ) + def test_at_eof_after_feed_data(self): gen = self.reader.at_eof() self.assertGeneratorRunning(gen) @@ -137,16 +172,22 @@ def test_feed_data_after_feed_eof(self): self.reader.feed_eof() with self.assertRaises(EOFError) as raised: self.reader.feed_data(b"spam") - self.assertEqual(str(raised.exception), "stream ended") + self.assertEqual( + str(raised.exception), + "stream ended", + ) def test_feed_eof_after_feed_eof(self): self.reader.feed_eof() with self.assertRaises(EOFError) as raised: self.reader.feed_eof() - self.assertEqual(str(raised.exception), "stream ended") + self.assertEqual( + str(raised.exception), + "stream ended", + ) def test_discard(self): - gen = self.reader.read_to_eof() + gen = self.reader.read_to_eof(32) self.reader.feed_data(b"spam") self.reader.discard() From cf070908da6f0fdb29035458a041cea7a64e9e3c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 21:48:28 +0200 Subject: [PATCH 160/762] Prepare removal of generator-based coroutines. --- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index e5743cc0e..57fe7e2c4 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -663,7 +663,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: else: raise SecurityError("too many redirects") - # ... = yield from connect(...) + # ... = yield from connect(...) - remove when dropping Python < 3.10 __iter__ = __await__ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4399f0782..9bda5cdd8 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1043,7 +1043,7 @@ async def __await_impl__(self) -> WebSocketServer: self.ws_server.wrap(server) return self.ws_server - # yield from serve(...) + # yield from serve(...) - remove when dropping Python < 3.10 __iter__ = __await__ From d084f4b457eeb9cf79102b8bd73e8018b769e03f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 22:15:11 +0200 Subject: [PATCH 161/762] Tweak constants. For aesthetic reasons. --- src/websockets/legacy/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 57fe7e2c4..ffc54453e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -584,9 +584,9 @@ def handle_redirect(self, uri: str) -> None: # async for ... in connect(...): - BACKOFF_MIN = 2.0 + BACKOFF_MIN = 1.92 BACKOFF_MAX = 60.0 - BACKOFF_FACTOR = 1.5 + BACKOFF_FACTOR = 1.618 async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: backoff_delay = self.BACKOFF_MIN From e7b0c0ff8caecee6f1f2b818940db3d8a6b87027 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 22:23:23 +0200 Subject: [PATCH 162/762] Add random delay before first reconnection attempt. --- src/websockets/legacy/client.py | 27 ++++++++++++++++++++------- tests/legacy/test_client_server.py | 12 +++++++++--- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index ffc54453e..63b973ecb 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -3,6 +3,7 @@ import asyncio import functools import logging +import random import urllib.parse import warnings from types import TracebackType @@ -587,6 +588,7 @@ def handle_redirect(self, uri: str) -> None: BACKOFF_MIN = 1.92 BACKOFF_MAX = 60.0 BACKOFF_FACTOR = 1.618 + BACKOFF_INITIAL = 5 async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: backoff_delay = self.BACKOFF_MIN @@ -599,15 +601,26 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: except asyncio.CancelledError: # pragma: no cover raise except Exception: - # Connection timed out - increase backoff delay + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. + if backoff_delay == self.BACKOFF_MIN: + initial_delay = random.random() * self.BACKOFF_INITIAL + self.logger.info( + "! connect failed; reconnecting in %.1f seconds", + initial_delay, + exc_info=True, + ) + await asyncio.sleep(initial_delay) + else: + self.logger.info( + "! connect failed again; retrying in %d seconds", + int(backoff_delay), + exc_info=True, + ) + await asyncio.sleep(int(backoff_delay)) + # Increase delay with truncated exponential backoff. backoff_delay = backoff_delay * self.BACKOFF_FACTOR backoff_delay = min(backoff_delay, self.BACKOFF_MAX) - self.logger.info( - "! connect failed; retrying in %d seconds", - int(backoff_delay), - exc_info=True, - ) - await asyncio.sleep(backoff_delay) continue else: # Connection succeeded - reset backoff delay diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 016f08e73..482b2cd0c 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1418,7 +1418,8 @@ async def run_client(): iteration = 0 connect_inst = connect(get_server_uri(self.server)) connect_inst.BACKOFF_MIN = 10 * MS - connect_inst.BACKOFF_MAX = 200 * MS + connect_inst.BACKOFF_MAX = 99 * MS + connect_inst.BACKOFF_INITIAL = 0 async for ws in connect_inst: await ws.send("spam") msg = await ws.recv() @@ -1468,9 +1469,14 @@ async def run_client(): [ "connection failed (503 Service Unavailable)", "connection closed", - "! connect failed; retrying in 0 seconds", + "! connect failed; reconnecting in 0.0 seconds", ] - * ((len(logs.records) - 5) // 3) + + [ + "connection failed (503 Service Unavailable)", + "connection closed", + "! connect failed again; retrying in 0 seconds", + ] + * ((len(logs.records) - 8) // 3) + [ "connection open", "connection closed", From 35a7ddcdbd29e1c48216a53c7f1d0ff395595ffc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 21:45:28 +0200 Subject: [PATCH 163/762] Standardize documentation of dataclasses. --- src/websockets/frames.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 9a97f2530..c31796cbd 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -99,9 +99,9 @@ class Frame: """ WebSocket frame. - Args: - opcode: opcode. - data: payload data. + Attributes: + opcode: Opcode. + data: Payload data. fin: FIN bit. rsv1: RSV1 bit. rsv2: RSV2 bit. @@ -368,6 +368,10 @@ class Close: """ WebSocket close code and reason. + Attributes: + code: Close code. + reason: Close reason. + """ code: int From dcd64046b39e874cb5c2e8f9629bdacf14cfe2af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 21:45:48 +0200 Subject: [PATCH 164/762] Refactor WebSocketURI. --- src/websockets/uri.py | 71 ++++++++++++++++++++++++++++--------------- tests/test_uri.py | 54 ++++++++++++++++++++++++++------ 2 files changed, 91 insertions(+), 34 deletions(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 3d8f7cd95..fff0c3806 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -17,11 +17,12 @@ class WebSocketURI: Attributes: secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI. - host: Host, normalized to lower case. - port: Port, always set even if it's the default. - resource_name: Path and optional query. - user_info: ``(username, password)`` when the URI contains - `User Information`_, else :obj:`None`. + host: Normalized to lower case. + port: Always set even if it's the default. + path: May be empty. + query: May be empty if the URI doesn't include a query component. + username: Available when the URI contains `User Information`_. + password: Available when the URI contains `User Information`_. .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 @@ -30,8 +31,27 @@ class WebSocketURI: secure: bool host: str port: int - resource_name: str - user_info: Optional[Tuple[str, str]] = None + path: str + query: str + username: Optional[str] + password: Optional[str] + + @property + def resource_name(self) -> str: + if self.path: + resource_name = self.path + else: + resource_name = "/" + if self.query: + resource_name += "?" + self.query + return resource_name + + @property + def user_info(self) -> Optional[Tuple[str, str]]: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) # All characters from the gen-delims and sub-delims sets in RFC 3987. @@ -45,6 +65,9 @@ def parse_uri(uri: str) -> WebSocketURI: Args: uri: WebSocket URI. + Returns: + WebSocketURI: Parsed WebSocket URI. + Raises: InvalidURI: if ``uri`` isn't a valid WebSocket URI. @@ -60,16 +83,14 @@ def parse_uri(uri: str) -> WebSocketURI: secure = parsed.scheme == "wss" host = parsed.hostname port = parsed.port or (443 if secure else 80) - resource_name = parsed.path or "/" - if parsed.query: - resource_name += "?" + parsed.query - user_info = None - if parsed.username is not None: - # urllib.parse.urlparse accepts URLs with a username but without a - # password. This doesn't make sense for HTTP Basic Auth credentials. - if parsed.password is None: - raise exceptions.InvalidURI(uri, "username provided without password") - user_info = (parsed.username, parsed.password) + path = parsed.path + query = parsed.query + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise exceptions.InvalidURI(uri, "username provided without password") try: uri.encode("ascii") @@ -77,11 +98,11 @@ def parse_uri(uri: str) -> WebSocketURI: # Input contains non-ASCII characters. # It must be an IRI. Convert it to a URI. host = host.encode("idna").decode() - resource_name = urllib.parse.quote(resource_name, safe=DELIMS) - if user_info is not None: - user_info = ( - urllib.parse.quote(user_info[0], safe=DELIMS), - urllib.parse.quote(user_info[1], safe=DELIMS), - ) - - return WebSocketURI(secure, host, port, resource_name, user_info) + path = urllib.parse.quote(path, safe=DELIMS) + query = urllib.parse.quote(query, safe=DELIMS) + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return WebSocketURI(secure, host, port, path, query, username, password) diff --git a/tests/test_uri.py b/tests/test_uri.py index f937d2949..8acc01c18 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -7,33 +7,46 @@ VALID_URIS = [ ( "ws://localhost/", - WebSocketURI(False, "localhost", 80, "/", None), + WebSocketURI(False, "localhost", 80, "/", "", None, None), ), ( "wss://localhost/", - WebSocketURI(True, "localhost", 443, "/", None), + WebSocketURI(True, "localhost", 443, "/", "", None, None), + ), + ( + "ws://localhost", + WebSocketURI(False, "localhost", 80, "", "", None, None), ), ( "ws://localhost/path?query", - WebSocketURI(False, "localhost", 80, "/path?query", None), + WebSocketURI(False, "localhost", 80, "/path", "query", None, None), ), ( "ws://localhost/path;params", - WebSocketURI(False, "localhost", 80, "/path;params", None), + WebSocketURI(False, "localhost", 80, "/path;params", "", None, None), ), ( "WS://LOCALHOST/PATH?QUERY", - WebSocketURI(False, "localhost", 80, "/PATH?QUERY", None), + WebSocketURI(False, "localhost", 80, "/PATH", "QUERY", None, None), ), ( "ws://user:pass@localhost/", - WebSocketURI(False, "localhost", 80, "/", ("user", "pass")), + WebSocketURI(False, "localhost", 80, "/", "", "user", "pass"), + ), + ( + "ws://høst/", + WebSocketURI(False, "xn--hst-0na", 80, "/", "", None, None), ), - ("ws://høst/", WebSocketURI(False, "xn--hst-0na", 80, "/", None)), ( - "ws://üser:påss@høst/πass", + "ws://üser:påss@høst/πass?qùéry", WebSocketURI( - False, "xn--hst-0na", 80, "/%CF%80ass", ("%C3%BCser", "p%C3%A5ss") + False, + "xn--hst-0na", + 80, + "/%CF%80ass", + "q%C3%B9%C3%A9ry", + "%C3%BCser", + "p%C3%A5ss", ), ), ] @@ -46,6 +59,19 @@ "ws:///path", ] +RESOURCE_NAMES = [ + ("ws://localhost/", "/"), + ("ws://localhost", "/"), + ("ws://localhost/path?query", "/path?query"), + ("ws://høst/πass?qùéry", "/%CF%80ass?q%C3%B9%C3%A9ry"), +] + +USER_INFOS = [ + ("ws://localhost/", None), + ("ws://user:pass@localhost/", ("user", "pass")), + ("ws://üser:påss@høst/", ("%C3%BCser", "p%C3%A5ss")), +] + class URITests(unittest.TestCase): def test_success(self): @@ -58,3 +84,13 @@ def test_error(self): with self.subTest(uri=uri): with self.assertRaises(InvalidURI): parse_uri(uri) + + def test_resource_name(self): + for uri, resource_name in RESOURCE_NAMES: + with self.subTest(uri=uri): + self.assertEqual(parse_uri(uri).resource_name, resource_name) + + def test_user_info(self): + for uri, user_info in USER_INFOS: + with self.subTest(uri=uri): + self.assertEqual(parse_uri(uri).user_info, user_info) From 4a22bdf8d570ff3eba0b797a80435f75f4505e66 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 21:46:37 +0200 Subject: [PATCH 165/762] Make websockets.uri a public API (again!) --- docs/reference/utilities.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index dc6333847..e3e53f2d3 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -35,3 +35,12 @@ HTTP events .. automethod:: raw_items .. autoexception:: MultipleValuesError + +URIs +---- + +.. automodule:: websockets.uri + + .. autofunction:: parse_uri + + .. autoclass:: WebSocketURI From abd297b832a74fe23825aacc28bf30e8b00f122d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 21:47:41 +0200 Subject: [PATCH 166/762] Expect a WebSocketURI in ClientConnection. Parsing the URI to get the host and port and opening the connection before initializing a ClientConnection feels more natural. --- src/websockets/client.py | 6 +-- tests/test_client.py | 98 ++++++++++++++++++++++++---------------- 2 files changed, 62 insertions(+), 42 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 13c8a8ad0..42c7aeee9 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -32,7 +32,7 @@ Subprotocol, UpgradeProtocol, ) -from .uri import parse_uri +from .uri import WebSocketURI from .utils import accept_key, generate_key @@ -46,7 +46,7 @@ class ClientConnection(Connection): def __init__( self, - uri: str, + wsuri: WebSocketURI, origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, @@ -60,7 +60,7 @@ def __init__( max_size=max_size, logger=logger, ) - self.wsuri = parse_uri(uri) + self.wsuri = wsuri self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols diff --git a/tests/test_client.py b/tests/test_client.py index 015b93b3f..1c1452d41 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,6 +9,7 @@ from websockets.frames import OP_TEXT, Frame from websockets.http import USER_AGENT from websockets.http11 import Request, Response +from websockets.uri import parse_uri from websockets.utils import accept_key from .extensions.utils import ( @@ -24,7 +25,7 @@ class ConnectTests(unittest.TestCase): def test_send_connect(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("wss://example.com/test") + client = ClientConnection(parse_uri("wss://example.com/test")) request = client.connect() self.assertIsInstance(request, Request) client.send_request(request) @@ -44,7 +45,7 @@ def test_send_connect(self): def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("wss://example.com/test") + client = ClientConnection(parse_uri("wss://example.com/test")) request = client.connect() self.assertEqual(request.path, "/test") self.assertEqual( @@ -62,7 +63,7 @@ def test_connect_request(self): ) def test_path(self): - client = ClientConnection("wss://example.com/endpoint?test=1") + client = ClientConnection(parse_uri("wss://example.com/endpoint?test=1")) request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") @@ -77,33 +78,40 @@ def test_port(self): ("wss://example.com:8443/", "example.com:8443"), ]: with self.subTest(uri=uri): - client = ClientConnection(uri) + client = ClientConnection(parse_uri(uri)) request = client.connect() self.assertEqual(request.headers["Host"], host) def test_user_info(self): - client = ClientConnection("wss://hello:iloveyou@example.com/") + client = ClientConnection(parse_uri("wss://hello:iloveyou@example.com/")) request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): - client = ClientConnection("wss://example.com/", origin="https://example.com") + client = ClientConnection( + parse_uri("wss://example.com/"), + origin="https://example.com", + ) request = client.connect() self.assertEqual(request.headers["Origin"], "https://example.com") def test_extensions(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientOpExtensionFactory()] + parse_uri("wss://example.com/"), + extensions=[ClientOpExtensionFactory()], ) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): - client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + client = ClientConnection( + parse_uri("wss://example.com/"), + subprotocols=["chat"], + ) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") @@ -112,7 +120,7 @@ def test_subprotocols(self): class AcceptRejectTests(unittest.TestCase): def test_receive_accept(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("ws://example.com/test") + client = ClientConnection(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -131,7 +139,7 @@ def test_receive_accept(self): def test_receive_reject(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("ws://example.com/test") + client = ClientConnection(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -151,7 +159,7 @@ def test_receive_reject(self): def test_accept_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("ws://example.com/test") + client = ClientConnection(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -183,7 +191,7 @@ def test_accept_response(self): def test_reject_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection("ws://example.com/test") + client = ClientConnection(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -231,7 +239,7 @@ def make_accept_response(self, client): ) def test_basic(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -239,7 +247,7 @@ def test_basic(self): self.assertEqual(client.state, OPEN) def test_missing_connection(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Connection"] client.receive_data(response.serialize()) @@ -251,7 +259,7 @@ def test_missing_connection(self): self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Connection"] response.headers["Connection"] = "close" @@ -264,7 +272,7 @@ def test_invalid_connection(self): self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Upgrade"] client.receive_data(response.serialize()) @@ -276,7 +284,7 @@ def test_missing_upgrade(self): self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Upgrade"] response.headers["Upgrade"] = "h2c" @@ -289,7 +297,7 @@ def test_invalid_upgrade(self): self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_accept(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] client.receive_data(response.serialize()) @@ -301,7 +309,7 @@ def test_missing_accept(self): self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") def test_multiple_accept(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Accept"] = ACCEPT client.receive_data(response.serialize()) @@ -317,7 +325,7 @@ def test_multiple_accept(self): ) def test_invalid_accept(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] response.headers["Sec-WebSocket-Accept"] = ACCEPT @@ -332,7 +340,7 @@ def test_invalid_accept(self): ) def test_no_extensions(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -342,7 +350,8 @@ def test_no_extensions(self): def test_no_extension(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientOpExtensionFactory()] + parse_uri("wss://example.com/"), + extensions=[ClientOpExtensionFactory()], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" @@ -354,7 +363,8 @@ def test_no_extension(self): def test_extension(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientRsv2ExtensionFactory()] + parse_uri("wss://example.com/"), + extensions=[ClientRsv2ExtensionFactory()], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" @@ -365,7 +375,7 @@ def test_extension(self): self.assertEqual(client.extensions, [Rsv2Extension()]) def test_unexpected_extension(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" client.receive_data(response.serialize()) @@ -378,7 +388,8 @@ def test_unexpected_extension(self): def test_unsupported_extension(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientRsv2ExtensionFactory()] + parse_uri("wss://example.com/"), + extensions=[ClientRsv2ExtensionFactory()], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" @@ -395,7 +406,8 @@ def test_unsupported_extension(self): def test_supported_extension_parameters(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientOpExtensionFactory("this")] + parse_uri("wss://example.com/"), + extensions=[ClientOpExtensionFactory("this")], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" @@ -407,7 +419,8 @@ def test_supported_extension_parameters(self): def test_unsupported_extension_parameters(self): client = ClientConnection( - "wss://example.com/", extensions=[ClientOpExtensionFactory("this")] + parse_uri("wss://example.com/"), + extensions=[ClientOpExtensionFactory("this")], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" @@ -424,7 +437,7 @@ def test_unsupported_extension_parameters(self): def test_multiple_supported_extension_parameters(self): client = ClientConnection( - "wss://example.com/", + parse_uri("wss://example.com/"), extensions=[ ClientOpExtensionFactory("this"), ClientOpExtensionFactory("that"), @@ -440,7 +453,7 @@ def test_multiple_supported_extension_parameters(self): def test_multiple_extensions(self): client = ClientConnection( - "wss://example.com/", + parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], ) response = self.make_accept_response(client) @@ -454,7 +467,7 @@ def test_multiple_extensions(self): def test_multiple_extensions_order(self): client = ClientConnection( - "wss://example.com/", + parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], ) response = self.make_accept_response(client) @@ -467,7 +480,7 @@ def test_multiple_extensions_order(self): self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -476,7 +489,9 @@ def test_no_subprotocols(self): self.assertIsNone(client.subprotocol) def test_no_subprotocol(self): - client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + client = ClientConnection( + parse_uri("wss://example.com/"), subprotocols=["chat"] + ) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -485,7 +500,9 @@ def test_no_subprotocol(self): self.assertIsNone(client.subprotocol) def test_subprotocol(self): - client = ClientConnection("wss://example.com/", subprotocols=["chat"]) + client = ClientConnection( + parse_uri("wss://example.com/"), subprotocols=["chat"] + ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" client.receive_data(response.serialize()) @@ -495,7 +512,7 @@ def test_subprotocol(self): self.assertEqual(client.subprotocol, "chat") def test_unexpected_subprotocol(self): - client = ClientConnection("wss://example.com/") + client = ClientConnection(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" client.receive_data(response.serialize()) @@ -508,7 +525,8 @@ def test_unexpected_subprotocol(self): def test_multiple_subprotocols(self): client = ClientConnection( - "wss://example.com/", subprotocols=["superchat", "chat"] + parse_uri("wss://example.com/"), + subprotocols=["superchat", "chat"], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "superchat" @@ -525,7 +543,8 @@ def test_multiple_subprotocols(self): def test_supported_subprotocol(self): client = ClientConnection( - "wss://example.com/", subprotocols=["superchat", "chat"] + parse_uri("wss://example.com/"), + subprotocols=["superchat", "chat"], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" @@ -537,7 +556,8 @@ def test_supported_subprotocol(self): def test_unsupported_subprotocol(self): client = ClientConnection( - "wss://example.com/", subprotocols=["superchat", "chat"] + parse_uri("wss://example.com/"), + subprotocols=["superchat", "chat"], ) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "otherchat" @@ -552,7 +572,7 @@ def test_unsupported_subprotocol(self): class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - client = ClientConnection("ws://example.com/test", state=OPEN) + client = ClientConnection(parse_uri("ws://example.com/test"), state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) @@ -560,5 +580,5 @@ def test_bypass_handshake(self): def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ClientConnection("wss://example.com/test", logger=logger) + ClientConnection(parse_uri("wss://example.com/test"), logger=logger) self.assertEqual(len(logs.records), 1) From eba7b56eb652a1bfff2d536b5d348499a1e434d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 22:33:10 +0200 Subject: [PATCH 167/762] Improve docs of Frame and Close. --- docs/reference/utilities.rst | 12 ++++++++++++ src/websockets/frames.py | 4 +++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/reference/utilities.rst b/docs/reference/utilities.rst index e3e53f2d3..6b5d402fc 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/utilities.rst @@ -15,6 +15,18 @@ WebSocket events .. autoclass:: Opcode + .. autoattribute:: CONT + + .. autoattribute:: TEXT + + .. autoattribute:: BINARY + + .. autoattribute:: CLOSE + + .. autoattribute:: PING + + .. autoattribute:: PONG + .. autoclass:: Close HTTP events diff --git a/src/websockets/frames.py b/src/websockets/frames.py index c31796cbd..82b4a1403 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -35,6 +35,8 @@ class Opcode(enum.IntEnum): + """Opcode values for WebSocket frames.""" + CONT, TEXT, BINARY = 0x00, 0x01, 0x02 CLOSE, PING, PONG = 0x08, 0x09, 0x0A @@ -366,7 +368,7 @@ def prepare_ctrl(data: Data) -> bytes: @dataclasses.dataclass class Close: """ - WebSocket close code and reason. + Code and reason for WebSocket close frames. Attributes: code: Close code. From 5fc6fa832c11be7c42739f901b3a893285bddec0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 6 Sep 2021 22:40:53 +0200 Subject: [PATCH 168/762] Clarify comment. Fix #1032. --- src/websockets/datastructures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 1ff586abd..5afe86931 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -88,7 +88,7 @@ def copy(self) -> Headers: return copy def serialize(self) -> bytes: - # Headers only contain ASCII characters. + # Since headers only contain ASCII characters, we can keep this simple. return str(self).encode() # Collection methods From a8eb9738244f9f3ee3d2471baa09c3010cf45afc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 8 Sep 2021 21:27:43 +0200 Subject: [PATCH 169/762] Avoid creating __doc__ attributes. Sphinx still finds the docstring-that-isn't-really-a-docstring. --- src/websockets/datastructures.py | 2 +- src/websockets/typing.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 5afe86931..d5c061cf8 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -165,4 +165,4 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] -HeadersLike.__doc__ = """Types accepted where :class:`Headers` is expected""" +"""Types accepted where :class:`Headers` is expected.""" diff --git a/src/websockets/typing.py b/src/websockets/typing.py index dadee7aba..e672ba006 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -17,7 +17,7 @@ # Public types used in the signature of public APIs Data = Union[str, bytes] -Data.__doc__ = """Types supported in a WebSocket message: +"""Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 @@ -27,34 +27,34 @@ LoggerLike = Union[logging.Logger, logging.LoggerAdapter] -LoggerLike.__doc__ = """Types accepted where a :class:`~logging.Logger` is expected.""" +"""Types accepted where a :class:`~logging.Logger` is expected.""" Origin = NewType("Origin", str) -Origin.__doc__ = """Value of a ``Origin`` header.""" +"""Value of a ``Origin`` header.""" Subprotocol = NewType("Subprotocol", str) -Subprotocol.__doc__ = """Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" +"""Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" ExtensionName = NewType("ExtensionName", str) -ExtensionName.__doc__ = """Name of a WebSocket extension.""" +"""Name of a WebSocket extension.""" ExtensionParameter = Tuple[str, Optional[str]] -ExtensionParameter.__doc__ = """Parameter of a WebSocket extension.""" +"""Parameter of a WebSocket extension.""" # Private types ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] -ExtensionHeader.__doc__ = """Extension in a ``Sec-WebSocket-Extensions`` header.""" +"""Extension in a ``Sec-WebSocket-Extensions`` header.""" ConnectionOption = NewType("ConnectionOption", str) -ConnectionOption.__doc__ = """Connection option in a ``Connection`` header.""" +"""Connection option in a ``Connection`` header.""" UpgradeProtocol = NewType("UpgradeProtocol", str) -UpgradeProtocol.__doc__ = """Upgrade protocol in an ``Upgrade`` header.""" +"""Upgrade protocol in an ``Upgrade`` header.""" From 0cf844146b4764c54943691ce5983596cd9e3272 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 8 Sep 2021 22:59:09 +0200 Subject: [PATCH 170/762] Remove unnecessary parameters from reject(). --- src/websockets/server.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index 5f7bec30d..bb91bdc77 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -99,7 +99,7 @@ def accept(self, request: Request) -> Response: request.exception = exc if self.debug: self.logger.debug("! invalid upgrade", exc_info=True) - return self.reject( + response = self.reject( http.HTTPStatus.UPGRADE_REQUIRED, ( f"Failed to open a WebSocket connection: {exc}.\n" @@ -107,8 +107,9 @@ def accept(self, request: Request) -> Response: f"You cannot access a WebSocket server directly " f"with a browser. You need a WebSocket client.\n" ), - headers=Headers([("Upgrade", "websocket")]), ) + response.headers["Upgrade"] = "websocket" + return response except InvalidHandshake as exc: request.exception = exc if self.debug: @@ -415,8 +416,6 @@ def reject( self, status: http.HTTPStatus, text: str, - headers: Optional[Headers] = None, - exception: Optional[Exception] = None, ) -> Response: """ Create a HTTP response event to reject the connection. @@ -429,13 +428,15 @@ def reject( """ body = text.encode() - if headers is None: - headers = Headers() - headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Connection", "close") - headers.setdefault("Content-Length", str(len(body))) - headers.setdefault("Content-Type", "text/plain; charset=utf-8") - headers.setdefault("Server", USER_AGENT) + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "text/plain; charset=utf-8"), + ("Server", USER_AGENT), + ] + ) self.logger.info("connection failed (%d %s)", status.value, status.phrase) return Response(status.value, status.phrase, headers, body) From 724408ea72c2aaf7598174a8e276c95738c3d996 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Jun 2021 15:43:52 +0200 Subject: [PATCH 171/762] Add Sans-I/O howto guide. --- docs/howto/index.rst | 8 + docs/howto/sansio.rst | 318 ++++++++++++++++++++++++++++++++++++++ docs/topics/data-flow.svg | 63 ++++++++ 3 files changed, 389 insertions(+) create mode 100644 docs/howto/sansio.rst create mode 100644 docs/topics/data-flow.svg diff --git a/docs/howto/index.rst b/docs/howto/index.rst index d7c83dd7a..5deb7c767 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -36,3 +36,11 @@ Once your application is ready, learn how to deploy it on various platforms. supervisor nginx haproxy + +If you're integrating the Sans-I/O layer of websockets into a library, rather +than building an application with websockets, follow this guide. + +.. toctree:: + :maxdepth: 2 + + sansio diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst new file mode 100644 index 000000000..83496bff2 --- /dev/null +++ b/docs/howto/sansio.rst @@ -0,0 +1,318 @@ +Integrate the Sans-I/O layer +============================ + +.. currentmodule:: websockets + +This guide explains how to integrate the `Sans-I/O`_ layer of websockets to +add support for WebSocket in another library. + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +As a prerequisite, you should decide how you will handle network I/O and +asynchronous control flow. + +Your integration layer will provide an API for the application on one side, +will talk to the network on the other side, and will rely on websockets to +implement the protocol in the middle. + +.. image:: ../topics/data-flow.svg + :align: center + +Opening a connection +-------------------- + +Client-side +........... + +If you're building a client, parse the URI you'd like to connect to:: + + from websockets.uri import parse_uri + + wsuri = parse_uri("ws://example.com/") + +Open a TCP connection to ``(wsuri.host, wsuri.port)`` and perform a TLS +handshake if ``wsuri.secure`` is :obj:`True`. + +Initialize a :class:`~client.ClientConnection`:: + + from websockets.client import ClientConnection + + connection = ClientConnection(wsuri) + +Create a WebSocket handshake request +with :meth:`~client.ClientConnection.connect` and send it +with :meth:`~client.ClientConnection.send_request`:: + + request = connection.connect() + connection.send_request(request) + +Then, call :meth:`~connection.Connection.data_to_send` and send its output to +the network, as described in `Send data`_ below. + +The first event returned by :meth:`~connection.Connection.events_received` is +the WebSocket handshake response. + +When the handshake fails, the reason is available in ``response.exception``:: + + if response.exception is not None: + raise response.exception + +Else, the WebSocket connection is open. + +A WebSocket client API usually performs the handshake then returns a wrapper +around the network connection and the :class:`~client.ClientConnection`. + +Server-side +........... + +If you're building a server, accept network connections from clients and +perform a TLS handshake if desired. + +For each connection, initialize a :class:`~server.ServerConnection`:: + + from websockets.server import ServerConnection + + connection = ServerConnection() + +The first event returned by :meth:`~connection.Connection.events_received` is +the WebSocket handshake request. + +Create a WebSocket handshake response +with :meth:`~server.ServerConnection.accept` and send it +with :meth:`~server.ServerConnection.send_response`:: + + response = connection.accept(request) + connection.send_response(response) + +Alternatively, you may reject the WebSocket handshake and return a HTTP +response with :meth:`~server.ServerConnection.reject`:: + + response = connection.reject(status, explanation) + connection.send_response(response) + +Then, call :meth:`~connection.Connection.data_to_send` and send its output to +the network, as described in `Send data`_ below. + +Even when you call :meth:`~server.ServerConnection.accept`, the WebSocket +handshake may fail if the request is incorrect or unsupported. + +When the handshake fails, the reason is available in ``request.exception``:: + + if request.exception is not None: + raise request.exception + +Else, the WebSocket connection is open. + +A WebSocket server API usually builds a wrapper around the network connection +and the :class:`~server.ServerConnection`. Then it invokes a connection +handler that accepts the wrapper in argument. + +It may also provide a way to close all connections and to shut down the server +gracefully. + +Going forwards, this guide focuses on handling an individual connection. + +From the network to the application +----------------------------------- + +Go through the five steps below until you reach the end of the data stream. + +Receive data +............ + +When receiving data from the network, feed it to the connection's +:meth:`~connection.Connection.receive_data` method. + +When reaching the end of the data stream, call the connection's +:meth:`~connection.Connection.receive_eof` method. + +For example, if ``sock`` is a :obj:`~socket.socket`:: + + try: + data = sock.recv(4096) + except OSError: # socket closed + data = b"" + if data: + connection.receive_data(data) + else: + connection.receive_eof() + +These methods aren't expected to raise exceptions — unless you call them again +after calling :meth:`~connection.Connection.receive_eof`, which is an error. +(If you get an exception, please file a bug!) + +Send data +......... + +Then, call :meth:`~connection.Connection.data_to_send` and send its output to +the network:: + + for data in connection.data_to_send(): + if data: + sock.sendall(data) + else: + sock.shutdown(socket.SHUT_WR) + +The empty bytestring signals the end of the data stream. When you see it, you +must half-close the TCP connection. + +Sending data right after receiving data is necessary because websockets +responds to ping frames, close frames, and incorrect inputs automatically. + +Expect TCP connection to close +.............................. + +Closing a WebSocket connection normally involves a two-way WebSocket closing +handshake. Then, regardless of whether the closure is normal or abnormal, the +server starts the four-way TCP closing handshake. If the network fails at the +wrong point, you can end up waiting until the TCP timeout, which is very long. + +To prevent dangling TCP connections when you expect the end of the data stream +but you never reach it, call :meth:`~connection.Connection.close_expected` +and, if it returns :obj:`True`, schedule closing the TCP connection after a +short timeout:: + + # start a new execution thread to run this code + sleep(10) + sock.close() # does nothing if the socket is already closed + +If the connection is still open when the timeout elapses, closing the socket +makes the execution thread that reads from the socket reach the end of the +data stream, possibly with an exception. + +Close TCP connection +.................... + +If you called :meth:`~connection.Connection.receive_eof`, close the TCP +connection now. This is a clean closure because the receive buffer is empty. + +After :meth:`~connection.Connection.receive_eof` signals the end of the read +stream, :meth:`~connection.Connection.data_to_send` always signals the end of +the write stream, unless it already ended. So, at this point, the TCP +connection is already half-closed. The only reason for closing it now is to +release resources related to the socket. + +Now you can exit the loop relaying data from the network to the application. + +Receive events +.............. + +Finally, call :meth:`~connection.Connection.events_received` to obtain events +parsed from the data provided to :meth:`~connection.Connection.receive_data`:: + + events = connection.events_received() + +The first event will be the WebSocket opening handshake request or response. +See `Opening a connection`_ above for details. + +All later events are WebSocket frames. There are two types of frames: + +* Data frames contain messages transferred over the WebSocket connections. You + should provide them to the application. See `Fragmentation`_ below for + how to reassemble messages from frames. +* Control frames provide information about the connection's state. The main + use case is to expose an abstraction over ping and pong to the application. + Keep in mind that websockets responds to ping frames and close frames + automatically. Don't duplicate this functionality! + +From the application to the network +----------------------------------- + +The connection object provides one method for each type of WebSocket frame. + +For sending a data frame: + +* :meth:`~connection.Connection.send_continuation` +* :meth:`~connection.Connection.send_text` +* :meth:`~connection.Connection.send_binary` + +These methods raise :exc:`~exceptions.ProtocolError` if you don't set +the :attr:`FIN ` bit correctly in fragmented +messages. + +For sending a control frame: + +* :meth:`~connection.Connection.send_close` +* :meth:`~connection.Connection.send_ping` +* :meth:`~connection.Connection.send_pong` + +:meth:`~connection.Connection.send_close` initiates the closing handshake. +See `Closing a connection`_ below for details. + +If you encounter an unrecoverable error and you must fail the WebSocket +connection, call :meth:`~connection.Connection.fail`. + +After any of the above, call :meth:`~connection.Connection.data_to_send` and +send its output to the network, as shown in `Send data`_ above. + +If you called :meth:`~connection.Connection.send_close` +or :meth:`~connection.Connection.fail`, you expect the end of the data +stream. You should follow the process described in `Close TCP connection`_ +above in order to prevent dangling TCP connections. + +Closing a connection +-------------------- + +Under normal circumstances, when a server wants to close the TCP connection: + +* it closes the write side; +* it reads until the end of the stream, because it expects the client to close + the read side; +* it closes the socket. + +When a client wants to close the TCP connection: + +* it reads until the end of the stream, because it expects the server to close + the read side; +* it closes the write side; +* it closes the socket. + +Applying the rules described earlier in this document gives the intended +result. As a reminder, the rules are: + +* When :meth:`~connection.Connection.data_to_send` returns the empty + bytestring, close the write side of the TCP connection. +* When you reach the end of the read stream, close the TCP connection. +* When :meth:`~connection.Connection.close_expected` returns :obj:`True`, if + you don't reach the end of the read stream quickly, close the TCP connection. + +Fragmentation +------------- + +WebSocket messages may be fragmented. Since this is a protocol-level concern, +you may choose to reassemble fragmented messages before handing them over to +the application. + +To reassemble a message, read data frames until you get a frame where +the :attr:`FIN ` bit is set, then concatenate +the payloads of all frames. + +You will never receive an inconsistent sequence of frames because websockets +raises a :exc:`~exceptions.ProtocolError` and fails the connection when this +happens. However, you may receive an incomplete sequence if the connection +drops in the middle of a fragmented message. + +Tips +---- + +Serialize operations +.................... + +The Sans-I/O layer expects to run sequentially. If your interact with it from +multiple threads or coroutines, you must ensure correct serialization. This +should happen automatically in a cooperative multitasking environment. + +However, you still have to make sure you don't break this property by +accident. For example, serialize writes to the network +when :meth:`~connection.Connection.data_to_send` returns multiple values to +prevent concurrent writes from interleaving incorrectly. + +Avoid buffers +............. + +The Sans-I/O layer doesn't do any buffering. It makes events available in +:meth:`~connection.Connection.events_received` as soon as they're received. + +You should make incoming messages available to the application immediately and +stop further processing until the application fetches them. This will usually +result in the best performance. diff --git a/docs/topics/data-flow.svg b/docs/topics/data-flow.svg new file mode 100644 index 000000000..749d9d482 --- /dev/null +++ b/docs/topics/data-flow.svg @@ -0,0 +1,63 @@ +Integration layerSans-I/O layerApplicationreceivemessagessendmessagesNetworksenddatareceivedatareceivebytessendbytessendeventsreceiveevents \ No newline at end of file From be1203b8f6903905024c8e1e3b0b6a5c4e290c99 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 5 Sep 2021 17:31:54 +0200 Subject: [PATCH 172/762] Add API documentation for the Sans-I/O layer. --- docs/reference/client.rst | 112 +++++++++++++------ docs/reference/common.rst | 113 ++++++++++++++----- docs/reference/server.rst | 150 +++++++++++++++++--------- docs/reference/types.rst | 4 +- src/websockets/client.py | 33 +++++- src/websockets/connection.py | 204 ++++++++++++++++++++++++++--------- src/websockets/server.py | 48 +++++++-- 7 files changed, 496 insertions(+), 168 deletions(-) diff --git a/docs/reference/client.rst b/docs/reference/client.rst index dc31dc032..daf01ef58 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -3,60 +3,108 @@ Client .. automodule:: websockets.client - Opening a connection - -------------------- +asyncio +------- - .. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) - :async: +Opening a connection +.................... - .. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) - :async: +.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: - Using a connection - ------------------ +.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: - .. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +Using a connection +.................. - .. automethod:: recv +.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) - .. automethod:: send + .. automethod:: recv - .. automethod:: close + .. automethod:: send - .. automethod:: wait_closed + .. automethod:: close - .. automethod:: ping + .. automethod:: wait_closed - .. automethod:: pong + .. automethod:: ping - WebSocket connection objects also provide these attributes: + .. automethod:: pong - .. autoattribute:: id + WebSocket connection objects also provide these attributes: - .. autoattribute:: logger + .. autoattribute:: id - .. autoproperty:: local_address + .. autoattribute:: logger - .. autoproperty:: remote_address + .. autoproperty:: local_address - .. autoproperty:: open + .. autoproperty:: remote_address - .. autoproperty:: closed + .. autoproperty:: open - The following attributes are available after the opening handshake, - once the WebSocket connection is open: + .. autoproperty:: closed - .. autoattribute:: path + The following attributes are available after the opening handshake, + once the WebSocket connection is open: - .. autoattribute:: request_headers + .. autoattribute:: path - .. autoattribute:: response_headers + .. autoattribute:: request_headers - .. autoattribute:: subprotocol + .. autoattribute:: response_headers - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: subprotocol - .. autoproperty:: close_code + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: - .. autoproperty:: close_reason + .. autoproperty:: close_code + + .. autoproperty:: close_reason + +Sans-I/O +-------- + +.. autoclass:: ClientConnection(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: connect + + .. automethod:: send_request + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: state + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + .. autoproperty:: close_exc diff --git a/docs/reference/common.rst b/docs/reference/common.rst index f5422bc35..f2683bc77 100644 --- a/docs/reference/common.rst +++ b/docs/reference/common.rst @@ -1,53 +1,114 @@ Both sides ========== +asyncio +------- + .. automodule:: websockets.legacy.protocol - Using a connection - ------------------ +.. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + +Sans-I/O +-------- + +.. automodule:: websockets.connection + +.. autoclass:: Connection(side, state=State.OPEN, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary - .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + .. automethod:: send_close - .. automethod:: recv + .. automethod:: send_ping - .. automethod:: send + .. automethod:: send_pong - .. automethod:: close + .. automethod:: fail - .. automethod:: wait_closed + .. automethod:: events_received - .. automethod:: ping + .. automethod:: data_to_send - .. automethod:: pong + .. automethod:: close_expected - WebSocket connection objects also provide these attributes: + .. autoattribute:: id - .. autoattribute:: id + .. autoattribute:: logger - .. autoattribute:: logger + .. autoproperty:: state - .. autoproperty:: local_address + .. autoproperty:: close_code - .. autoproperty:: remote_address + .. autoproperty:: close_reason - .. autoproperty:: open + .. autoproperty:: close_exc - .. autoproperty:: closed +.. autoclass:: Side - The following attributes are available after the opening handshake, - once the WebSocket connection is open: + .. autoattribute:: SERVER - .. autoattribute:: path + .. autoattribute:: CLIENT - .. autoattribute:: request_headers +.. autoclass:: State - .. autoattribute:: response_headers + .. autoattribute:: CONNECTING - .. autoattribute:: subprotocol + .. autoattribute:: OPEN - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: CLOSING - .. autoproperty:: close_code + .. autoattribute:: CLOSED - .. autoproperty:: close_reason +.. autodata:: SEND_EOF diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 1864594f5..a8eb9c5b4 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -3,96 +3,148 @@ Server .. automodule:: websockets.server - Starting a server - ----------------- +asyncio +------- - .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) - :async: +Starting a server +................. - .. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) - :async: +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: - Stopping a server - ----------------- +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: - .. autoclass:: WebSocketServer +Stopping a server +................. - .. automethod:: close +.. autoclass:: WebSocketServer - .. automethod:: wait_closed + .. automethod:: close - .. autoattribute:: sockets + .. automethod:: wait_closed - Using a connection - ------------------ + .. autoattribute:: sockets - .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +Using a connection +.................. - .. automethod:: recv +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) - .. automethod:: send + .. automethod:: recv - .. automethod:: close + .. automethod:: send - .. automethod:: wait_closed + .. automethod:: close - .. automethod:: ping + .. automethod:: wait_closed - .. automethod:: pong + .. automethod:: ping - You can customize the opening handshake in a subclass by overriding these methods: + .. automethod:: pong - .. automethod:: process_request + You can customize the opening handshake in a subclass by overriding these methods: - .. automethod:: select_subprotocol + .. automethod:: process_request - WebSocket connection objects also provide these attributes: + .. automethod:: select_subprotocol - .. autoattribute:: id + WebSocket connection objects also provide these attributes: - .. autoattribute:: logger + .. autoattribute:: id - .. autoproperty:: local_address + .. autoattribute:: logger - .. autoproperty:: remote_address + .. autoproperty:: local_address - .. autoproperty:: open + .. autoproperty:: remote_address - .. autoproperty:: closed + .. autoproperty:: open - The following attributes are available after the opening handshake, - once the WebSocket connection is open: + .. autoproperty:: closed - .. autoattribute:: path + The following attributes are available after the opening handshake, + once the WebSocket connection is open: - .. autoattribute:: request_headers + .. autoattribute:: path - .. autoattribute:: response_headers + .. autoattribute:: request_headers - .. autoattribute:: subprotocol + .. autoattribute:: response_headers - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: subprotocol - .. autoproperty:: close_code + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: - .. autoproperty:: close_reason + .. autoproperty:: close_code + + .. autoproperty:: close_reason Basic authentication --------------------- +.................... .. automodule:: websockets.auth - websockets supports HTTP Basic Authentication according to - :rfc:`7235` and :rfc:`7617`. +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: basic_auth_protocol_factory + +.. autoclass:: BasicAuthWebSocketServerProtocol + + .. autoattribute:: realm + + .. autoattribute:: username + + .. automethod:: check_credentials + +.. currentmodule:: websockets.server + +Sans-I/O +-------- + +.. autoclass:: ServerConnection(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: accept + + .. automethod:: reject + + .. automethod:: send_response + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + .. autoattribute:: id - .. autofunction:: basic_auth_protocol_factory + .. autoattribute:: logger - .. autoclass:: BasicAuthWebSocketServerProtocol + .. autoproperty:: state - .. autoattribute:: realm + .. autoproperty:: close_code - .. autoattribute:: username + .. autoproperty:: close_reason - .. automethod:: check_credentials + .. autoproperty:: close_exc diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 3dab553af..4b7952553 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -1,8 +1,6 @@ Types ===== -.. autodata:: websockets.datastructures.HeadersLike - .. automodule:: websockets.typing .. autodata:: Data @@ -17,4 +15,6 @@ Types .. autodata:: ExtensionParameter +.. autodata:: websockets.connection.Event +.. autodata:: websockets.datastructures.HeadersLike diff --git a/src/websockets/client.py b/src/websockets/client.py index 42c7aeee9..6f2da5a6e 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -44,6 +44,28 @@ class ClientConnection(Connection): + """ + Sans-I/O implementation of a WebSocket client connection. + + Args: + wsuri: URI of the WebSocket server, parsed + with :func:`~websockets.uri.parse_uri`. + origin: value of the ``Origin`` header. This is useful when connecting + to a server that validates the ``Origin`` header to defend against + Cross-Site WebSocket Hijacking attacks. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of decreasing + preference. + state: initial state of the WebSocket connection. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + logger: logger for this connection; + defaults to ``logging.getLogger("websockets.client")``; + see the :doc:`logging guide <../topics/logging>` for details. + + """ + def __init__( self, wsuri: WebSocketURI, @@ -68,7 +90,14 @@ def __init__( def connect(self) -> Request: # noqa: F811 """ - Create a WebSocket handshake request event to open a connection. + Create a handshake request to open a connection. + + You must send the handshake request with :meth:`send_request`. + + You can modify it before sending it, for example to add HTTP headers. + + Returns: + Request: WebSocket handshake request event to send to the server. """ headers = Headers() @@ -275,7 +304,7 @@ def send_request(self, request: Request) -> None: Send a handshake request to the server. Args: - request: WebSocket handshake request event to send. + request: WebSocket handshake request event. """ if self.debug: diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 684664860..8661a148b 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -38,12 +38,12 @@ ] Event = Union[Request, Response, Frame] - - -# A WebSocket connection is either a server or a client. +"""Events that :meth:`~Connection.events_received` may return.""" class Side(enum.IntEnum): + """A WebSocket connection is either a server or a client.""" + SERVER, CLIENT = range(2) @@ -51,10 +51,9 @@ class Side(enum.IntEnum): CLIENT = Side.CLIENT -# A WebSocket connection goes through the following four states, in order: - - class State(enum.IntEnum): + """A WebSocket connection is in one of these four states.""" + CONNECTING, OPEN, CLOSING, CLOSED = range(4) @@ -64,12 +63,26 @@ class State(enum.IntEnum): CLOSED = State.CLOSED -# Sentinel to signal that the connection should be closed. - SEND_EOF = b"" +"""Sentinel signaling that the TCP connection must be half-closed.""" class Connection: + """ + Sans-I/O implementation of a WebSocket connection. + + Args: + side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. + state: initial state of the WebSocket connection. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + logger: logger for this connection; depending on ``side``, + defaults to ``logging.getLogger("websockets.client")`` + or ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../topics/logging>` for details. + + """ + def __init__( self, side: Side, @@ -78,12 +91,14 @@ def __init__( logger: Optional[LoggerLike] = None, ) -> None: # Unique identifier. For logs. - self.id = uuid.uuid4() + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger(f"websockets.{side.name.lower()}") - self.logger = logger + self.logger: LoggerLike = logger + """Logger for this connection.""" # Track if DEBUG is enabled. Shortcut logging calls if it isn't. self.debug = logger.isEnabledFor(logging.DEBUG) @@ -126,12 +141,12 @@ def __init__( next(self.parser) # start coroutine self.parser_exc: Optional[Exception] = None - # Public attributes - @property def state(self) -> State: """ - Connection State defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. + WebSocket connection state. + + Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. """ return self._state @@ -145,9 +160,12 @@ def state(self, state: State) -> None: @property def close_code(self) -> Optional[int]: """ - Connection Close Code defined in 7.1.5 of :rfc:`6455`. + `WebSocket close code`_. - Available once the connection is closed. + .. _WebSocket close code: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + + :obj:`None` if the connection isn't closed yet. """ if self.state is not CLOSED: @@ -160,9 +178,12 @@ def close_code(self) -> Optional[int]: @property def close_reason(self) -> Optional[str]: """ - Connection Close Reason defined in 7.1.6 of :rfc:`6455`. + `WebSocket close reason`_. + + .. _WebSocket close reason: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 - Available once the connection is closed. + :obj:`None` if the connection isn't closed yet. """ if self.state is not CLOSED: @@ -175,13 +196,20 @@ def close_reason(self) -> Optional[str]: @property def close_exc(self) -> ConnectionClosed: """ - Exception raised when trying to interact with a closed connection. + Exception to raise when trying to interact with a closed connection. + + Don't raise this exception while the connection :attr:`state` + is :attr:`~websockets.connection.State.CLOSING`; wait until + it's :attr:`~websockets.connection.State.CLOSED`. - Available once the connection is closed. If you need to raise this - exception while the connection is closing, wait until it's closed. + Indeed, the exception includes the close code and reason, which are + known only once the connection is closed. + + Raises: + AssertionError: if the connection isn't closed yet. """ - assert self.state is CLOSED + assert self.state is CLOSED, "connection isn't closed yet" exc_type: Type[ConnectionClosed] if ( self.close_rcvd is not None @@ -205,15 +233,15 @@ def close_exc(self) -> ConnectionClosed: def receive_data(self, data: bytes) -> None: """ - Receive data from the connection. + Receive data from the network. After calling this method: - - You must call :meth:`data_to_send` and send this data. - - You should call :meth:`events_received` and process these events. + - You must call :meth:`data_to_send` and send this data to the network. + - You should call :meth:`events_received` and process resulting events. Raises: - EOFError: if :meth:`receive_eof` was called before. + EOFError: if :meth:`receive_eof` was called earlier. """ self.reader.feed_data(data) @@ -221,16 +249,16 @@ def receive_data(self, data: bytes) -> None: def receive_eof(self) -> None: """ - Receive the end of the data stream from the connection. + Receive the end of the data stream from the network. After calling this method: - - You must call :meth:`data_to_send` and send this data. - - You aren't exepcted to call :meth:`events_received` as it won't - return any new events. + - You must call :meth:`data_to_send` and send this data to the network. + - You aren't expected to call :meth:`events_received`; it won't return + any new events. Raises: - EOFError: if :meth:`receive_eof` was called before. + EOFError: if :meth:`receive_eof` was called earlier. """ self.reader.feed_eof() @@ -240,7 +268,19 @@ def receive_eof(self) -> None: def send_continuation(self, data: bytes, fin: bool) -> None: """ - Send a continuation frame. + Send a `Continuation frame`_. + + .. _Continuation frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing the same kind of data + as the initial frame. + fin: FIN bit; set it to :obj:`True` if this is the last frame + of a fragmented message and to :obj:`False` otherwise. + + Raises: + ProtocolError: if a fragmented message isn't in progress. """ if not self.expect_continuation_frame: @@ -250,7 +290,18 @@ def send_continuation(self, data: bytes, fin: bool) -> None: def send_text(self, data: bytes, fin: bool = True) -> None: """ - Send a text frame. + Send a `Text frame`_. + + .. _Text frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing text encoded with UTF-8. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: if a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -260,7 +311,18 @@ def send_text(self, data: bytes, fin: bool = True) -> None: def send_binary(self, data: bytes, fin: bool = True) -> None: """ - Send a binary frame. + Send a `Binary frame`_. + + .. _Binary frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing arbitrary binary data. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: if a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -270,7 +332,18 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: def send_close(self, code: Optional[int] = None, reason: str = "") -> None: """ - Send a connection close frame. + Send a `Close frame`_. + + .. _Close frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + + Parameters: + code: close code. + reason: close reason. + + Raises: + ProtocolError: if a fragmented message is being sent, if the code + isn't valid, or if a reason is provided without a code """ if self.expect_continuation_frame: @@ -291,22 +364,43 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: def send_ping(self, data: bytes) -> None: """ - Send a ping frame. + Send a `Ping frame`_. + + .. _Ping frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + Parameters: + data: payload containing arbitrary binary data. """ self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: """ - Send a pong frame. + Send a `Pong frame`_. + + .. _Pong frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + Parameters: + data: payload containing arbitrary binary data. """ self.send_frame(Frame(OP_PONG, data)) def fail(self, code: int, reason: str = "") -> None: """ - Fail the WebSocket connection. + `Fail the WebSocket connection`_. + .. _Fail the WebSocket connection: + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 + + Parameters: + code: close code + reason: close reason + + Raises: + ProtocolError: if the code isn't valid. """ # 7.1.7. Fail the WebSocket Connection @@ -339,11 +433,14 @@ def fail(self, code: int, reason: str = "") -> None: def events_received(self) -> List[Event]: """ - Return events read from the connection. + Fetch events generated from data received from the network. - Call this method immediately after calling any of the ``receive_*()`` - methods and process the events. + Call this method immediately after any of the ``receive_*()`` methods. + Process resulting events, likely by passing them to the application. + + Returns: + List[Event]: Events read from the connection. """ events, self.events = self.events, [] return events @@ -352,13 +449,19 @@ def events_received(self) -> List[Event]: def data_to_send(self) -> List[bytes]: """ - Return data to write to the connection. + Obtain data to send to the network. + + Call this method immediately after any of the ``receive_*()``, + ``send_*()``, or :meth:`fail` methods. + + Write resulting data to the connection. - Call this method immediately after calling any of the ``receive_*()``, - ``send_*()``, or ``fail()`` methods and write the data to the + The empty bytestring :data:`~websockets.connection.SEND_EOF` signals + the end of the data stream. When you receive it, half-close the TCP connection. - The empty bytestring signals the end of the data stream. + Returns: + List[bytes]: Data to write to the connection. """ writes, self.writes = self.writes, [] @@ -366,11 +469,16 @@ def data_to_send(self) -> List[bytes]: def close_expected(self) -> bool: """ - Tell whether the TCP connection is expected to close soon. + Tell if the TCP connection is expected to close soon. + + Call this method immediately after any of the ``receive_*()`` or + :meth:`fail` methods. + + If it returns :obj:`True`, schedule closing the TCP connection after a + short timeout if the other side hasn't already closed it. - Call this method immediately after calling any of the ``receive_*()`` - or ``fail_*()`` methods and, if it returns :obj:`True`, schedule - closing the TCP connection after a short timeout. + Returns: + bool: Whether the TCP connection is expected to close soon. """ # We already got a TCP Close if and only if the state is CLOSED. diff --git a/src/websockets/server.py b/src/websockets/server.py index bb91bdc77..2d9b4f9a8 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -45,8 +45,26 @@ class ServerConnection(Connection): - - side = SERVER + """ + Sans-I/O implementation of a WebSocket server connection. + + Args: + origins: acceptable values of the ``Origin`` header; include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket + Hijacking attacks. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of decreasing + preference. + state: initial state of the WebSocket connection. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + logger: logger for this connection; + defaults to ``logging.getLogger("websockets.client")``; + see the :doc:`logging guide <../topics/logging>` for details. + + """ def __init__( self, @@ -69,16 +87,20 @@ def __init__( def accept(self, request: Request) -> Response: """ - Create a WebSocket handshake response event to accept the connection. + Create a handshake response to accept the connection. + + If the connection cannot be established, the handshake response + actually rejects the handshake. - If the connection cannot be established, create a HTTP response event - to reject the handshake. + You must send the handshake response with :meth:`send_response`. + + You can modify it before sending it, for example to add HTTP headers. Args: - request: handshake request event received from the client. + request: WebSocket handshake request event received from the client. Returns: - Response: handshake response event to send to the client. + Response: WebSocket handshake response event to send to the client. """ try: @@ -418,13 +440,21 @@ def reject( text: str, ) -> Response: """ - Create a HTTP response event to reject the connection. + Create a handshake response to reject the connection. A short plain text response is the best fallback when failing to establish a WebSocket connection. + You must send the handshake response with :meth:`send_response`. + + You can modify it before sending it, for example to alter HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; will be encoded to UTF-8. + Returns: - Response: HTTP handshake response to send to the client. + Response: WebSocket handshake response event to send to the client. """ body = text.encode() From a04bfdb8f7eaa0071f3b37efe83960763311fa6f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 07:45:58 +0200 Subject: [PATCH 173/762] Add changelog. --- docs/project/changelog.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 84ab6a314..736dfde41 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -54,6 +54,7 @@ Backwards-incompatible changes :obj:`None` to disable the timeout entirely. .. admonition:: The ``legacy_recv`` option is deprecated. + :class: note See the release notes of websockets 3.0 for details. @@ -72,6 +73,14 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 10.0 introduces a `Sans-I/O API + `_ for easier integration + in third-party libraries. + :class: important + + If you're integrating websockets in a library, rather than just using it, + look at the :doc:`Sans-I/O integration guide <../howto/sansio>`. + * Added compatibility with Python 3.10. * Added :func:`~websockets.broadcast` to send a message to many clients. From 13eff12bb4c995b50154fdc250281c92ddccaca0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 07:55:31 +0200 Subject: [PATCH 174/762] Bump version number. --- docs/conf.py | 2 +- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index ffe61f7ba..d22b85a82 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,7 +26,7 @@ author = "Aymeric Augustin" # The full version, including alpha/beta/rc tags -release = "9.1" +release = "10.0" # -- General configuration --------------------------------------------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 736dfde41..90e08728b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.0 ---- -*In development* +*September 9, 2021* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index a7901ef92..168f8b054 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1 @@ -version = "9.1" +version = "10.0" From 32d9a52a18960004780a735cabb2d881969cb6ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 20:44:18 +0200 Subject: [PATCH 175/762] Remove references to Python 2. RIP --- README.rst | 2 -- docs/howto/faq.rst | 9 --------- 2 files changed, 11 deletions(-) diff --git a/README.rst b/README.rst index 6b9b67672..48d2637c1 100644 --- a/README.rst +++ b/README.rst @@ -125,8 +125,6 @@ Why shouldn't I use ``websockets``? at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. -* If you want to use Python 2: ``websockets`` builds upon ``asyncio`` which - only works on Python 3. ``websockets`` requires Python ≥ 3.7. What else? ---------- diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index a32e5ec1b..62c017a88 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -428,15 +428,6 @@ coroutines make it easier to manage control flow in concurrent code. If you prefer callback-based APIs, you should use another library. -Is there a Python 2 version? -............................ - -No, there isn't. - -Python 2 reached end of life on January 1st, 2020. - -Before that date, websockets required asyncio and therefore Python 3. - Why do I get the error: ``module 'websockets' has no attribute '...'``? ....................................................................... From c8428ced9850b0838edd185605b076b4b28ad406 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 22:01:09 +0200 Subject: [PATCH 176/762] Add security policy --- SECURITY.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..556217a4d --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,5 @@ +# Security policy + +Only the latest version receives security updates. + +Please report vulnerabilities [via Tidelift](https://tidelift.com/docs/security). From 9b8a8d1cb560d292aecde52252289e3560760167 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 21:40:10 +0200 Subject: [PATCH 177/762] Remove path argument from connection handlers. Ensure backwards-compatibility. Ref #1038. --- README.rst | 2 +- compliance/test_server.py | 2 +- docs/howto/faq.rst | 18 ++++---- docs/intro/index.rst | 12 +++--- docs/project/changelog.rst | 13 ++++++ docs/topics/authentication.rst | 8 ++-- docs/topics/broadcast.rst | 8 ++-- example/basic_auth_server.py | 2 +- example/counter.py | 2 +- example/deployment/haproxy/app.py | 2 +- example/deployment/heroku/app.py | 2 +- example/deployment/kubernetes/app.py | 2 +- example/deployment/nginx/app.py | 2 +- example/deployment/supervisor/app.py | 2 +- example/django/authentication.py | 2 +- example/django/notifications.py | 2 +- example/echo.py | 2 +- example/health_check_server.py | 2 +- example/secure_server.py | 2 +- example/server.py | 2 +- example/show_time.py | 2 +- example/shutdown_server.py | 2 +- example/unix_server.py | 2 +- experiments/authentication/app.py | 10 ++--- experiments/broadcast/server.py | 2 +- experiments/compression/server.py | 2 +- src/websockets/legacy/server.py | 64 ++++++++++++++++++++++------ tests/legacy/test_client_server.py | 39 +++++++++++------ 28 files changed, 140 insertions(+), 72 deletions(-) diff --git a/README.rst b/README.rst index 48d2637c1..99e477867 100644 --- a/README.rst +++ b/README.rst @@ -63,7 +63,7 @@ And here's an echo server: import asyncio from websockets import serve - async def echo(websocket, path): + async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/compliance/test_server.py b/compliance/test_server.py index 14ac90fe6..92f895d92 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -13,7 +13,7 @@ HOST, PORT = "127.0.0.1", 8642 -async def echo(ws, path): +async def echo(ws): async for msg in ws: await ws.send(msg) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 62c017a88..bef61bdab 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -22,12 +22,12 @@ before returning. For example, if your handler has a structure similar to:: - async def handler(websocket, path): + async def handler(websocket): asyncio.create_task(do_some_work()) change it to:: - async def handler(websocket, path): + async def handler(websocket): await do_some_work() Why does the server close the connection after processing one message? @@ -38,12 +38,12 @@ process multiple messages. For example, if your handler looks like this:: - async def handler(websocket, path): + async def handler(websocket): print(websocket.recv()) change it like this:: - async def handler(websocket, path): + async def handler(websocket): async for message in websocket: print(message) @@ -58,12 +58,12 @@ Any call that may take some time must be asynchronous. For example, if you have:: - async def handler(websocket, path): + async def handler(websocket): time.sleep(1) change it to:: - async def handler(websocket, path): + async def handler(websocket): await asyncio.sleep(1) This is part of learning asyncio. It isn't specific to websockets. @@ -82,7 +82,7 @@ You can bind additional arguments to the connection handler with import functools import websockets - async def handler(websocket, path, extra_argument): + async def handler(websocket, extra_argument): ... bound_handler = functools.partial(handler, extra_argument='spam') @@ -104,7 +104,7 @@ To access HTTP headers during the WebSocket handshake, you can override Once the connection is established, they're available in :attr:`~server.WebSocketServerProtocol.request_headers`:: - async def handler(websocket, path): + async def handler(websocket): cookies = websocket.request_headers["Cookie"] How do I get the IP address of the client connecting to my server? @@ -112,7 +112,7 @@ How do I get the IP address of the client connecting to my server? It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: - async def handler(websocket, path): + async def handler(websocket): remote_ip = websocket.remote_address[0] How do I set which IP addresses my server listens to? diff --git a/docs/intro/index.rst b/docs/intro/index.rst index c8426719c..bd7c48f81 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -131,7 +131,7 @@ Consumer For receiving messages and passing them to a ``consumer`` coroutine:: - async def consumer_handler(websocket, path): + async def consumer_handler(websocket): async for message in websocket: await consumer(message) @@ -145,7 +145,7 @@ Producer For getting messages from a ``producer`` coroutine and sending them:: - async def producer_handler(websocket, path): + async def producer_handler(websocket): while True: message = await producer() await websocket.send(message) @@ -163,11 +163,11 @@ Both sides You can read and write messages on the same connection by combining the two patterns shown above and running the two tasks in parallel:: - async def handler(websocket, path): + async def handler(websocket): consumer_task = asyncio.ensure_future( - consumer_handler(websocket, path)) + consumer_handler(websocket)) producer_task = asyncio.ensure_future( - producer_handler(websocket, path)) + producer_handler(websocket)) done, pending = await asyncio.wait( [consumer_task, producer_task], return_when=asyncio.FIRST_COMPLETED, @@ -186,7 +186,7 @@ unregister them when they disconnect. connected = set() - async def handler(websocket, path): + async def handler(websocket): # Register. connected.add(websocket) try: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 90e08728b..2a714f945 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,19 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +10.1 +---- + +*In development* + +Improvements +............ + +* Made the second parameter of connection handlers optional. It will be + deprecated in the next major release. The request path is available in + the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of + the first argument. + 10.0 ---- diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 8beeea9aa..4d702b2f6 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -185,7 +185,7 @@ connection: .. code:: python - async def first_message_handler(websocket, path): + async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) if user is None: @@ -224,7 +224,7 @@ the user. If authentication fails, it returns a HTTP 401: self.user = user - async def query_param_handler(websocket, path): + async def query_param_handler(websocket): user = websocket.user ... @@ -273,7 +273,7 @@ the user. If authentication fails, it returns a HTTP 401: self.user = user - async def cookie_handler(websocket, path): + async def cookie_handler(websocket): user = websocket.user ... @@ -311,7 +311,7 @@ the user. If authentication fails, it returns a HTTP 401: self.user = user return True - async def user_info_handler(websocket, path): + async def user_info_handler(websocket): user = websocket.user ... diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index a90cc2d70..531d8ca12 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -27,7 +27,7 @@ Integrating them is left as an exercise for the reader. You could start with:: import asyncio import websockets - async def handler(websocket, path): + async def handler(websocket): ... async def broadcast(message): @@ -64,7 +64,7 @@ Here's a connection handler that registers clients in a global variable:: CLIENTS = set() - async def handler(websocket, path): + async def handler(websocket): CLIENTS.add(websocket) try: await websocket.wait_closed() @@ -241,7 +241,7 @@ run a task that gets messages from the queue and sends them to the client:: message = await queue.get() await websocket.send(message) - async def handler(websocket, path): + async def handler(websocket): queue = asyncio.Queue() relay_task = asyncio.create_task(relay(queue, websocket)) CLIENTS.add(queue) @@ -297,7 +297,7 @@ no references left, therefore the garbage collector deletes it. The connection handler subscribes to the stream and sends messages:: - async def handler(websocket, path): + async def handler(websocket): async for message in PUBSUB: await websocket.send(message) diff --git a/example/basic_auth_server.py b/example/basic_auth_server.py index 532c5bc51..d2efeb7e5 100755 --- a/example/basic_auth_server.py +++ b/example/basic_auth_server.py @@ -5,7 +5,7 @@ import asyncio import websockets -async def hello(websocket, path): +async def hello(websocket): greeting = f"Hello {websocket.username}!" await websocket.send(greeting) diff --git a/example/counter.py b/example/counter.py index e41f6fabb..6e33b3afc 100755 --- a/example/counter.py +++ b/example/counter.py @@ -22,7 +22,7 @@ def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) -async def counter(websocket, path): +async def counter(websocket): try: # Register user USERS.add(websocket) diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py index 2b24790dd..360479b8e 100644 --- a/example/deployment/haproxy/app.py +++ b/example/deployment/haproxy/app.py @@ -7,7 +7,7 @@ import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index ff9ba2775..d4ba3edb5 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -7,7 +7,7 @@ import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py index dcc29bd1c..a8bcef688 100755 --- a/example/deployment/kubernetes/app.py +++ b/example/deployment/kubernetes/app.py @@ -9,7 +9,7 @@ import websockets -async def slow_echo(websocket, path): +async def slow_echo(websocket): async for message in websocket: # Block the event loop! This allows saturating a single asyncio # process without opening an impractical number of connections. diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py index ad42a8b3e..24e608975 100644 --- a/example/deployment/nginx/app.py +++ b/example/deployment/nginx/app.py @@ -7,7 +7,7 @@ import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index 484566bc8..bf61983ef 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -6,7 +6,7 @@ import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/django/authentication.py b/example/django/authentication.py index bbb3db02a..7f60f8275 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -10,7 +10,7 @@ from sesame.utils import get_user -async def handler(websocket, path): +async def handler(websocket): sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) if user is None: diff --git a/example/django/notifications.py b/example/django/notifications.py index 41fb719dc..7275a1ef7 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -28,7 +28,7 @@ def get_content_types(user): } -async def handler(websocket, path): +async def handler(websocket): """Authenticate user and register connection in CONNECTIONS.""" sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) diff --git a/example/echo.py b/example/echo.py index 024f8d8ac..4b673cb17 100755 --- a/example/echo.py +++ b/example/echo.py @@ -3,7 +3,7 @@ import asyncio import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/health_check_server.py b/example/health_check_server.py index 2ca185cde..7b8bded77 100755 --- a/example/health_check_server.py +++ b/example/health_check_server.py @@ -8,7 +8,7 @@ async def health_check(path, request_headers): if path == "/healthz": return http.HTTPStatus.OK, [], b"OK\n" -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/secure_server.py b/example/secure_server.py index cd8ee0cc1..f0231bc16 100755 --- a/example/secure_server.py +++ b/example/secure_server.py @@ -7,7 +7,7 @@ import ssl import websockets -async def hello(websocket, path): +async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") diff --git a/example/server.py b/example/server.py index 4dcf317f5..7fd7bdf4c 100755 --- a/example/server.py +++ b/example/server.py @@ -5,7 +5,7 @@ import asyncio import websockets -async def hello(websocket, path): +async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") diff --git a/example/show_time.py b/example/show_time.py index 8e39f1776..b5a153b71 100755 --- a/example/show_time.py +++ b/example/show_time.py @@ -7,7 +7,7 @@ import random import websockets -async def time(websocket, path): +async def time(websocket): while True: now = datetime.datetime.utcnow().isoformat() + "Z" await websocket.send(now) diff --git a/example/shutdown_server.py b/example/shutdown_server.py index cabba4014..1bcc9c90b 100755 --- a/example/shutdown_server.py +++ b/example/shutdown_server.py @@ -4,7 +4,7 @@ import signal import websockets -async def echo(websocket, path): +async def echo(websocket): async for message in websocket: await websocket.send(message) diff --git a/example/unix_server.py b/example/unix_server.py index 192f31bb0..335039c35 100755 --- a/example/unix_server.py +++ b/example/unix_server.py @@ -6,7 +6,7 @@ import os.path import websockets -async def hello(websocket, path): +async def hello(websocket): name = await websocket.recv() print(f"<<< {name}") diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index c3c9a0557..6b3b2ae3f 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -84,14 +84,14 @@ async def serve_html(path, request_headers): return http.HTTPStatus.NOT_FOUND, {}, b"Not found\n" -async def noop_handler(websocket, path): +async def noop_handler(websocket): pass # Send credentials as the first message in the WebSocket connection -async def first_message_handler(websocket, path): +async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) if user is None: @@ -119,7 +119,7 @@ async def process_request(self, path, headers): self.user = user -async def query_param_handler(websocket, path): +async def query_param_handler(websocket): user = websocket.user await websocket.send(f"Hello {user}!") @@ -149,7 +149,7 @@ async def process_request(self, path, headers): self.user = user -async def cookie_handler(websocket, path): +async def cookie_handler(websocket): user = websocket.user await websocket.send(f"Hello {user}!") @@ -173,7 +173,7 @@ async def check_credentials(self, username, password): return True -async def user_info_handler(websocket, path): +async def user_info_handler(websocket): user = websocket.user await websocket.send(f"Hello {user}!") diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index 355d6b5e9..9c9907b7f 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -45,7 +45,7 @@ async def subscribe(self): PUBSUB = PubSub() -async def handler(websocket, path, method=None): +async def handler(websocket, method=None): if method in ["default", "naive", "task", "wait"]: CLIENTS.add(websocket) try: diff --git a/experiments/compression/server.py b/experiments/compression/server.py index f7b147006..8d1ee3cd7 100644 --- a/experiments/compression/server.py +++ b/experiments/compression/server.py @@ -18,7 +18,7 @@ MEM_SIZE = [] -async def handler(ws, path): +async def handler(ws): msg = await ws.recv() await ws.send(msg) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9bda5cdd8..4de4959b9 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -4,6 +4,7 @@ import email.utils import functools import http +import inspect import logging import socket import warnings @@ -95,7 +96,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, - ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], ws_server: WebSocketServer, *, logger: Optional[LoggerLike] = None, @@ -118,7 +122,10 @@ def __init__( if origins is not None and "" in origins: warnings.warn("use None instead of '' in origins", DeprecationWarning) origins = [None if origin == "" else origin for origin in origins] - self.ws_handler = ws_handler + # For backwards compatibility with 10.0 or earlier. Done here in + # addition to serve to trigger the deprecation warning on direct + # use of WebSocketServerProtocol. + self.ws_handler = remove_path_argument(ws_handler) self.ws_server = ws_server self.origins = origins self.available_extensions = extensions @@ -152,7 +159,7 @@ async def handler(self) -> None: try: try: - path = await self.handshake( + await self.handshake( origins=self.origins, available_extensions=self.available_extensions, available_subprotocols=self.available_subprotocols, @@ -221,7 +228,7 @@ async def handler(self) -> None: return try: - await self.ws_handler(self, path) + await self.ws_handler(self) except Exception: self.logger.error("connection handler failed", exc_info=True) if not self.closed: @@ -555,8 +562,6 @@ async def handshake( """ Perform the server side of the opening handshake. - Return the path of the URI of the request. - Args: origins: list of acceptable values of the Origin HTTP header; include :obj:`None` if the lack of an origin is acceptable. @@ -567,6 +572,9 @@ async def handshake( extra_headers: arbitrary HTTP headers to add to the response when the handshake succeeds. + Returns: + str: path of the URI of the request. + Raises: InvalidHandshake: if the handshake fails. @@ -851,9 +859,8 @@ class Serve: The server is shut down automatically when exiting the context. Args: - ws_handler: connection handler. It must be a coroutine accepting - two arguments: the WebSocket connection, which is a - :class:`WebSocketServerProtocol`, and the path of the request. + ws_handler: connection handler. It receives the WebSocket connection, + which is a :class:`WebSocketServerProtocol`, in argument. host: network interfaces the server is bound to; see :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on; @@ -908,7 +915,10 @@ class Serve: def __init__( self, - ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], host: Optional[Union[str, Sequence[str]]] = None, port: Optional[int] = None, *, @@ -979,7 +989,10 @@ def __init__( factory = functools.partial( create_protocol, - ws_handler, + # For backwards compatibility with 10.0 or earlier. Done here in + # addition to WebSocketServerProtocol to trigger the deprecation + # warning once per serve() call rather than once per connection. + remove_path_argument(ws_handler), ws_server, host=host, port=port, @@ -1052,7 +1065,10 @@ async def __await_impl__(self) -> WebSocketServer: def unix_serve( - ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], path: Optional[str] = None, **kwargs: Any, ) -> Serve: @@ -1071,3 +1087,27 @@ def unix_serve( """ return serve(ws_handler, path=path, unix=True, **kwargs) + + +def remove_path_argument( + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ] +) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: + if len(inspect.signature(ws_handler).parameters) == 2: + # Enable deprecation warning and announce deprecation in 11.0. + # warnings.warn("remove second argument of ws_handler", DeprecationWarning) + + async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: + return await cast( + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler, + )(websocket, websocket.path) + + return _ws_handler + + return cast( + Callable[[WebSocketServerProtocol], Awaitable[Any]], + ws_handler, + ) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 482b2cd0c..fea2b4178 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -53,22 +53,22 @@ testcert = bytes(pathlib.Path(__file__).parent.with_name("test_localhost.pem")) -async def default_handler(ws, path): - if path == "/deprecated_attributes": +async def default_handler(ws): + if ws.path == "/deprecated_attributes": await ws.recv() # delay that allows catching warnings await ws.send(repr((ws.host, ws.port, ws.secure))) - elif path == "/close_timeout": + elif ws.path == "/close_timeout": await ws.send(repr(ws.close_timeout)) - elif path == "/path": + elif ws.path == "/path": await ws.send(str(ws.path)) - elif path == "/headers": + elif ws.path == "/headers": await ws.send(repr(ws.request_headers)) await ws.send(repr(ws.response_headers)) - elif path == "/extensions": + elif ws.path == "/extensions": await ws.send(repr(ws.extensions)) - elif path == "/subprotocol": + elif ws.path == "/subprotocol": await ws.send(repr(ws.subprotocol)) - elif path == "/slow_stop": + elif ws.path == "/slow_stop": await ws.wait_closed() await asyncio.sleep(2 * MS) else: @@ -476,6 +476,21 @@ def test_unix_socket(self): finally: self.stop_server() + def test_ws_handler_argument_backwards_compatibility(self): + async def handler_with_path(ws, path): + await ws.send(path) + + with self.temp_server( + handler=handler_with_path, + # Enable deprecation warning and announce deprecation in 11.0. + # deprecation_warnings=["remove second argument of ws_handler"], + ): + with self.temp_client("/path"): + self.assertEqual( + self.loop.run_until_complete(self.client.recv()), + "/path", + ) + async def process_request_OK(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" @@ -1336,7 +1351,7 @@ class AsyncIteratorTests(ClientServerTestsMixin, AsyncioTestCase): MESSAGES = ["3", "2", "1", "Fire!"] - async def echo_handler(ws, path): + async def echo_handler(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) @@ -1354,7 +1369,7 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - async def echo_handler_1001(ws, path): + async def echo_handler_1001(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) await ws.close(1001) @@ -1373,7 +1388,7 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - async def echo_handler_1011(ws, path): + async def echo_handler_1011(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) await ws.close(1011) @@ -1395,7 +1410,7 @@ async def run_client(): class ReconnectionTests(ClientServerTestsMixin, AsyncioTestCase): - async def echo_handler(ws, path): + async def echo_handler(ws): async for msg in ws: await ws.send(msg) From c439f1d52aafc05064cc11702d1c3014046799b0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 22:30:05 +0200 Subject: [PATCH 178/762] Mirror full asyncio.Server API in WebSocketServer. --- docs/project/changelog.rst | 3 ++ docs/reference/server.rst | 8 +++++ src/websockets/legacy/server.py | 59 +++++++++++++++++++++++++-------- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 2a714f945..a2c9991d5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,9 @@ Improvements the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of the first argument. +* Mirrored the entire :class:`~asyncio.Server` API + in :class:`~server.WebSocketServer`. + 10.0 ---- diff --git a/docs/reference/server.rst b/docs/reference/server.rst index a8eb9c5b4..97bf320b6 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -24,6 +24,14 @@ Stopping a server .. automethod:: wait_closed + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + .. autoattribute:: sockets Using a connection diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4de4959b9..98712ff86 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -724,13 +724,6 @@ def unregister(self, protocol: WebSocketServerProtocol) -> None: """ self.websockets.remove(protocol) - def is_serving(self) -> bool: - """ - Tell whether the server is accepting new connections or shutting down. - - """ - return self.server.is_serving() - def close(self) -> None: """ Close the server. @@ -748,7 +741,7 @@ def close(self) -> None: """ if self.close_task is None: - self.close_task = self.server.get_loop().create_task(self._close()) + self.close_task = self.get_loop().create_task(self._close()) async def _close(self) -> None: """ @@ -769,7 +762,7 @@ async def _close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0, **loop_if_py_lt_38(self.server.get_loop())) + await asyncio.sleep(0, **loop_if_py_lt_38(self.get_loop())) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING connections with a HTTP 503 @@ -782,7 +775,7 @@ async def _close(self) -> None: asyncio.create_task(websocket.close(1001)) for websocket in self.websockets ], - **loop_if_py_lt_38(self.server.get_loop()), + **loop_if_py_lt_38(self.get_loop()), ) # Wait until all connection handlers are complete. @@ -791,7 +784,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - **loop_if_py_lt_38(self.server.get_loop()), + **loop_if_py_lt_38(self.get_loop()), ) # Tell wait_closed() to return. @@ -820,16 +813,54 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.closed_waiter) + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: + """ + See :meth:`asyncio.Server.start_serving`. + + """ + await self.server.start_serving() # pragma: no cover + + async def serve_forever(self) -> None: + """ + See :meth:`asyncio.Server.serve_forever`. + + """ + await self.server.serve_forever() # pragma: no cover + @property def sockets(self) -> Optional[List[socket.socket]]: """ - List of :obj:`~socket.socket` objects the server is listening on. - - :obj:`None` if the server is closed. + See :attr:`asyncio.Server.sockets`. """ return self.server.sockets + async def __aenter__(self) -> WebSocketServer: + return self # pragma: no cover + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() # pragma: no cover + await self.wait_closed() # pragma: no cover + class Serve: """ From 1b16b57ba94c00e20ab08e5596a98511bca95070 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 22:34:19 +0200 Subject: [PATCH 179/762] Remove obsolete workaround. --- src/websockets/legacy/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 63b973ecb..595b13a73 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -651,8 +651,6 @@ async def __await_impl_timeout__(self) -> WebSocketClientProtocol: async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): transport, protocol = await self._create_connection() - # https://github.com/python/typeshed/pull/2756 - transport = cast(asyncio.Transport, transport) protocol = cast(WebSocketClientProtocol, protocol) try: From 541f95cdab5d4dd953fc9428c47421b7168ae0b1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 9 Sep 2021 22:48:52 +0200 Subject: [PATCH 180/762] Move build_host to headers module. This is where similar functions are defined. --- src/websockets/client.py | 3 ++- src/websockets/headers.py | 25 +++++++++++++++++++++++++ src/websockets/http.py | 26 +------------------------- src/websockets/legacy/client.py | 3 ++- tests/test_headers.py | 22 ++++++++++++++++++++++ tests/test_http.py | 28 +--------------------------- 6 files changed, 53 insertions(+), 54 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 6f2da5a6e..34732b3a6 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -16,13 +16,14 @@ from .headers import ( build_authorization_basic, build_extension, + build_host, build_subprotocol, parse_connection, parse_extension, parse_subprotocol, parse_upgrade, ) -from .http import USER_AGENT, build_host +from .http import USER_AGENT from .http11 import Request, Response from .typing import ( ConnectionOption, diff --git a/src/websockets/headers.py b/src/websockets/headers.py index a2fdfdd30..9ae3035a5 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -2,6 +2,7 @@ import base64 import binascii +import ipaddress import re from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast @@ -17,6 +18,7 @@ __all__ = [ + "build_host", "parse_connection", "parse_upgrade", "parse_extension", @@ -33,6 +35,29 @@ T = TypeVar("T") +def build_host(host: str, port: int, secure: bool) -> str: + """ + Build a ``Host`` header. + + """ + # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 + # IPv6 addresses must be enclosed in brackets. + try: + address = ipaddress.ip_address(host) + except ValueError: + # host is a hostname + pass + else: + # host is an IP address + if address.version == 6: + host = f"[{host}]" + + if port != (443 if secure else 80): + host = f"{host}:{port}" + + return host + + # To avoid a dependency on a parsing library, we implement manually the ABNF # described in https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 and # https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. diff --git a/src/websockets/http.py b/src/websockets/http.py index 38848b56d..b14fa94bd 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,6 +1,5 @@ from __future__ import annotations -import ipaddress import sys from .imports import lazy_import @@ -24,31 +23,8 @@ ) -__all__ = ["USER_AGENT", "build_host"] +__all__ = ["USER_AGENT"] PYTHON_VERSION = "{}.{}".format(*sys.version_info) USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}" - - -def build_host(host: str, port: int, secure: bool) -> str: - """ - Build a ``Host`` header. - - """ - # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 - # IPv6 addresses must be enclosed in brackets. - try: - address = ipaddress.ip_address(host) - except ValueError: - # host is a hostname - pass - else: - # host is an IP address - if address.version == 6: - host = f"[{host}]" - - if port != (443 if secure else 80): - host = f"{host}:{port}" - - return host diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 595b13a73..6704d16ce 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -35,12 +35,13 @@ from ..headers import ( build_authorization_basic, build_extension, + build_host, build_subprotocol, parse_extension, parse_subprotocol, validate_subprotocols, ) -from ..http import USER_AGENT, build_host +from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri from .handshake import build_request, check_response diff --git a/tests/test_headers.py b/tests/test_headers.py index badec5a86..a2d51fc6a 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -5,6 +5,28 @@ class HeadersTests(unittest.TestCase): + def test_build_host(self): + for (host, port, secure), result in [ + (("localhost", 80, False), "localhost"), + (("localhost", 8000, False), "localhost:8000"), + (("localhost", 443, True), "localhost"), + (("localhost", 8443, True), "localhost:8443"), + (("example.com", 80, False), "example.com"), + (("example.com", 8000, False), "example.com:8000"), + (("example.com", 443, True), "example.com"), + (("example.com", 8443, True), "example.com:8443"), + (("127.0.0.1", 80, False), "127.0.0.1"), + (("127.0.0.1", 8000, False), "127.0.0.1:8000"), + (("127.0.0.1", 443, True), "127.0.0.1"), + (("127.0.0.1", 8443, True), "127.0.0.1:8443"), + (("::1", 80, False), "[::1]"), + (("::1", 8000, False), "[::1]:8000"), + (("::1", 443, True), "[::1]"), + (("::1", 8443, True), "[::1]:8443"), + ]: + with self.subTest(host=host, port=port, secure=secure): + self.assertEqual(build_host(host, port, secure), result) + def test_parse_connection(self): for header, parsed in [ # Realistic use cases diff --git a/tests/test_http.py b/tests/test_http.py index ca7c1c0a4..16bec9468 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,27 +1 @@ -import unittest - -from websockets.http import * - - -class HTTPTests(unittest.TestCase): - def test_build_host(self): - for (host, port, secure), result in [ - (("localhost", 80, False), "localhost"), - (("localhost", 8000, False), "localhost:8000"), - (("localhost", 443, True), "localhost"), - (("localhost", 8443, True), "localhost:8443"), - (("example.com", 80, False), "example.com"), - (("example.com", 8000, False), "example.com:8000"), - (("example.com", 443, True), "example.com"), - (("example.com", 8443, True), "example.com:8443"), - (("127.0.0.1", 80, False), "127.0.0.1"), - (("127.0.0.1", 8000, False), "127.0.0.1:8000"), - (("127.0.0.1", 443, True), "127.0.0.1"), - (("127.0.0.1", 8443, True), "127.0.0.1:8443"), - (("::1", 80, False), "[::1]"), - (("::1", 8000, False), "[::1]:8000"), - (("::1", 443, True), "[::1]"), - (("::1", 8443, True), "[::1]:8443"), - ]: - with self.subTest(host=host, port=port, secure=secure): - self.assertEqual(build_host(host, port, secure), result) +from websockets.http import * # noqa From dc0e79d6ca2ce6da4906c861443b8c696bf93dff Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 10 Sep 2021 22:54:12 +0200 Subject: [PATCH 181/762] Fix excessive escaping in debug logs. --- src/websockets/legacy/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index a31a5c7c8..a52f9f71d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1241,7 +1241,7 @@ async def keepalive_ping(self) -> None: # ping() raises ConnectionClosed if the connection is lost, # when connection_lost() calls abort_pings(). - self.logger.debug("%% sending keepalive ping") + self.logger.debug("% sending keepalive ping") pong_waiter = await self.ping() if self.ping_timeout is not None: @@ -1251,7 +1251,7 @@ async def keepalive_ping(self) -> None: self.ping_timeout, **loop_if_py_lt_38(self.loop), ) - self.logger.debug("%% received keepalive pong") + self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: self.logger.debug("! timed out waiting for keepalive pong") From 5d67724fa80765ef95b21b0b04a3d22b1014649a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 15:55:04 +0200 Subject: [PATCH 182/762] Fix indentation issue. --- docs/howto/django.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 67bf582f1..eef1b8f47 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -39,7 +39,7 @@ Generate tokens We want secure, short-lived tokens containing the user ID. We'll rely on `django-sesame`_, a small library designed exactly for this purpose. - .. _django-sesame: https://github.com/aaugustin/django-sesame +.. _django-sesame: https://github.com/aaugustin/django-sesame Add django-sesame to the dependencies of your Django project, install it, and configure it in the settings of the project: From 397eda4952e1f2b044e352ddbf2b45b336881583 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 17:11:20 +0200 Subject: [PATCH 183/762] Mention crypto policy in issue template. --- .github/ISSUE_TEMPLATE/issue.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md index 8efabc617..ee32cf0ee 100644 --- a/.github/ISSUE_TEMPLATE/issue.md +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -20,6 +20,9 @@ https://docs.python.org/3/library/asyncio-dev.html Did you look for similar issues? Please keep the discussion in one place :-) https://github.com/aaugustin/websockets/issues?q=is%3Aissue +Is your issue related to cryptocurrency in any way? Please don't file it. +https://websockets.readthedocs.io/en/stable/project/contributing.html#cryptocurrency-users + For bugs, providing a reproduction helps a lot. Take an existing example and tweak it! https://github.com/aaugustin/websockets/tree/main/example From d8a436fc0eec66e6e9c4b4631124c0039ae63a25 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 17:16:13 +0200 Subject: [PATCH 184/762] Crypto keeps getting worse. --- docs/project/contributing.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index c3e8dfc4c..43fd58dc8 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -55,12 +55,12 @@ websockets appears to be quite popular for interfacing with Bitcoin or other cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. I'm aware of efforts to build proof-of-stake models. I'll care once the total -carbon footprint of all cryptocurrencies drops to a non-bullshit level. +energy consumption of all cryptocurrencies drops to a non-bullshit level. -Please stop heating the planet where my children are supposed to live, thanks. +You already negated all of humanity's efforts to develop renewable energy. +Please stop heating the planet where my children will have to live. Since websockets is released under an open-source license, you can use it for any purpose you like. However, I won't spend any of my time to help you. I will summarily close issues related to Bitcoin or cryptocurrency in any way. -Thanks for your understanding. From 744482a2c2ab33e86ec877441f6f8d44ce03b282 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Sep 2021 21:20:41 +0200 Subject: [PATCH 185/762] Pin Python 3.9.6 in CI. Ref #1051. --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0042c302c..9d0792c7b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,7 @@ jobs: - name: Install Python 3.x uses: actions/setup-python@v2 with: - python-version: 3.x + python-version: 3.9.6 - name: Install tox run: pip install tox - name: Run tests with coverage @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: [3.7, 3.8, 3.9] + python: [3.7, 3.8, 3.9.6] steps: - name: Check out repository uses: actions/checkout@v2 From a2a61ead84cfaf60752452f245e111eecf4a6e53 Mon Sep 17 00:00:00 2001 From: Oliver Zehentleitner <47597331+oliver-zehentleitner@users.noreply.github.com> Date: Mon, 20 Sep 2021 21:43:31 +0200 Subject: [PATCH 186/762] Fix link to tutorial --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 99e477867..8603d91ee 100644 --- a/README.rst +++ b/README.rst @@ -75,7 +75,7 @@ And here's an echo server: Does that look good? -`Get started with the tutorial! `_ +`Get started with the tutorial! `_ .. raw:: html From 17af113f028b8a04e1ff8ba00381e9b57b386cfc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Sep 2021 22:20:54 +0200 Subject: [PATCH 187/762] Fix link to FAQ --- .github/ISSUE_TEMPLATE/issue.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md index ee32cf0ee..f2704152c 100644 --- a/.github/ISSUE_TEMPLATE/issue.md +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -12,7 +12,7 @@ assignees: '' Thanks for taking the time to report an issue! Did you check the FAQ? Perhaps you'll find the answer you need: -https://websockets.readthedocs.io/en/stable/faq.html +https://websockets.readthedocs.io/en/stable/howto/faq.html Is your question really about asyncio? Perhaps the dev guide will help: https://docs.python.org/3/library/asyncio-dev.html From 54d59f2f0e2af947714a86419b78a462063e75af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Sep 2021 22:22:02 +0200 Subject: [PATCH 188/762] Use the default documentation version. --- docs/howto/django.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index eef1b8f47..a3d4a2ceb 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -12,7 +12,7 @@ WebSocket, you have two main options. technique is well suited when you need to add a small set of real-time features — maybe a notification service — to a HTTP application. -.. _Channels: https://channels.readthedocs.io/en/latest/ +.. _Channels: https://channels.readthedocs.io/ This guide shows how to implement the second technique with websockets. It assumes familiarity with Django. From 737ea76b3f2697da9e69f2736b0c868430c18219 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Sep 2021 22:24:35 +0200 Subject: [PATCH 189/762] Remove stale reference. --- README.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/README.rst b/README.rst index 8603d91ee..089e1bfa9 100644 --- a/README.rst +++ b/README.rst @@ -113,7 +113,6 @@ Docs`_ and see for yourself. .. _Read the Docs: https://websockets.readthedocs.io/ .. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers -.. _Autobahn Testsuite: https://github.com/aaugustin/websockets/blob/main/compliance/README.rst Why shouldn't I use ``websockets``? ----------------------------------- From 20960c2792353dff7569554fcbe956111d772ba0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Sep 2021 22:39:11 +0200 Subject: [PATCH 190/762] Add python -m websockets --version. --- src/websockets/__main__.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 860e4b1fa..c562d21b5 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -11,6 +11,7 @@ from .exceptions import ConnectionClosed from .frames import Close from .legacy.client import connect +from .version import version as websockets_version if sys.platform == "win32": @@ -152,6 +153,24 @@ async def run_client( def main() -> None: + # Parse command line arguments. + parser = argparse.ArgumentParser( + prog="python -m websockets", + description="Interactive WebSocket client.", + add_help=False, + ) + group = parser.add_mutually_exclusive_group() + group.add_argument("--version", action="store_true") + group.add_argument("uri", metavar="", nargs="?") + args = parser.parse_args() + + if args.version: + print(f"websockets {websockets_version}") + return + + if args.uri is None: + parser.error("the following arguments are required: ") + # If we're on Windows, enable VT100 terminal support. if sys.platform == "win32": try: @@ -169,15 +188,6 @@ def main() -> None: except ImportError: # Windows has no `readline` normally pass - # Parse command line arguments. - parser = argparse.ArgumentParser( - prog="python -m websockets", - description="Interactive WebSocket client.", - add_help=False, - ) - parser.add_argument("uri", metavar="") - args = parser.parse_args() - # Create an event loop that will run in a background thread. loop = asyncio.new_event_loop() From 9f77f4e492fa0030ae3e5c388c606931c130065e Mon Sep 17 00:00:00 2001 From: Vladislav Smirnov Date: Tue, 21 Sep 2021 22:35:04 +0300 Subject: [PATCH 191/762] Fix typo in FAQ --- docs/howto/faq.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index bef61bdab..60d80677d 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -208,8 +208,8 @@ achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. Keep track of the tasks and make sure they terminate or you cancel them when the connection terminates. -Why does my program never receives any messages? -................................................ +Why does my program never receive any messages? +............................................... Your program runs a coroutine that never yields control to the event loop. The coroutine that receives messages never gets a chance to run. From 37ef247bc4bb8692d2de187b9e153ba2dc64a9e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Sep 2021 11:18:51 +0200 Subject: [PATCH 192/762] Reformat notes. --- docs/howto/faq.rst | 9 +++++---- docs/topics/broadcast.rst | 7 +++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 60d80677d..f97ea73a4 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -3,11 +3,12 @@ FAQ .. currentmodule:: websockets -.. note:: +.. admonition:: Many questions asked in websockets' issue tracker are really + about :mod:`asyncio`. + :class: seealso - Many questions asked in websockets' issue tracker are actually - about :mod:`asyncio`. Python's documentation about `developing with - asyncio`_ is a good complement. + Python's documentation about `developing with asyncio`_ is a good + complement. .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 531d8ca12..6c7ced8b0 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -4,10 +4,9 @@ Broadcasting messages .. currentmodule:: websockets -.. note:: - - If you just want to send a message to all connected clients, use - :func:`broadcast`. +.. admonition:: If you just want to send a message to all connected clients, + use :func:`broadcast`. + :class: tip If you want to learn about its design in depth, continue reading this document. From c70c7d6703ab8bef35cc89965becc67e8a1afe7c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Sep 2021 11:22:06 +0200 Subject: [PATCH 193/762] Document how to autoreload in dev. --- docs/howto/autoreload.rst | 31 +++++++++++++++++++++++++++++++ docs/howto/index.rst | 3 ++- docs/project/changelog.rst | 2 ++ 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 docs/howto/autoreload.rst diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst new file mode 100644 index 000000000..4f650acb3 --- /dev/null +++ b/docs/howto/autoreload.rst @@ -0,0 +1,31 @@ +Reload on code changes +====================== + +When developing a websockets server, you may run it locally to test changes. +Unfortunately, whenever you want to try a new version of the code, you must +stop the server and restart it, which slows down your development process. + +Web frameworks such as Django or Flask provide a development server that +reloads the application automatically when you make code changes. There is no +such functionality in websockets because it's designed for production rather +than development. + +However, you can achieve the same result easily. + +Install watchdog_ with the ``watchmedo`` shell utility: + +.. code:: console + + $ pip install watchdog[watchmedo] + +.. _watchdog: https://pypi.org/project/watchdog/ + +Run your server with ``watchmedo auto-restart``: + +.. code:: console + + $ watchmedo auto-restart --pattern "*.py" --recursive --signal SIGTERM \ + python app.py + +This example assumes that the server is defined in a script called ``app.py``. +Adapt it as necessary. diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 5deb7c767..d11789d48 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -1,13 +1,14 @@ How-to guides ============= -If you're stuck, perhaps you'll find the answer in the FAQ or the cheat sheet. +If you're stuck, perhaps you'll find the answer here. .. toctree:: :titlesonly: faq cheatsheet + autoreload This guide will help you integrate websockets into a broader system. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a2c9991d5..3209b39db 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,8 @@ Improvements * Mirrored the entire :class:`~asyncio.Server` API in :class:`~server.WebSocketServer`. +* Documented how to auto-reload on code changes in development. + 10.0 ---- From 27e861fd8566d056d4382992c659868bc57ef01d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 14 Aug 2021 11:34:06 +0200 Subject: [PATCH 194/762] Enable Python 3.10 in CI. Fix #935. --- .github/workflows/tests.yml | 4 ++-- .github/workflows/wheels.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d0792c7b..d8a95a70d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,7 @@ jobs: - name: Install Python 3.x uses: actions/setup-python@v2 with: - python-version: 3.9.6 + python-version: "3.9.6" - name: Install tox run: pip install tox - name: Run tests with coverage @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: [3.7, 3.8, 3.9.6] + python: ["3.7", "3.8", "3.9.6", "3.10"] steps: - name: Check out repository uses: actions/checkout@v2 diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 249cd36ce..7e62237b3 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -50,7 +50,7 @@ jobs: uses: joerick/cibuildwheel@v1.11.0 env: CIBW_ARCHS_LINUX: auto aarch64 - CIBW_BUILD: cp37-* cp38-* cp39-* + CIBW_SKIP: cp36-* - name: Save wheels uses: actions/upload-artifact@v2 with: From 5ba529bf55e271040a122b999f756f5c1919cd11 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 24 Oct 2021 21:20:46 +0200 Subject: [PATCH 195/762] Make apply_mask twice as fast on ARM. --- docs/project/changelog.rst | 2 ++ src/websockets/speedups.c | 22 +++++++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 3209b39db..69955ee1d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,8 @@ Improvements * Mirrored the entire :class:`~asyncio.Server` API in :class:`~server.WebSocketServer`. +* Improved performance for large messages on ARM processors. + * Documented how to auto-reload on code changes in development. 10.0 diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index fc328e528..f8d24ec7a 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -2,9 +2,11 @@ #define PY_SSIZE_T_CLEAN #include -#include /* uint32_t, uint64_t */ +#include /* uint8_t, uint32_t, uint64_t */ -#if __SSE2__ +#if __ARM_NEON +#include +#elif __SSE2__ #include #endif @@ -128,7 +130,21 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) // We need a new scope for MSVC 2010 (non C99 friendly) { -#if __SSE2__ +#if __ARM_NEON + + // With NEON support, XOR by blocks of 16 bytes = 128 bits. + + Py_ssize_t input_len_128 = input_len & ~15; + uint8x16_t mask_128 = vreinterpretq_u8_u32(vdupq_n_u32(*(uint32_t *)mask)); + + for (; i < input_len_128; i += 16) + { + uint8x16_t in_128 = vld1q_u8((uint8_t *)(input + i)); + uint8x16_t out_128 = veorq_u8(in_128, mask_128); + vst1q_u8((uint8_t *)(output + i), out_128); + } + +#elif __SSE2__ // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. From fc3ade70aeafa5535e433b2c6d05f6fb7c70e9e3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 28 Oct 2021 21:17:23 +0200 Subject: [PATCH 196/762] Recommend Sanic for mixing HTTP and WebSocket. Fix #1073. --- README.rst | 7 +++++++ docs/howto/faq.rst | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 089e1bfa9..82f89f952 100644 --- a/README.rst +++ b/README.rst @@ -120,11 +120,18 @@ Why shouldn't I use ``websockets``? * If you prefer callbacks over coroutines: ``websockets`` was created to provide the best coroutine-based API to manage WebSocket connections in Python. Pick another library for a callback-based API. + * If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. + If you want do to both in the same server, look at HTTP frameworks that + build on top of ``websockets`` to support WebSocket connections, like + Sanic_. + +.. _Sanic: https://sanicframework.org/en/ + What else? ---------- diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index f97ea73a4..20b957745 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -135,7 +135,7 @@ How do I run a HTTP server and WebSocket server on the same port? You don't. HTTP and WebSockets have widely different operational characteristics. -Running them on the same server is a bad idea. +Running them with the same server becomes inconvenient when you scale. Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. @@ -145,6 +145,11 @@ There's limited support for returning HTTP responses with the If you need more, pick a HTTP server and run it separately. +Alternatively, pick a HTTP framework that builds on top of ``websockets`` to +support WebSocket connections, like Sanic_. + +.. _Sanic: https://sanicframework.org/en/ + Client side ----------- From ed9a7b446c7147f6f88dbeb1d86546ad754e435e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 8 Oct 2021 22:18:24 +0200 Subject: [PATCH 197/762] Disable memory optimization on the client side. Also clarify compression docs. Fix #1065. --- docs/project/changelog.rst | 5 ++ docs/topics/compression.rst | 75 ++++++++++++++----- docs/topics/memory.rst | 3 + .../extensions/permessage_deflate.py | 2 - 4 files changed, 65 insertions(+), 20 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 69955ee1d..a4176d572 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,11 @@ Improvements the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of the first argument. +* Reverted optimization of default compression settings for clients, mainly to + avoid triggering bugs in poorly implemented servers like `AWS API Gateway`_. + + .. _AWS API Gateway: https://github.com/aaugustin/websockets/issues/1065 + * Mirrored the entire :class:`~asyncio.Server` API in :class:`~server.WebSocketServer`. diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index d40c4257d..0f264dc66 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -59,35 +59,46 @@ Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or ], ) -The Window Bits and Memory Level values in these examples reduce memory usage at the expense of compression rate. +The Window Bits and Memory Level values in these examples reduce memory usage +at the expense of compression rate. Compression settings -------------------- When a client and a server enable the Per-Message Deflate extension, they negotiate two parameters to guarantee compatibility between compression and -decompression. This affects the trade-off between compression rate and memory -usage for both sides. +decompression. These parameters affect the trade-off between compression rate +and memory usage for both sides. * **Context Takeover** means that the compression context is retained between messages. In other words, compression is applied to the stream of messages - rather than to each message individually. Context takeover should remain - enabled to get good performance on applications that send a stream of - messages with the same structure, that is, most applications. + rather than to each message individually. + + Context takeover should remain enabled to get good performance on + applications that send a stream of messages with similar structure, + that is, most applications. + + This requires retaining the compression context and state between messages, + which increases the memory footprint of a connection. * **Window Bits** controls the size of the compression context. It must be an integer between 9 (lowest memory usage) and 15 (best compression). - websockets defaults to 12. Setting it to 8 is possible but rejected by some - versions of zlib. + Setting it to 8 is possible but rejected by some versions of zlib. + + On the server side, websockets defaults to 12. On the client side, it lets + the server pick a suitable value, which is the same as defaulting to 15. :mod:`zlib` offers additional parameters for tuning compression. They control -the trade-off between compression rate and CPU and memory usage for the -compression side, transparently for the decompression side. +the trade-off between compression rate, memory usage, and CPU usage only for +compressing. They're transparent for decompressing. Unless mentioned +otherwise, websockets inherits defaults of :func:`~zlib.compressobj`. * **Memory Level** controls the size of the compression state. It must be an - integer between 1 (lowest memory usage) and 9 (best compression). websockets - defaults to 5. A lower memory level can increase speed thanks to memory - locality. + integer between 1 (lowest memory usage) and 9 (best compression). + + websockets defaults to 5. This is lower than zlib's default of 8. Not only + does a lower memory level reduce memory usage, but it can also increase + speed thanks to memory locality. * **Compression Level** controls the effort to optimize compression. It must be an integer between 1 (lowest CPU usage) and 9 (best compression). @@ -95,16 +106,17 @@ compression side, transparently for the decompression side. * **Strategy** selects the compression strategy. The best choice depends on the type of data being compressed. -Unless mentioned otherwise, websockets uses the defaults of -:func:`zlib.compressobj` for all these settings. Tuning compression ------------------ +For servers +........... + By default, websockets enables compression with conservative settings that -optimize memory usage at the cost of a slightly worse compression rate: Window -Bits = 12 and Memory Level = 5. This strikes a good balance for small messages -that are typical of WebSocket servers. +optimize memory usage at the cost of a slightly worse compression rate: +Window Bits = 12 and Memory Level = 5. This strikes a good balance for small +messages that are typical of WebSocket servers. Here's how various compression settings affect memory usage of a single connection on a 64-bit system, as well a benchmark of compressed size and @@ -152,6 +164,33 @@ usage is: CPU usage is also higher for compression than decompression. +For clients +........... + +By default, websockets enables compression with Memory Level = 5 but leaves +the Window Bits setting up to the server. + +There's two good reasons and one bad reason for not optimizing the client side +like the server side: + +1. If the maintainers of a server configured some optimized settings, we don't + want to override them with more restrictive settings. + +2. Optimizing memory usage doesn't matter very much for clients because it's + uncommon to open thousands of client connections in a program. + +3. On a more pragmatic note, some servers misbehave badly when a client + configures compression settings. `AWS API Gateway`_ is the worst offender. + + .. _AWS API Gateway: https://github.com/aaugustin/websockets/issues/1065 + + Unfortunately, even though websockets is right and AWS is wrong, many users + jump to the conclusion that websockets doesn't work. + + Until the ecosystem levels up, interoperability with buggy servers seems + more valuable than optimizing memory usage. + + Further reading --------------- diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index ee0109c35..e44247a77 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -19,6 +19,9 @@ Baseline Compression settings are the main factor affecting the baseline amount of memory used by each connection. +With websockets' defaults, on the server side, a single connections uses +70 KiB of memory. + Refer to the :doc:`topic guide on compression <../topics/compression>` to learn more about tuning compression settings. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index da2bc153e..fefa55643 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -448,8 +448,6 @@ def enable_client_permessage_deflate( ): extensions = list(extensions) + [ ClientPerMessageDeflateFactory( - server_max_window_bits=12, - client_max_window_bits=12, compress_settings={"memLevel": 5}, ) ] From b240c047f342411140dee8e774beae7abe55381a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Nov 2021 09:14:00 +0100 Subject: [PATCH 198/762] Add wheels for more architectures. --- .github/workflows/wheels.yml | 5 +++-- docs/project/changelog.rst | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 7e62237b3..a6df67743 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -47,9 +47,10 @@ jobs: with: platforms: all - name: Build wheels - uses: joerick/cibuildwheel@v1.11.0 + uses: pypa/cibuildwheel@v2.2.2 env: - CIBW_ARCHS_LINUX: auto aarch64 + CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" + CIBW_ARCHS_LINUX: "auto aarch64" CIBW_SKIP: cp36-* - name: Save wheels uses: actions/upload-artifact@v2 diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a4176d572..50b8ae7a5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -33,6 +33,8 @@ They may change at any time. Improvements ............ +* Added wheels for Python 3.10, PyPy 3.7, and for more platforms. + * Made the second parameter of connection handlers optional. It will be deprecated in the next major release. The request path is available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of From b4d9a2add6a9e716c9405dc74aae8bd521dec3ae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Nov 2021 09:31:38 +0100 Subject: [PATCH 199/762] Clarified API change on connection handlers. --- docs/project/changelog.rst | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 50b8ae7a5..5b190fb94 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,16 +30,30 @@ They may change at any time. *In development* -Improvements +New features ............ -* Added wheels for Python 3.10, PyPy 3.7, and for more platforms. - * Made the second parameter of connection handlers optional. It will be deprecated in the next major release. The request path is available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of the first argument. + If you implemented the connection handler of a server as:: + + async def handler(request, path): + ... + + You should replace it by:: + + async def handler(request): + path = request.path # if handler() uses the path argument + ... + +Improvements +............ + +* Added wheels for Python 3.10, PyPy 3.7, and for more platforms. + * Reverted optimization of default compression settings for clients, mainly to avoid triggering bugs in poorly implemented servers like `AWS API Gateway`_. From 56b589368175bf18fa1c3f0036460c0c2f3e4c4b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 11 Nov 2021 19:14:09 +0100 Subject: [PATCH 200/762] Unpin Python 3.9.6. Reverts 744482a2. Fix #1051. --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d8a95a70d..3aa579aa4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,7 +18,7 @@ jobs: - name: Install Python 3.x uses: actions/setup-python@v2 with: - python-version: "3.9.6" + python-version: "3.x" - name: Install tox run: pip install tox - name: Run tests with coverage @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.7", "3.8", "3.9.6", "3.10"] + python: ["3.7", "3.8", "3.9", "3.10"] steps: - name: Check out repository uses: actions/checkout@v2 From 27a439b6d20089dd24ab5a08d07d0585cc31ee34 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 11 Nov 2021 19:31:33 +0100 Subject: [PATCH 201/762] Work around bug in coverage. --- tests/legacy/test_client_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index fea2b4178..4eb8229b2 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1454,6 +1454,7 @@ async def run_client(): await server_ws.close() with self.assertRaises(ConnectionClosed): await ws.recv() + pass # work around bug in coverage else: # Exit block with an exception. raise Exception("BOOM!") From b6e25e290143c4dbb5b7bd0acd0c44efd086167e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 11 Nov 2021 20:16:58 +0100 Subject: [PATCH 202/762] Work around bug in mypy. --- src/websockets/legacy/protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index a52f9f71d..3340c33be 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -750,7 +750,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.write_close_frame(Close(code, reason)), self.close_timeout, **loop_if_py_lt_38(self.loop), - ) + ) # type: ignore # remove when removing loop_if_py_lt_38 except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. @@ -771,7 +771,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.transfer_data_task, self.close_timeout, **loop_if_py_lt_38(self.loop), - ) + ) # type: ignore # remove when removing loop_if_py_lt_38 except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1250,7 +1250,7 @@ async def keepalive_ping(self) -> None: pong_waiter, self.ping_timeout, **loop_if_py_lt_38(self.loop), - ) + ) # type: ignore # remove when removing loop_if_py_lt_38 self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1365,7 +1365,7 @@ async def wait_for_connection_lost(self) -> bool: asyncio.shield(self.connection_lost_waiter), self.close_timeout, **loop_if_py_lt_38(self.loop), - ) + ) # type: ignore # remove when removing loop_if_py_lt_38 except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because From cbbaa26e4ab303ef3c5aa2212cf45db9c15d14d2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 11 Nov 2021 21:12:41 +0100 Subject: [PATCH 203/762] Standardize on the code-block directive. --- docs/howto/autoreload.rst | 4 ++-- docs/howto/django.rst | 18 +++++++++--------- docs/howto/haproxy.rst | 6 +++--- docs/howto/heroku.rst | 28 ++++++++++++---------------- docs/howto/kubernetes.rst | 32 ++++++++++++++++---------------- docs/howto/nginx.rst | 6 +++--- docs/howto/supervisor.rst | 20 ++++++++++---------- docs/topics/authentication.rst | 22 +++++++++++----------- docs/topics/timeouts.rst | 2 +- 9 files changed, 67 insertions(+), 71 deletions(-) diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst index 4f650acb3..edd87d0fd 100644 --- a/docs/howto/autoreload.rst +++ b/docs/howto/autoreload.rst @@ -14,7 +14,7 @@ However, you can achieve the same result easily. Install watchdog_ with the ``watchmedo`` shell utility: -.. code:: console +.. code-block:: console $ pip install watchdog[watchmedo] @@ -22,7 +22,7 @@ Install watchdog_ with the ``watchmedo`` shell utility: Run your server with ``watchmedo auto-restart``: -.. code:: console +.. code-block:: console $ watchmedo auto-restart --pattern "*.py" --recursive --signal SIGTERM \ python app.py diff --git a/docs/howto/django.rst b/docs/howto/django.rst index a3d4a2ceb..34cce58e9 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -44,7 +44,7 @@ We want secure, short-lived tokens containing the user ID. We'll rely on Add django-sesame to the dependencies of your Django project, install it, and configure it in the settings of the project: -.. code:: python +.. code-block:: python AUTHENTICATION_BACKENDS = [ "django.contrib.auth.backends.ModelBackend", @@ -62,7 +62,7 @@ We'd like our tokens to be valid for 30 seconds. We expect web pages to load and to establish the WebSocket connection within this delay. Configure django-sesame accordingly in the settings of your Django project: -.. code:: python +.. code-block:: python SESAME_MAX_AGE = 30 @@ -77,7 +77,7 @@ performance. Now you can generate tokens in a ``django-admin shell`` as follows: -.. code:: pycon +.. code-block:: pycon >>> from django.contrib.auth import get_user_model >>> User = get_user_model() @@ -132,7 +132,7 @@ Save this code to a file called ``authentication.py``, make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set properly, and start the websockets server: -.. code:: console +.. code-block:: console $ python authentication.py @@ -140,7 +140,7 @@ Generate a new token — remember, they're only valid for 30 seconds — and use it to connect to your server. Paste your token and press Enter when you get a prompt: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8888/ Connected to ws://localhost:8888/ @@ -153,7 +153,7 @@ It works! If you enter an expired or invalid token, authentication fails and the server closes the connection: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8888/ Connected to ws://localhost:8888. @@ -163,7 +163,7 @@ closes the connection: You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: -.. code:: javascript +.. code-block:: javascript websocket = new WebSocket("ws://localhost:8888/"); websocket.onopen = (event) => websocket.send(""); @@ -205,7 +205,7 @@ the Redis connection directly. Install Redis, add django-redis to the dependencies of your Django project, install it, and configure it in the settings of the project: -.. code:: python +.. code-block:: python CACHES = { "default": { @@ -234,7 +234,7 @@ an event to Redis. Let's check that it works: -.. code:: console +.. code-block:: console $ redis-cli 127.0.0.1:6379> SELECT 1 diff --git a/docs/howto/haproxy.rst b/docs/howto/haproxy.rst index 0ecb46a04..fdaab0401 100644 --- a/docs/howto/haproxy.rst +++ b/docs/howto/haproxy.rst @@ -28,7 +28,7 @@ This configuration runs four instances of the app. Install Supervisor and run it: -.. code:: console +.. code-block:: console $ supervisord -c supervisord.conf -n @@ -46,13 +46,13 @@ servers. This is best for long running connections. Save the configuration to ``haproxy.cfg``, install HAProxy, and run it: -.. code:: console +.. code-block:: console $ haproxy -f haproxy.cfg You can confirm that HAProxy proxies connections properly: -.. code:: console +.. code-block:: console $ PYTHONPATH=src python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index b728106e9..6a7c4d00b 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -14,7 +14,7 @@ Create application Deploying to Heroku requires a git repository. Let's initialize one: -.. code:: console +.. code-block:: console $ mkdir websockets-echo $ cd websockets-echo @@ -32,7 +32,7 @@ Then, create a Heroku app — if you follow these instructions step-by-step, you'll have to pick a different name because I'm already using ``websockets-echo`` on Heroku: -.. code:: console +.. code-block:: console $ heroku create websockets-echo Creating ⬢ websockets-echo... done @@ -61,20 +61,16 @@ Deploy application In order to build the app, Heroku needs to know that it depends on websockets. Create a ``requirements.txt`` file containing this line: -.. code:: - - websockets +.. literalinclude:: ../../example/deployment/heroku/requirements.txt Heroku also needs to know how to run the app. Create a ``Procfile`` with this content: -.. code:: - - web: python app.py +.. literalinclude:: ../../example/deployment/heroku/Procfile Confirm that you created the correct files and commit them to git: -.. code:: console +.. code-block:: console $ ls Procfile app.py requirements.txt @@ -88,7 +84,7 @@ Confirm that you created the correct files and commit them to git: The app is ready. Let's deploy it! -.. code:: console +.. code-block:: console $ git push heroku main @@ -114,7 +110,7 @@ If you're currently building a websockets server, perhaps you're already in a virtualenv where websockets is installed. If not, you can install it in a new virtualenv as follows: -.. code:: console +.. code-block:: console $ python -m venv websockets-client $ . websockets-client/bin/activate @@ -123,7 +119,7 @@ virtualenv as follows: Connect the interactive client — using the name of your Heroku app instead of ``websockets-echo``: -.. code:: console +.. code-block:: console $ python -m websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. @@ -138,7 +134,7 @@ An insecure connection (``ws://``) would also work. Once you're connected, you can send any message and the server will echo it, then press Ctrl-D to terminate the connection: -.. code:: console +.. code-block:: console > Hello! < Hello! @@ -147,7 +143,7 @@ then press Ctrl-D to terminate the connection: You can also confirm that your application shuts down gracefully. Connect an interactive client again — remember to replace ``websockets-echo`` with your app: -.. code:: console +.. code-block:: console $ python -m websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. @@ -155,7 +151,7 @@ interactive client again — remember to replace ``websockets-echo`` with your a In another shell, restart the dyno — again, replace ``websockets-echo`` with your app: -.. code:: console +.. code-block:: console $ heroku dyno:restart -a websockets-echo Restarting dynos on ⬢ websockets-echo... done @@ -163,7 +159,7 @@ In another shell, restart the dyno — again, replace ``websockets-echo`` with y Go back to the first shell. The connection is closed with code 1001 (going away). -.. code:: console +.. code-block:: console $ python -m websockets wss://websockets-echo.herokuapp.com/ Connected to wss://websockets-echo.herokuapp.com/. diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst index 0e77aeac1..26dbf8a94 100644 --- a/docs/howto/kubernetes.rst +++ b/docs/howto/kubernetes.rst @@ -36,13 +36,13 @@ guide, so we'll go for the simplest possible configuration instead: After saving this ``Dockerfile``, build the image: -.. code:: console +.. code-block:: console $ docker build -t websockets-test:1.0 . Test your image by running: -.. code:: console +.. code-block:: console $ docker run --name run-websockets-test --publish 32080:80 --rm \ websockets-test:1.0 @@ -50,7 +50,7 @@ Test your image by running: Then, in another shell, in a virtualenv where websockets is installed, connect to the app and check that it echoes anything you send: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:32080/ Connected to ws://localhost:32080/. @@ -60,14 +60,14 @@ to the app and check that it echoes anything you send: Now, in yet another shell, stop the app with: -.. code:: console +.. code-block:: console $ docker kill -s TERM run-websockets-test Going to the shell where you connected to the app, you can confirm that it shut down gracefully: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:32080/ Connected to ws://localhost:32080/. @@ -96,7 +96,7 @@ to production, you would configure an Ingress_. After saving this to a file called ``deployment.yaml``, you can deploy: -.. code:: console +.. code-block:: console $ kubectl apply -f deployment.yaml service/websockets-test created @@ -104,7 +104,7 @@ After saving this to a file called ``deployment.yaml``, you can deploy: Now you have a deployment with one pod running: -.. code:: console +.. code-block:: console $ kubectl get deployment websockets-test NAME READY UP-TO-DATE AVAILABLE AGE @@ -115,7 +115,7 @@ Now you have a deployment with one pod running: You can connect to the service — press Ctrl-D to exit: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:32080/ Connected to ws://localhost:32080/. @@ -126,7 +126,7 @@ Validate deployment First, let's ensure the liveness probe works by making the app unresponsive: -.. code:: console +.. code-block:: console $ curl http://localhost:32080/inemuri Sleeping for 10s @@ -139,7 +139,7 @@ Therefore Kubernetes should restart the pod after at most 5 seconds. Indeed, after a few seconds, the pod reports a restart: -.. code:: console +.. code-block:: console $ kubectl get pods -l app=websockets-test NAME READY STATUS RESTARTS AGE @@ -147,14 +147,14 @@ Indeed, after a few seconds, the pod reports a restart: Next, let's take it one step further and crash the app: -.. code:: console +.. code-block:: console $ curl http://localhost:32080/seppuku Terminating The pod reports a second restart: -.. code:: console +.. code-block:: console $ kubectl get pods -l app=websockets-test NAME READY STATUS RESTARTS AGE @@ -167,14 +167,14 @@ Scale deployment Of course, Kubernetes is for scaling. Let's scale — modestly — to 10 pods: -.. code:: console +.. code-block:: console $ kubectl scale deployment.apps/websockets-test --replicas=10 deployment.apps/websockets-test scaled After a few seconds, we have 10 pods running: -.. code:: console +.. code-block:: console $ kubectl get deployment websockets-test NAME READY UP-TO-DATE AVAILABLE AGE @@ -191,7 +191,7 @@ over 50 * 6 * 0.1 = 30 seconds. Let's try it: -.. code:: console +.. code-block:: console $ ulimit -n 512 $ time python benchmark.py 500 6 @@ -204,7 +204,7 @@ stabilize the test setup. Finally, we can scale back to one pod. -.. code:: console +.. code-block:: console $ kubectl scale deployment.apps/websockets-test --replicas=1 deployment.apps/websockets-test scaled diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index e20f82098..30545fbc7 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -33,7 +33,7 @@ This configuration runs four instances of the app. Install Supervisor and run it: -.. code:: console +.. code-block:: console $ supervisord -c supervisord.conf -n @@ -69,13 +69,13 @@ Then we combine the `WebSocket proxying`_ and `load balancing`_ guides: Save the configuration to ``nginx.conf``, install nginx, and run it: -.. code:: console +.. code-block:: console $ nginx -c nginx.conf -p . You can confirm that nginx proxies connections properly: -.. code:: console +.. code-block:: console $ PYTHONPATH=src python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. diff --git a/docs/howto/supervisor.rst b/docs/howto/supervisor.rst index 0c07aebae..5eefc7711 100644 --- a/docs/howto/supervisor.rst +++ b/docs/howto/supervisor.rst @@ -14,14 +14,14 @@ connections. Create and activate a virtualenv: -.. code:: console +.. code-block:: console $ python -m venv supervisor-websockets $ . supervisor-websockets/bin/activate Install websockets and Supervisor: -.. code:: console +.. code-block:: console $ pip install websockets $ pip install supervisor @@ -45,7 +45,7 @@ running, restarting them if they exit. Now start Supervisor in the foreground: -.. code:: console +.. code-block:: console $ supervisord -c supervisord.conf -n INFO Increased RLIMIT_NOFILE limit to 1024 @@ -62,7 +62,7 @@ Now start Supervisor in the foreground: In another shell, after activating the virtualenv, we can connect to the app — press Ctrl-D to exit: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. @@ -72,13 +72,13 @@ press Ctrl-D to exit: Look at the pid of an instance of the app in the logs and terminate it: -.. code:: console +.. code-block:: console $ kill -TERM 43597 The logs show that Supervisor restarted this instance: -.. code:: console +.. code-block:: console INFO exited: websockets-test_00 (exit status 0; expected) INFO spawned: 'websockets-test_00' with pid 43629 @@ -87,7 +87,7 @@ The logs show that Supervisor restarted this instance: Now let's check what happens when we shut down Supervisor, but first let's establish a connection and leave it open: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. @@ -95,14 +95,14 @@ establish a connection and leave it open: Look at the pid of supervisord itself in the logs and terminate it: -.. code:: console +.. code-block:: console $ kill -TERM 43596 The logs show that Supervisor terminated all instances of the app before exiting: -.. code:: console +.. code-block:: console WARN received SIGTERM indicating exit request INFO waiting for websockets-test_00, websockets-test_01, websockets-test_02, websockets-test_03 to die @@ -113,7 +113,7 @@ exiting: And you can see that the connection to the app was closed gracefully: -.. code:: console +.. code-block:: console $ python -m websockets ws://localhost:8080/ Connected to ws://localhost:8080/. diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 4d702b2f6..31bfd6465 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -154,7 +154,7 @@ Run the experiment in an environment where websockets is installed: .. _experiments/authentication: https://github.com/aaugustin/websockets/tree/main/experiments/authentication -.. code:: console +.. code-block:: console $ python experiments/authentication/app.py Running on http://localhost:8000/ @@ -172,7 +172,7 @@ First message As soon as the connection is open, the client sends a message containing the token: -.. code:: javascript +.. code-block:: javascript const websocket = new WebSocket("ws://.../"); websocket.onopen = () => websocket.send(token); @@ -183,7 +183,7 @@ At the beginning of the connection handler, the server receives this message and authenticates the user. If authentication fails, the server closes the connection: -.. code:: python +.. code-block:: python async def first_message_handler(websocket): token = await websocket.recv() @@ -200,7 +200,7 @@ Query parameter The client adds the token to the WebSocket URI in a query parameter before opening the connection: -.. code:: javascript +.. code-block:: javascript const uri = `ws://.../?token=${token}`; const websocket = new WebSocket(uri); @@ -210,7 +210,7 @@ opening the connection: The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns a HTTP 401: -.. code:: python +.. code-block:: python class QueryParamProtocol(websockets.WebSocketServerProtocol): async def process_request(self, path, headers): @@ -237,7 +237,7 @@ The client sets a cookie containing the token before opening the connection. The cookie must be set by an iframe loaded from the same origin as the WebSocket server. This requires passing the token to this iframe. -.. code:: javascript +.. code-block:: javascript // in main window iframe.contentWindow.postMessage(token, "http://..."); @@ -256,7 +256,7 @@ This involves several events. Look at the full implementation for details. The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns a HTTP 401: -.. code:: python +.. code-block:: python class CookieProtocol(websockets.WebSocketServerProtocol): async def process_request(self, path, headers): @@ -284,7 +284,7 @@ User information The client adds the token to the WebSocket URI in user information before opening the connection: -.. code:: javascript +.. code-block:: javascript const uri = `ws://token:${token}@.../`; const websocket = new WebSocket(uri); @@ -297,7 +297,7 @@ than a token, we send ``token`` as username and the token as password. The server intercepts the HTTP request, extracts the token and authenticates the user. If authentication fails, it returns a HTTP 401: -.. code:: python +.. code-block:: python class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): @@ -326,7 +326,7 @@ solution in this scenario. To authenticate a websockets client with HTTP Basic Authentication (:rfc:`7617`), include the credentials in the URI: -.. code:: python +.. code-block:: python async with websockets.connect( f"wss://{username}:{password}@example.com", @@ -339,7 +339,7 @@ contain unsafe characters.) To authenticate a websockets client with HTTP Bearer Authentication (:rfc:`6750`), add a suitable ``Authorization`` header: -.. code:: python +.. code-block:: python async with websockets.connect( "wss://example.com", diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 51666ceea..815a29b3f 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -56,7 +56,7 @@ Latency between a client and a server may increase for two reasons: You can monitor latency as follows: -.. code:: python +.. code-block:: python import asyncio import logging From 46d973a3163a92117259122b7b636cff1b142a23 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 14:54:06 +0100 Subject: [PATCH 204/762] Upgrade RTD config to v2. --- .readthedocs.yml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 109affab4..0369e0656 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,7 +1,13 @@ +version: 2 + build: - image: latest + os: ubuntu-20.04 + tools: + python: "3.10" -python: - version: 3.7 +sphinx: + configuration: docs/conf.py -requirements_file: docs/requirements.txt +python: + install: + - requirements: docs/requirements.txt From e6b8522b25d66b4847a29fa1d0fa27f135909279 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 16:42:32 +0100 Subject: [PATCH 205/762] Remove dead code. --- tests/test_imports.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_imports.py b/tests/test_imports.py index d84808902..8f1625a9b 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,4 +1,3 @@ -import sys import types import unittest import warnings @@ -33,10 +32,6 @@ def test_get_deprecated_alias(self): with warnings.catch_warnings(record=True) as recorded_warnings: self.assertEqual(self.mod.bar, bar) - # No warnings raised on pre-PEP 526 Python. - if sys.version_info[:2] < (3, 7): # pragma: no cover - return - self.assertEqual(len(recorded_warnings), 1) warning = recorded_warnings[0].message self.assertEqual( From 11cbe72428bf78ac2e89b1ec2844969bc3f7ab8a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 15:19:46 +0100 Subject: [PATCH 206/762] Introduce development version numbers. This prevents developement versions of websocket from advertising themselves as the previous release while they may contain backwards-incompatible changes since that release. --- docs/conf.py | 18 ++------- src/websockets/version.py | 77 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d22b85a82..750682573 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,8 +25,7 @@ copyright = f"2013-{datetime.date.today().year}, Aymeric Augustin and contributors" author = "Aymeric Augustin" -# The full version, including alpha/beta/rc tags -release = "10.0" +from websockets.version import tag as version, version as release # -- General configuration --------------------------------------------------- @@ -91,20 +90,9 @@ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # Configure viewcode extension. -try: - git_sha1 = subprocess.run( - "git rev-parse --short HEAD", - capture_output=True, - shell=True, - check=True, - text=True, - ).stdout.strip() -except subprocess.SubprocessError as exc: - print("Cannot get git commit, disabling linkcode:", exc) - extensions.remove("sphinx.ext.linkcode") -else: - code_url = f"https://github.com/aaugustin/websockets/blob/{git_sha1}" +from websockets.version import commit +code_url = f"https://github.com/aaugustin/websockets/blob/{commit}" def linkcode_resolve(domain, info): assert domain == "py" diff --git a/src/websockets/version.py b/src/websockets/version.py index 168f8b054..3a6e6aa08 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1 +1,76 @@ -version = "10.0" +from __future__ import annotations + + +__all__ = ["tag", "version", "commit"] + + +# ========= =========== =================== +# release development +# ========= =========== =================== +# tag X.Y X.Y (upcoming) +# version X.Y X.Y.dev1+g5678cde +# commit X.Y 5678cde +# ========= =========== =================== + + +# When tagging a release, set `released = True`. +# After tagging a release, set `released = False` and increment `tag`. + +released = False + +tag = version = commit = "10.1" + + +if not released: # pragma: no cover + import pathlib + import re + import subprocess + + def get_version(tag: str) -> str: + # Since setup.py executes the contents of src/websockets/version.py, + # __file__ can point to either of these two files. + file_path = pathlib.Path(__file__) + root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2] + + # Read version from git if available. This prevents reading stale + # information from src/websockets.egg-info after building a sdist. + try: + description = subprocess.run( + ["git", "describe", "--dirty", "--tags", "--long"], + capture_output=True, + cwd=root_dir, + check=True, + text=True, + ).stdout.strip() + except subprocess.CalledProcessError: + pass + else: + description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7}(?:-dirty)?)" + match = re.fullmatch(description_re, description) + assert match is not None + distance, remainder = match.groups() + remainder = remainder.replace("-", ".") # required by PEP 440 + return f"{tag}.dev{distance}+{remainder}" + + # Read version from package metadata if it is installed. + try: + import importlib.metadata # move up when dropping Python 3.7 + + return importlib.metadata.version("websockets") + except ImportError: + pass + + # Avoid crashing if the development version cannot be determined. + return f"{tag}.dev0+gunknown" + + version = get_version(tag) + + def get_commit(tag: str, version: str) -> str: + # Extract commit from version, falling back to tag if not available. + version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7}|unknown)(?:\.dirty)?" + match = re.fullmatch(version_re, version) + assert match is not None + (commit,) = match.groups() + return tag if commit == "unknown" else commit + + commit = get_commit(tag, version) From 150f2a04551a3af669dccfdc5efe59d1e2cae40a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 08:09:38 +0200 Subject: [PATCH 207/762] Update marketing pitch. --- README.rst | 23 +++++++++++------------ docs/index.rst | 8 ++++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/README.rst b/README.rst index 82f89f952..f8df94ba4 100644 --- a/README.rst +++ b/README.rst @@ -25,11 +25,10 @@ What is ``websockets``? ----------------------- -``websockets`` is a library for building WebSocket servers_ and clients_ in -Python with a focus on correctness and simplicity. +websockets is a library for building WebSocket_ servers and clients in Python +with a focus on correctness, simplicity, robustness, and performance. -.. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py +.. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. @@ -92,19 +91,19 @@ Why should I use ``websockets``? The development of ``websockets`` is shaped by four principles: -1. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and - ``await ws.send(msg)``; ``websockets`` takes care of managing connections +1. **Correctness**: ``websockets`` is heavily tested for compliance + with :rfc:`6455`. Continuous integration fails under 100% branch + coverage. + +2. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and + ``await ws.send(msg)``. ``websockets`` takes care of managing connections so you can focus on your application. -2. **Robustness**: ``websockets`` is built for production; for example it was +3. **Robustness**: ``websockets`` is built for production. For example, it was the only library to `handle backpressure correctly`_ before the issue became widely known in the Python community. -3. **Quality**: ``websockets`` is heavily tested. Continuous integration fails - under 100% branch coverage. Also it passes the industry-standard `Autobahn - Testsuite`_. - -4. **Performance**: memory usage is configurable. An extension written in C +4. **Performance**: memory usage is optimized and configurable. A C extension accelerates expensive operations. It's pre-compiled for Linux, macOS and Windows and packaged in the wheel format for each system and Python version. diff --git a/docs/index.rst b/docs/index.rst index 30d01c2f8..def9076af 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,11 +21,11 @@ websockets .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ -websockets is a library for building WebSocket servers_ and clients_ in Python -with a focus on correctness and simplicity. +websockets is a library for building WebSocket_ servers and +clients in Python with a focus on correctness, simplicity, robustness, and +performance. -.. _servers: https://github.com/aaugustin/websockets/blob/main/example/server.py -.. _clients: https://github.com/aaugustin/websockets/blob/main/example/client.py +.. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, it provides an elegant coroutine-based API. From 22d881b00e823cce7237ac7c1aaad2c5d09cb95f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 08:10:14 +0200 Subject: [PATCH 208/762] Move interactive client to front page. Remove dependency on echo.websocket.org which no longer exists. --- docs/index.rst | 10 ++++++++++ docs/intro/index.rst | 7 ------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index def9076af..bdde6dc74 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,16 @@ And here's an echo server: .. literalinclude:: ../example/echo.py +Also, websockets provides an interactive client: + +.. code-block:: console + + $ python -m websockets ws://localhost:8765/ + Connected to ws://localhost:8765/. + > Hello world! + < Hello world! + Connection closed: 1000 (OK). + Do you like it? Let's dive in! .. toctree:: diff --git a/docs/intro/index.rst b/docs/intro/index.rst index bd7c48f81..f1ea2a2a4 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -210,10 +210,3 @@ You don't have to worry about performing the opening or the closing handshake, answering pings, or any other behavior required by the specification. websockets handles all this under the hood so you don't have to. - -One more thing... ------------------ - -websockets provides an interactive client:: - - $ python -m websockets wss://echo.websocket.org/ From b9e11125b781da3cf567cd2cf3036478c6ee1670 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 08:21:54 +0200 Subject: [PATCH 209/762] Move API design principle to front page. --- docs/index.rst | 4 ++++ docs/intro/index.rst | 10 ---------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index bdde6dc74..07835a81c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,10 @@ And here's an echo server: .. literalinclude:: ../example/echo.py +Don't worry about the opening and closing handshakes, pings and pongs, or any +other behavior described in the specification. websockets takes care of this +under the hood so you can focus on your application! + Also, websockets provides an interactive client: .. code-block:: console diff --git a/docs/intro/index.rst b/docs/intro/index.rst index f1ea2a2a4..bf3c96b33 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -200,13 +200,3 @@ unregister them when they disconnect. This simplistic example keeps track of connected clients in memory. This only works as long as you run a single process. In a practical application, the handler may subscribe to some channels on a message broker, for example. - -That's all! ------------ - -The design of the websockets API was driven by simplicity. - -You don't have to worry about performing the opening or the closing handshake, -answering pings, or any other behavior required by the specification. - -websockets handles all this under the hood so you don't have to. From 0fafda84ae722bf319a46d7cad4c67bcfd47c16d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 16:57:48 +0200 Subject: [PATCH 210/762] Extract patterns to a howto guide. Improve this guide. Fix #1022. --- docs/howto/index.rst | 1 + docs/howto/patterns.rst | 110 ++++++++++++++++++++++++++++++++++++++++ docs/intro/index.rst | 82 ------------------------------ 3 files changed, 111 insertions(+), 82 deletions(-) create mode 100644 docs/howto/patterns.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index d11789d48..d399cebc2 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -8,6 +8,7 @@ If you're stuck, perhaps you'll find the answer here. faq cheatsheet + patterns autoreload This guide will help you integrate websockets into a broader system. diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst new file mode 100644 index 000000000..c6f325d21 --- /dev/null +++ b/docs/howto/patterns.rst @@ -0,0 +1,110 @@ +Patterns +======== + +.. currentmodule:: websockets + +Here are typical patterns for processing messages in a WebSocket server or +client. You will certainly implement some of them in your application. + +This page gives examples of connection handlers for a server. However, they're +also applicable to a client, simply by assuming that ``websocket`` is a +connection created with :func:`~client.connect`. + +WebSocket connections are long-lived. You will usually write a loop to process +several messages during the lifetime of a connection. + +Consumer +-------- + +To receive messages from the WebSocket connection:: + + async def consumer_handler(websocket): + async for message in websocket: + await consumer(message) + +In this example, ``consumer()`` is a coroutine implementing your business +logic for processing a message received on the WebSocket connection. Each +message may be :class:`str` or :class:`bytes`. + +Iteration terminates when the client disconnects. + +Producer +-------- + +To send messages to the WebSocket connection:: + + async def producer_handler(websocket): + while True: + message = await producer() + await websocket.send(message) + +In this example, ``producer()`` is a coroutine implementing your business +logic for generating the next message to send on the WebSocket connection. +Each message must be :class:`str` or :class:`bytes`. + +Iteration terminates when the client disconnects +because :meth:`~server.WebSocketServerProtocol.send` raises a +:exc:`~exceptions.ConnectionClosed` exception, +which breaks out of the ``while True`` loop. + +Consumer and producer +--------------------- + +You can receive and send messages on the same WebSocket connection by +combining the consumer and producer patterns. This requires running two tasks +in parallel:: + + async def handler(websocket): + await asyncio.gather( + consumer_handler(websocket), + producer_handler(websocket), + ) + +If a task terminates, :func:`~asyncio.gather` doesn't cancel the other task. +This can result in a situation where the producer keeps running after the +consumer finished, which may leak resources. + +Here's a way to exit and close the WebSocket connection as soon as a task +terminates, after canceling the other task:: + + async def handler(websocket): + consumer_task = asyncio.create_task(consumer_handler(websocket)) + producer_task = asyncio.create_task(producer_handler(websocket)) + done, pending = await asyncio.wait( + [consumer_task, producer_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + +Registration +------------ + +To keep track of currently connected clients, you can register them when they +connect and unregister them when they disconnect:: + + connected = set() + + async def handler(websocket): + # Register. + connected.add(websocket) + try: + # Broadcast a message to all connected clients. + websockets.broadcast(connected, "Hello!") + await asyncio.sleep(10) + finally: + # Unregister. + connected.remove(websocket) + +This example maintains the set of connected clients in memory. This works as +long as you run a single process. It doesn't scale to multiple processes. + +Publish–subscribe +----------------- + +If you plan to run multiple processes and you want to communicate updates +between processes, then you must deploy a messaging system. You may find +publish-subscribe functionality useful. + +A complete implementation of this idea with Redis is described in +the :doc:`Django integration guide <../howto/django>`. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index bf3c96b33..e95f0889c 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -118,85 +118,3 @@ Then open this HTML file in several browsers. .. literalinclude:: ../../example/counter.html :language: html - -Common patterns ---------------- - -You will usually want to process several messages during the lifetime of a -connection. Therefore you must write a loop. Here are the basic patterns for -building a WebSocket server. - -Consumer -........ - -For receiving messages and passing them to a ``consumer`` coroutine:: - - async def consumer_handler(websocket): - async for message in websocket: - await consumer(message) - -In this example, ``consumer`` represents your business logic for processing -messages received on the WebSocket connection. - -Iteration terminates when the client disconnects. - -Producer -........ - -For getting messages from a ``producer`` coroutine and sending them:: - - async def producer_handler(websocket): - while True: - message = await producer() - await websocket.send(message) - -In this example, ``producer`` represents your business logic for generating -messages to send on the WebSocket connection. - -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` raises a -:exc:`~exceptions.ConnectionClosed` exception when the client disconnects, -which breaks out of the ``while True`` loop. - -Both sides -.......... - -You can read and write messages on the same connection by combining the two -patterns shown above and running the two tasks in parallel:: - - async def handler(websocket): - consumer_task = asyncio.ensure_future( - consumer_handler(websocket)) - producer_task = asyncio.ensure_future( - producer_handler(websocket)) - done, pending = await asyncio.wait( - [consumer_task, producer_task], - return_when=asyncio.FIRST_COMPLETED, - ) - for task in pending: - task.cancel() - -Registration -............ - -As shown in the synchronization example above, if you need to maintain a list -of currently connected clients, you must register them when they connect and -unregister them when they disconnect. - -:: - - connected = set() - - async def handler(websocket): - # Register. - connected.add(websocket) - try: - # Broadcast a message to all connected clients. - websockets.broadcast(connected, "Hello!") - await asyncio.sleep(10) - finally: - # Unregister. - connected.remove(websocket) - -This simplistic example keeps track of connected clients in memory. This only -works as long as you run a single process. In a practical application, the -handler may subscribe to some channels on a message broker, for example. From 33b38eea7c627fb314aaaa6aa29041917fd7718e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 18 Sep 2021 17:04:15 +0200 Subject: [PATCH 211/762] Extract examples to a quick start guide. --- docs/intro/index.rst | 101 ++------------- docs/intro/quickstart.rst | 116 ++++++++++++++++++ example/counter.html | 80 ------------ example/{ => quickstart}/client.py | 5 +- .../client_secure.py} | 5 +- example/quickstart/counter.css | 33 +++++ example/quickstart/counter.html | 18 +++ example/quickstart/counter.js | 26 ++++ example/{ => quickstart}/counter.py | 32 ++--- example/{ => quickstart}/localhost.pem | 0 example/{ => quickstart}/server.py | 5 +- .../server_secure.py} | 5 +- example/quickstart/show_time.html | 9 ++ example/quickstart/show_time.js | 12 ++ example/quickstart/show_time.py | 18 +++ example/show_time.html | 20 --- example/show_time.py | 20 --- 17 files changed, 261 insertions(+), 244 deletions(-) create mode 100644 docs/intro/quickstart.rst delete mode 100644 example/counter.html rename example/{ => quickstart}/client.py (86%) rename example/{secure_client.py => quickstart/client_secure.py} (86%) create mode 100644 example/quickstart/counter.css create mode 100644 example/quickstart/counter.html create mode 100644 example/quickstart/counter.js rename example/{ => quickstart}/counter.py (57%) rename example/{ => quickstart}/localhost.pem (100%) rename example/{ => quickstart}/server.py (87%) rename example/{secure_server.py => quickstart/server_secure.py} (86%) create mode 100644 example/quickstart/show_time.html create mode 100644 example/quickstart/show_time.js create mode 100755 example/quickstart/show_time.py delete mode 100644 example/show_time.html delete mode 100755 example/show_time.py diff --git a/docs/intro/index.rst b/docs/intro/index.rst index e95f0889c..a5b68bfbf 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -14,6 +14,8 @@ websockets requires Python ≥ 3.7. For each minor version (3.x), only the latest bugfix or security release (3.x.y) is officially supported. +It doesn't have any dependencies. + Installation ------------ @@ -21,100 +23,13 @@ Install websockets with:: pip install websockets -Basic example -------------- - -.. _server-example: - -Here's a WebSocket server example. - -It reads a name from the client, sends a greeting, and closes the connection. - -.. literalinclude:: ../../example/server.py - :emphasize-lines: 8,18 - -.. _client-example: - -On the server side, websockets executes the handler coroutine ``hello()`` once -for each WebSocket connection. It closes the connection when the handler -coroutine returns. - -Here's a corresponding WebSocket client example. - -.. literalinclude:: ../../example/client.py - :emphasize-lines: 10 - -Using :func:`~client.connect` as an asynchronous context manager ensures the -connection is closed before exiting the ``hello()`` coroutine. - -.. _secure-server-example: - -Secure example --------------- - -Secure WebSocket connections improve confidentiality and also reliability -because they reduce the risk of interference by bad proxies. - -The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. The -connection is encrypted with TLS_ (Transport Layer Security). ``wss`` -requires certificates like ``https``. - -.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security - -.. admonition:: TLS vs. SSL - :class: tip - - TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an - earlier encryption protocol; the name stuck. - -Here's how to adapt the server example to provide secure connections. See the -documentation of the :mod:`ssl` module for configuring the context securely. - -.. literalinclude:: ../../example/secure_server.py - :emphasize-lines: 19-21,24 - -Here's how to adapt the client. - -.. literalinclude:: ../../example/secure_client.py - :emphasize-lines: 10-12,16 - -This client needs a context because the server uses a self-signed certificate. - -A client connecting to a secure WebSocket server with a valid certificate -(i.e. signed by a CA that your Python installation trusts) can simply pass -``ssl=True`` to :func:`~client.connect` instead of building a context. - -Browser-based example ---------------------- - -Here's an example of how to run a WebSocket server and connect from a browser. - -Run this script in a console: - -.. literalinclude:: ../../example/show_time.py - -Then open this HTML file in a browser. - -.. literalinclude:: ../../example/show_time.html - :language: html - -Synchronization example ------------------------ - -A WebSocket server can receive events from clients, process them to update the -application state, and synchronize the resulting state across clients. - -Here's an example where any client can increment or decrement a counter. -Updates are propagated to all connected clients. - -The concurrency model of :mod:`asyncio` guarantees that updates are -serialized. +Wheels are available for all platforms. -Run this script in a console: +First steps +----------- -.. literalinclude:: ../../example/counter.py +If you're in a hurry, check out these examples. -Then open this HTML file in several browsers. +.. toctree:: -.. literalinclude:: ../../example/counter.html - :language: html + quickstart diff --git a/docs/intro/quickstart.rst b/docs/intro/quickstart.rst new file mode 100644 index 000000000..8c1221126 --- /dev/null +++ b/docs/intro/quickstart.rst @@ -0,0 +1,116 @@ +Quick start +=========== + +.. currentmodule:: websockets + +Here are a few examples to get you started quickly with websockets. + +Hello world! +------------ + +Here's a WebSocket server. + +It receives a name from the client, sends a greeting, and closes the connection. + +.. literalinclude:: ../../example/quickstart/server.py + +:func:`~server.serve` executes the connection handler coroutine ``hello()`` +once for each WebSocket connection. It closes the WebSocket connection when +the handler returns. + +Here's a corresponding WebSocket client. + +It sends a name to the server, receives a greeting, and closes the connection. + +.. literalinclude:: ../../example/quickstart/client.py + +Using :func:`~client.connect` as an asynchronous context manager ensures the +WebSocket connection is closed. + +.. _secure-server-example: + +Encryption +---------- + +Secure WebSocket connections improve confidentiality and also reliability +because they reduce the risk of interference by bad proxies. + +The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. The +connection is encrypted with TLS_ (Transport Layer Security). ``wss`` +requires certificates like ``https``. + +.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security + +.. admonition:: TLS vs. SSL + :class: tip + + TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an + earlier encryption protocol; the name stuck. + +Here's how to adapt the server to encrypt connections. See the documentation +of the :mod:`ssl` module for configuring the context securely. + +.. literalinclude:: ../../example/quickstart/server_secure.py + +Here's how to adapt the client similarly. + +.. literalinclude:: ../../example/quickstart/client_secure.py + +This client needs a context because the server uses a self-signed certificate. + +When connecting to a secure WebSocket server with a valid certificate — any +certificate signed by a CA that your Python installation trusts — you can +simply pass ``ssl=True`` to :func:`~client.connect`. + +In a browser +------------ + +The WebSocket protocol was invented for the web — as the name says! + +Here's how to connect to a WebSocket server in a browser. + +Run this script in a console: + +.. literalinclude:: ../../example/quickstart/show_time.py + +Save this file as ``show_time.html``: + +.. literalinclude:: ../../example/quickstart/show_time.html + :language: html + +Save this file as ``show_time.js``: + +.. literalinclude:: ../../example/quickstart/show_time.js + :language: js + +Then open ``show_time.html`` in a browser and see the clock tick irregularly. + +Broadcast +--------- + +A WebSocket server can receive events from clients, process them to update the +application state, and broadcast the updated state to all connected clients. + +Here's an example where any client can increment or decrement a counter. The +concurrency model of :mod:`asyncio` guarantees that updates are serialized. + +Run this script in a console: + +.. literalinclude:: ../../example/quickstart/counter.py + +Save this file as ``counter.html``: + +.. literalinclude:: ../../example/quickstart/counter.html + :language: html + +Save this file as ``counter.css``: + +.. literalinclude:: ../../example/quickstart/counter.css + :language: css + +Save this file as ``counter.js``: + +.. literalinclude:: ../../example/quickstart/counter.js + :language: js + +Then open ``counter.html`` file in several browsers and play with [+] and [-]. diff --git a/example/counter.html b/example/counter.html deleted file mode 100644 index 6310c4a16..000000000 --- a/example/counter.html +++ /dev/null @@ -1,80 +0,0 @@ - - - - WebSocket demo - - - -
-
-
-
?
-
+
-
-
- ? online -
- - - diff --git a/example/client.py b/example/quickstart/client.py similarity index 86% rename from example/client.py rename to example/quickstart/client.py index 062540202..8d588c2b0 100755 --- a/example/client.py +++ b/example/quickstart/client.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WS client example - import asyncio import websockets @@ -16,4 +14,5 @@ async def hello(): greeting = await websocket.recv() print(f"<<< {greeting}") -asyncio.run(hello()) +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/secure_client.py b/example/quickstart/client_secure.py similarity index 86% rename from example/secure_client.py rename to example/quickstart/client_secure.py index 8a1551e29..f4b39f2b8 100755 --- a/example/secure_client.py +++ b/example/quickstart/client_secure.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WSS (WS over TLS) client example, with a self-signed certificate - import asyncio import pathlib import ssl @@ -22,4 +20,5 @@ async def hello(): greeting = await websocket.recv() print(f"<<< {greeting}") -asyncio.run(hello()) +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/quickstart/counter.css b/example/quickstart/counter.css new file mode 100644 index 000000000..e1f4b7714 --- /dev/null +++ b/example/quickstart/counter.css @@ -0,0 +1,33 @@ +body { + font-family: "Courier New", sans-serif; + text-align: center; +} +.buttons { + font-size: 4em; + display: flex; + justify-content: center; +} +.button, .value { + line-height: 1; + padding: 2rem; + margin: 2rem; + border: medium solid; + min-height: 1em; + min-width: 1em; +} +.button { + cursor: pointer; + user-select: none; +} +.minus { + color: red; +} +.plus { + color: green; +} +.value { + min-width: 2em; +} +.state { + font-size: 2em; +} diff --git a/example/quickstart/counter.html b/example/quickstart/counter.html new file mode 100644 index 000000000..2e3433bd2 --- /dev/null +++ b/example/quickstart/counter.html @@ -0,0 +1,18 @@ + + + + WebSocket demo + + + +
+
-
+
?
+
+
+
+
+ ? online +
+ + + diff --git a/example/quickstart/counter.js b/example/quickstart/counter.js new file mode 100644 index 000000000..37d892a28 --- /dev/null +++ b/example/quickstart/counter.js @@ -0,0 +1,26 @@ +window.addEventListener("DOMContentLoaded", () => { + const websocket = new WebSocket("ws://localhost:6789/"); + + document.querySelector(".minus").addEventListener("click", () => { + websocket.send(JSON.stringify({ action: "minus" })); + }); + + document.querySelector(".plus").addEventListener("click", () => { + websocket.send(JSON.stringify({ action: "plus" })); + }); + + websocket.onmessage = ({ data }) => { + const event = JSON.parse(data); + switch (event.type) { + case "value": + document.querySelector(".value").textContent = event.value; + break; + case "users": + const users = `${event.count} user${event.count == 1 ? "" : "s"}`; + document.querySelector(".users").textContent = users; + break; + default: + console.error("unsupported event", event); + } + }; +}); diff --git a/example/counter.py b/example/quickstart/counter.py similarity index 57% rename from example/counter.py rename to example/quickstart/counter.py index 6e33b3afc..566e12965 100755 --- a/example/counter.py +++ b/example/quickstart/counter.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WS server example that synchronizes state across clients - import asyncio import json import logging @@ -9,47 +7,43 @@ logging.basicConfig() -STATE = {"value": 0} - USERS = set() - -def state_event(): - return json.dumps({"type": "state", **STATE}) - +VALUE = 0 def users_event(): return json.dumps({"type": "users", "count": len(USERS)}) +def value_event(): + return json.dumps({"type": "value", "value": VALUE}) async def counter(websocket): + global USERS, VALUE try: # Register user USERS.add(websocket) websockets.broadcast(USERS, users_event()) # Send current state to user - await websocket.send(state_event()) + await websocket.send(value_event()) # Manage state changes async for message in websocket: - data = json.loads(message) - if data["action"] == "minus": - STATE["value"] -= 1 - websockets.broadcast(USERS, state_event()) - elif data["action"] == "plus": - STATE["value"] += 1 - websockets.broadcast(USERS, state_event()) + event = json.loads(message) + if event["action"] == "minus": + VALUE -= 1 + websockets.broadcast(USERS, value_event()) + elif event["action"] == "plus": + VALUE += 1 + websockets.broadcast(USERS, value_event()) else: - logging.error("unsupported event: %s", data) + logging.error("unsupported event: %s", event) finally: # Unregister user USERS.remove(websocket) websockets.broadcast(USERS, users_event()) - async def main(): async with websockets.serve(counter, "localhost", 6789): await asyncio.Future() # run forever - if __name__ == "__main__": asyncio.run(main()) diff --git a/example/localhost.pem b/example/quickstart/localhost.pem similarity index 100% rename from example/localhost.pem rename to example/quickstart/localhost.pem diff --git a/example/server.py b/example/quickstart/server.py similarity index 87% rename from example/server.py rename to example/quickstart/server.py index 7fd7bdf4c..31b182972 100755 --- a/example/server.py +++ b/example/quickstart/server.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WS server example - import asyncio import websockets @@ -18,4 +16,5 @@ async def main(): async with websockets.serve(hello, "localhost", 8765): await asyncio.Future() # run forever -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/secure_server.py b/example/quickstart/server_secure.py similarity index 86% rename from example/secure_server.py rename to example/quickstart/server_secure.py index f0231bc16..de41d30dc 100755 --- a/example/secure_server.py +++ b/example/quickstart/server_secure.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# WSS (WS over TLS) server example, with a self-signed certificate - import asyncio import pathlib import ssl @@ -24,4 +22,5 @@ async def main(): async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): await asyncio.Future() # run forever -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/quickstart/show_time.html b/example/quickstart/show_time.html new file mode 100644 index 000000000..b1c93b141 --- /dev/null +++ b/example/quickstart/show_time.html @@ -0,0 +1,9 @@ + + + + WebSocket demo + + + + + diff --git a/example/quickstart/show_time.js b/example/quickstart/show_time.js new file mode 100644 index 000000000..26bed7ec9 --- /dev/null +++ b/example/quickstart/show_time.js @@ -0,0 +1,12 @@ +window.addEventListener("DOMContentLoaded", () => { + const messages = document.createElement("ul"); + document.body.appendChild(messages); + + const websocket = new WebSocket("ws://localhost:5678/"); + websocket.onmessage = ({ data }) => { + const message = document.createElement("li"); + const content = document.createTextNode(data); + message.appendChild(content); + messages.appendChild(message); + }; +}); diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py new file mode 100755 index 000000000..facd56b00 --- /dev/null +++ b/example/quickstart/show_time.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import random +import websockets + +async def show_time(websocket): + while websocket.open: + await websocket.send(datetime.datetime.utcnow().isoformat() + "Z") + await asyncio.sleep(random.random() * 2 + 1) + +async def main(): + async with websockets.serve(show_time, "localhost", 5678): + await asyncio.Future() # run forever + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/show_time.html b/example/show_time.html deleted file mode 100644 index 721f44264..000000000 --- a/example/show_time.html +++ /dev/null @@ -1,20 +0,0 @@ - - - - WebSocket demo - - - - - diff --git a/example/show_time.py b/example/show_time.py deleted file mode 100755 index b5a153b71..000000000 --- a/example/show_time.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python - -# WS server that sends messages at random intervals - -import asyncio -import datetime -import random -import websockets - -async def time(websocket): - while True: - now = datetime.datetime.utcnow().isoformat() + "Z" - await websocket.send(now) - await asyncio.sleep(random.random() * 3) - -async def main(): - async with websockets.serve(time, "localhost", 5678): - await asyncio.Future() # run forever - -asyncio.run(main()) From ce611d79b3c849f66da6645513f6a28195e298e1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Sep 2021 21:20:13 +0200 Subject: [PATCH 212/762] Add tutorial. --- docs/conf.py | 6 +- docs/howto/heroku.rst | 4 +- docs/intro/index.rst | 23 +- docs/intro/tutorial1.rst | 591 ++++++++++++++++++++++++ docs/intro/tutorial2.rst | 565 ++++++++++++++++++++++ docs/intro/tutorial3.rst | 290 ++++++++++++ docs/spelling_wordlist.txt | 1 + example/tutorial/start/connect4.css | 105 +++++ example/tutorial/start/connect4.js | 45 ++ example/tutorial/start/connect4.py | 62 +++ example/tutorial/start/favicon.ico | Bin 0 -> 5430 bytes example/tutorial/step1/app.py | 65 +++ example/tutorial/step1/connect4.css | 1 + example/tutorial/step1/connect4.js | 1 + example/tutorial/step1/connect4.py | 1 + example/tutorial/step1/favicon.ico | 1 + example/tutorial/step1/index.html | 10 + example/tutorial/step1/main.js | 53 +++ example/tutorial/step2/app.py | 190 ++++++++ example/tutorial/step2/connect4.css | 1 + example/tutorial/step2/connect4.js | 1 + example/tutorial/step2/connect4.py | 1 + example/tutorial/step2/favicon.ico | 1 + example/tutorial/step2/index.html | 15 + example/tutorial/step2/main.js | 83 ++++ example/tutorial/step3/Procfile | 1 + example/tutorial/step3/app.py | 198 ++++++++ example/tutorial/step3/connect4.css | 1 + example/tutorial/step3/connect4.js | 1 + example/tutorial/step3/connect4.py | 1 + example/tutorial/step3/favicon.ico | 1 + example/tutorial/step3/index.html | 15 + example/tutorial/step3/main.js | 93 ++++ example/tutorial/step3/requirements.txt | 1 + 34 files changed, 2422 insertions(+), 6 deletions(-) create mode 100644 docs/intro/tutorial1.rst create mode 100644 docs/intro/tutorial2.rst create mode 100644 docs/intro/tutorial3.rst create mode 100644 example/tutorial/start/connect4.css create mode 100644 example/tutorial/start/connect4.js create mode 100644 example/tutorial/start/connect4.py create mode 100644 example/tutorial/start/favicon.ico create mode 100644 example/tutorial/step1/app.py create mode 120000 example/tutorial/step1/connect4.css create mode 120000 example/tutorial/step1/connect4.js create mode 120000 example/tutorial/step1/connect4.py create mode 120000 example/tutorial/step1/favicon.ico create mode 100644 example/tutorial/step1/index.html create mode 100644 example/tutorial/step1/main.js create mode 100644 example/tutorial/step2/app.py create mode 120000 example/tutorial/step2/connect4.css create mode 120000 example/tutorial/step2/connect4.js create mode 120000 example/tutorial/step2/connect4.py create mode 120000 example/tutorial/step2/favicon.ico create mode 100644 example/tutorial/step2/index.html create mode 100644 example/tutorial/step2/main.js create mode 100644 example/tutorial/step3/Procfile create mode 100644 example/tutorial/step3/app.py create mode 120000 example/tutorial/step3/connect4.css create mode 120000 example/tutorial/step3/connect4.js create mode 120000 example/tutorial/step3/connect4.py create mode 120000 example/tutorial/step3/favicon.ico create mode 100644 example/tutorial/step3/index.html create mode 100644 example/tutorial/step3/main.js create mode 100644 example/tutorial/step3/requirements.txt diff --git a/docs/conf.py b/docs/conf.py index 750682573..8d9fcdc8f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -95,7 +95,11 @@ code_url = f"https://github.com/aaugustin/websockets/blob/{commit}" def linkcode_resolve(domain, info): - assert domain == "py" + # Non-linkable objects from the starter kit in the tutorial. + if domain == "js" or info["module"] == "connect4": + return + + assert domain == "py", "expected only Python objects" mod = importlib.import_module(info["module"]) if "." in info["fullname"]: diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 6a7c4d00b..464420e05 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -42,6 +42,7 @@ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/heroku/app.py + :language: text Heroku expects the server to `listen on a specific port`_, which is provided in the ``$PORT`` environment variable. The app reads it and passes it to @@ -62,6 +63,7 @@ In order to build the app, Heroku needs to know that it depends on websockets. Create a ``requirements.txt`` file containing this line: .. literalinclude:: ../../example/deployment/heroku/requirements.txt + :language: text Heroku also needs to know how to run the app. Create a ``Procfile`` with this content: @@ -86,7 +88,7 @@ The app is ready. Let's deploy it! .. code-block:: console - $ git push heroku main + $ git push heroku ... lots of output... diff --git a/docs/intro/index.rst b/docs/intro/index.rst index a5b68bfbf..2c66dea9a 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -16,19 +16,34 @@ websockets requires Python ≥ 3.7. It doesn't have any dependencies. +.. _install: + Installation ------------ -Install websockets with:: +Install websockets with: + +.. code-block:: console - pip install websockets + $ pip install websockets Wheels are available for all platforms. -First steps +Tutorial +-------- + +Learn how to build an real-time web application with websockets. + +.. toctree:: + + tutorial1 + tutorial2 + tutorial3 + +In a hurry? ----------- -If you're in a hurry, check out these examples. +Check out these examples. .. toctree:: diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst new file mode 100644 index 000000000..ab4f39d79 --- /dev/null +++ b/docs/intro/tutorial1.rst @@ -0,0 +1,591 @@ +Part 1 - Send & receive +======================= + +.. currentmodule:: websockets + +In this tutorial, you're going to build a web-based `Connect Four`_ game. + +.. _Connect Four: https://en.wikipedia.org/wiki/Connect_Four + +The web removes the constraint of being in the same room for playing a game. +Two players can connect over of the Internet, regardless of where they are, +and play in their browsers. + +When a player makes a move, it should be reflected immediately on both sides. +This is difficult to implement over HTTP due to the request-response style of +the protocol. + +Indeed, there is no good way to be notified when the other player makes a +move. Workarounds such as polling or long-polling introduce significant +overhead. + +Enter `WebSocket `_. + +The WebSocket protocol provides two-way communication between a browser and a +server over a persistent connection. That's exactly what you need to exchange +moves between players, via a server. + +.. admonition:: This is the first part of the tutorial. + + * In this :doc:`first part `, you will create a server and + connect one browser; you can play if you share the same browser. + * In the :doc:`second part `, you will connect a second + browser; you can play from different browsers on a local network. + * In the :doc:`third part `, you will deploy the game to the + web; you can play from any browser connected to the Internet. + +Prerequisites +------------- + +This tutorial assumes basic knowledge of Python and JavaScript. + +If you're comfortable with :doc:`virtual environments `, +you can use one for this tutorial. Else, don't worry: websockets doesn't have +any dependencies; it shouldn't create trouble in the default environment. + +If you haven't installed websockets yet, do it now: + +.. code-block:: console + + $ pip install websockets + +Confirm that websockets is installed: + +.. code-block:: console + + $ python -m websockets --version + +.. admonition:: This tutorial is written for websockets |release|. + :class: tip + + If you installed another version, you should switch to the corresponding + version of the documentation. + +Download the starter kit +------------------------ + +Create a directory and download these three files: +:download:`connect4.js <../../example/tutorial/start/connect4.js>`, +:download:`connect4.css <../../example/tutorial/start/connect4.css>`, +and :download:`connect4.py <../../example/tutorial/start/connect4.py>`. + +The JavaScript module, along with the CSS file, provides a web-based user +interface. Here's its API. + +.. js:module:: connect4 + +.. js:data:: PLAYER1 + + Color of the first player. + +.. js:data:: PLAYER2 + + Color of the second player. + +.. js:function:: createBoard(board) + + Draw a board. + + :param board: DOM element containing the board; must be initially empty. + +.. js:function:: playMove(board, player, column, row) + + Play a move. + + :param board: DOM element containing the board. + :param player: :js:data:`PLAYER1` or :js:data:`PLAYER2`. + :param column: between ``0`` and ``6``. + :param row: between ``0`` and ``5``. + +The Python module provides a class to record moves and tell when a player +wins. Here's its API. + +.. module:: connect4 + +.. data:: PLAYER1 + :value: "red" + + Color of the first player. + +.. data:: PLAYER2 + :value: "yellow" + + Color of the second player. + +.. class:: Connect4 + + A Connect Four game. + + .. method:: play(player, column) + + Play a move. + + :param player: :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2`. + :param column: between ``0`` and ``6``. + :returns: Row where the checker lands, between ``0`` and ``5``. + :raises RuntimeError: if the move is illegal. + + .. attribute:: moves + + List of moves played during this game, as ``(player, column, row)`` + tuples. + + .. attribute:: winner + + :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2` if they + won; :obj:`None` if the game is still ongoing. + +.. currentmodule:: websockets + +Bootstrap the web UI +-------------------- + +Create an ``index.html`` file next to ``connect4.js`` and ``connect4.css`` +with this content: + +.. literalinclude:: ../../example/tutorial/step1/index.html + :language: html + +This HTML page contains an empty ``
`` element where you will draw the +Connect Four board. It loads a ``main.js`` script where you will write all +your JavaScript code. + +Create a ``main.js`` file next to ``index.html``. In this script, when the +page loads, draw the board: + +.. code-block:: javascript + + import { createBoard, playMove } from "./connect4.js"; + + window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + }); + +Open a shell, navigate to the directory containing these files, and start a +HTTP server: + +.. code-block:: console + + $ python -m http.server + +Open http://localhost:8000/ in a web browser. The page displays an empty board +with seven columns and six rows. You will play moves in this board later. + +Bootstrap the server +-------------------- + +Create an ``app.py`` file next to ``connect4.py`` with this content: + +.. code-block:: python + + #!/usr/bin/env python + + import asyncio + + import websockets + + + async def handler(websocket): + while True: + message = await websocket.recv() + print(message) + + + async def main(): + async with websockets.serve(handler, "", 8001): + await asyncio.Future() # run forever + + + if __name__ == "__main__": + asyncio.run(main()) + +The entry point of this program is ``asyncio.run(main())``. It creates an +asyncio event loop, runs the ``main()`` coroutine, and shuts down the loop. + +The ``main()`` coroutine calls :func:`~server.serve` to start a websockets +server. :func:`~server.serve` takes three positional arguments: + +* ``handler`` is a coroutine that manages a connection. When a client + connects, websockets calls ``handler`` with the connection in argument. + When ``handler`` terminates, websockets closes the connection. +* The second argument defines the network interfaces where the server can be + reached. Here, the server listens on all interfaces, so that other devices + on the same local network can connect. +* The third argument is the port on which the server listens. + +Invoking :func:`~server.serve` as an asynchronous context manager, in an +``async with`` block, ensures that the server shuts down properly when +terminating the program. + +For each connection, the ``handler()`` coroutine runs an infinite loop that +receives messages from the browser and prints them. + +Open a shell, navigate to the directory containing ``app.py``, and start the +server: + +.. code-block:: console + + $ python app.py + +This doesn't display anything. Hopefully the WebSocket server is running. +Let's make sure that it works. You cannot test the WebSocket server with a +web browser like you tested the HTTP server. However, you can test it with +websockets' interactive client. + +Open another shell and run this command: + +.. code-block:: console + + $ python -m websockets ws://localhost:8001/ + +You get a prompt. Type a message and press "Enter". Switch to the shell where +the server is running and check that the server received the message. Good! + +Exit the interactive client with Ctrl-C or Ctrl-D. + +Now, if you look at the console where you started the server, you can see the +stack trace of an exception: + +.. code-block:: pytb + + connection handler failed + Traceback (most recent call last): + ... + File "app.py", line 22, in handler + message = await websocket.recv() + ... + websockets.exceptions.ConnectionClosedOK: received 1000 (OK); then sent 1000 (OK) + +Indeed, the server was waiting for the next message +with :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` when the client +disconnected. When this happens, websockets raises +a :exc:`~exceptions.ConnectionClosedOK` exception to let you know that you +won't receive another message on this connection. + +This exception creates noise in the server logs, making it more difficult to +spot real errors when you add functionality to the server. Catch it in the +``handler()`` coroutine: + +.. code-block:: python + + async def handler(websocket): + while True: + try: + message = await websocket.recv() + except websockets.ConnectionClosedOK: + break + print(message) + +Stop the server with Ctrl-C and start it again: + +.. code-block:: console + + $ python app.py + +.. admonition:: You must restart the WebSocket server when you make changes. + :class: tip + + The WebSocket server loads the Python code in ``app.py`` then serves every + WebSocket request with this version of the code. As a consequence, + changes to ``app.py`` aren't visible until you restart the server. + + This is unlike the HTTP server that you started earlier with ``python -m + http.server``. For every request, this HTTP server reads the target file + and sends it. That's why changes are immediately visible. + + It is possible to :doc:`restart the WebSocket server automatically + <../howto/autoreload>` but this isn't necessary for this tutorial. + +Try connecting and disconnecting the interactive client again. +The :exc:`~exceptions.ConnectionClosedOK` exception doesn't appear anymore. + +This pattern is so common that websockets provides a shortcut for iterating +over messages received on the connection until the client disconnects: + +.. code-block:: python + + async def handler(websocket): + async for message in websocket: + print(message) + +Restart the server and check with the interactive client that its behavior +didn't change. + +At this point, you bootstrapped a web application and a WebSocket server. +Let's connect them. + +Transmit from browser to server +------------------------------- + +In JavaScript, you open a WebSocket connection as follows: + +.. code-block:: javascript + + const websocket = new WebSocket("ws://localhost:8001/"); + +Before you exchange messages with the server, you need to decide their format. +There is no universal convention for this. + +Let's use JSON objects with a ``type`` key identifying the type of the event +and the rest of the object containing properties of the event. + +Here's an event describing a move in the middle slot of the board: + +.. code-block:: javascript + + const event = {type: "play", column: 3}; + +Here's how to serialize this event to JSON and send it to the server: + +.. code-block:: javascript + + websocket.send(JSON.stringify(event)); + +Now you have all the building blocks to send moves to the server. + +Add this function to ``main.js``: + +.. literalinclude:: ../../example/tutorial/step1/main.js + :language: js + :start-at: function sendMoves + :end-before: window.addEventListener + +``sendMoves()`` registers a listener for ``click`` events on the board. The +listener figures out which column was clicked, builds a event of type +``"play"``, serializes it, and sends it to the server. + +Modify the initialization to open the WebSocket connection and call the +``sendMoves()`` function: + +.. code-block:: javascript + + window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + // Open the WebSocket connection and register event handlers. + const websocket = new WebSocket("ws://localhost:8001/"); + sendMoves(board, websocket); + }); + +Check that the HTTP server and the WebSocket server are still running. If you +stopped them, here are the commands to start them again: + +.. code-block:: console + + $ python -m http.server + +.. code-block:: console + + $ python app.py + +Refresh http://localhost:8000/ in your web browser. Click various columns in +the board. The server receives messages with the expected column number. + +There isn't any feedback in the board because you haven't implemented that +yet. Let's do it. + +Transmit from server to browser +------------------------------- + +In JavaScript, you receive WebSocket messages by listening to ``message`` +events. Here's how to receive a message from the server and deserialize it +from JSON: + +.. code-block:: javascript + + websocket.addEventListener("message", ({ data }) => { + const event = JSON.parse(data); + // do something with event + }); + +You're going to need three types of messages from the server to the browser: + +.. code-block:: javascript + + {type: "play", player: "red", column: 3, row: 0} + {type: "win", player: "red"} + {type: "error", message: "This slot is full."} + +The JavaScript code receiving these messages will dispatch events depending on +their type and take appropriate action. For example, it will react to an +event of type ``"play"`` by displaying the move on the board with +the :js:func:`~connect4.playMove` function. + +Add this function to ``main.js``: + +.. literalinclude:: ../../example/tutorial/step1/main.js + :language: js + :start-at: function showMessage + :end-before: function sendMoves + +.. admonition:: Why does ``showMessage`` use ``window.setTimeout``? + :class: hint + + When :js:func:`playMove` modifies the state of the board, the browser + renders changes asynchronously. Conversely, ``window.alert()`` runs + synchronously and blocks rendering while the alert is visible. + + If you called ``window.alert()`` immediately after :js:func:`playMove`, + the browser could display the alert before rendering the move. You could + get a "Player red wins!" alert without seeing red's last move. + + We're using ``window.alert()`` for simplicity in this tutorial. A real + application would display these messages in the user interface instead. + It wouldn't be vulnerable to this problem. + +Modify the initialization to call the ``receiveMoves()`` function: + +.. literalinclude:: ../../example/tutorial/step1/main.js + :language: js + :start-at: window.addEventListener + +At this point, the user interface should receive events properly. Let's test +it by modifying the server to send some events. + +Sending an event from Python is quite similar to JavaScript: + +.. code-block:: python + + event = {"type": "play", "player": "red", "column": 3, "row": 0} + await websocket.send(json.dumps(event)) + +.. admonition:: Don't forget to serialize the event with :func:`json.dumps`. + :class: tip + + Else, websockets raises ``TypeError: data is a dict-like object``. + +Modify the ``handler()`` coroutine in ``app.py`` as follows: + +.. code-block:: python + + import json + + from connect4 import PLAYER1, PLAYER2 + + async def handler(websocket): + for player, column, row in [ + (PLAYER1, 3, 0), + (PLAYER2, 3, 1), + (PLAYER1, 4, 0), + (PLAYER2, 4, 1), + (PLAYER1, 2, 0), + (PLAYER2, 1, 0), + (PLAYER1, 5, 0), + ]: + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + await websocket.send(json.dumps(event)) + await asyncio.sleep(0.5) + event = { + "type": "win", + "player": PLAYER1, + } + await websocket.send(json.dumps(event)) + +Restart the WebSocket server and refresh http://localhost:8000/ in your web +browser. Seven moves appear at 0.5 second intervals. Then an alert announces +the winner. + +Good! Now you know how to communicate both ways. + +Once you plug the game engine to process moves, you will have a fully +functional game. + +Add the game logic +------------------ + +In the ``handler()`` coroutine, you're going to initialize a game: + +.. code-block:: python + + from connect4 import Connect4 + + async def handler(websocket): + # Initialize a Connect Four game. + game = Connect4() + + ... + +Then, you're going to iterate over incoming messages and take these steps: + +* parse an event of type ``"play"``, the only type of event that the user + interface sends; +* play the move in the board with the :meth:`~connect4.Connect4.play` method, + alternating between the two players; +* if :meth:`~connect4.Connect4.play` raises :exc:`RuntimeError` because the + move is illegal, send an event of type ``"error"``; +* else, send an event of type ``"play"`` to tell the user interface where the + checker lands; +* if the move won the game, send an event of type ``"win"``. + +Try to implement this by yourself! + +Keep in mind that you must restart the WebSocket server and reload the page in +the browser when you make changes. + +When it works, you can play the game from a single browser, with players +taking alternate turns. + +.. admonition:: Enable debug logs to see all messages sent and received. + :class: tip + + Here's how to enable debug logs: + + .. code-block:: python + + import logging + + logging.basicConfig(format="%(message)s", level=logging.DEBUG) + +If you're stuck, a solution is available at the bottom of this document. + +Summary +------- + +In this first part of the tutorial, you learned how to: + +* build and run a WebSocket server in Python with :func:`~server.serve`; +* receive a message in a connection handler + with :meth:`~server.WebSocketServerProtocol.recv`; +* send a message in a connection handler + with :meth:`~server.WebSocketServerProtocol.send`; +* iterate over incoming messages with ``async for + message in websocket: ...``; +* open a WebSocket connection in JavaScript with the ``WebSocket`` API; +* send messages in a browser with ``WebSocket.send()``; +* receive messages in a browser by listening to ``message`` events; +* design a set of events to be exchanged between the browser and the server. + +You can now play a Connect Four game in a browser, communicating over a +WebSocket connection with a server where the game logic resides! + +However, the two players share a browser, so the constraint of being in the +same room still applies. + +Move on to the :doc:`second part ` of the tutorial to break this +constraint and play from separate browsers. + +Solution +-------- + +.. literalinclude:: ../../example/tutorial/step1/app.py + :caption: app.py + :language: python + :linenos: + +.. literalinclude:: ../../example/tutorial/step1/index.html + :caption: index.html + :language: html + :linenos: + +.. literalinclude:: ../../example/tutorial/step1/main.js + :caption: main.js + :language: js + :linenos: diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst new file mode 100644 index 000000000..669d46cde --- /dev/null +++ b/docs/intro/tutorial2.rst @@ -0,0 +1,565 @@ +Part 2 - Route & broadcast +========================== + +.. currentmodule:: websockets + +.. admonition:: This is the second part of the tutorial. + + * In the :doc:`first part `, you created a server and + connected one browser; you could play if you shared the same browser. + * In this :doc:`second part `, you will connect a second + browser; you can play from different browsers on a local network. + * In the :doc:`third part `, you will deploy the game to the + web; you can play from any browser connected to the Internet. + +In the first part of the tutorial, you opened a WebSocket connection from a +browser to a server and exchanged events to play moves. The state of the game +was stored in an instance of the :class:`~connect4.Connect4` class, +referenced as a local variable in the connection handler coroutine. + +Now you want to open two WebSocket connections from two separate browsers, one +for each player, to the same server in order to play the same game. This +requires moving the state of the game to a place where both connections can +access it. + +Share game state +---------------- + +As long as you're running a single server process, you can share state by +storing it in a global variable. + +.. admonition:: What if you need to scale to multiple server processes? + :class: hint + + In that case, you must design a way for the process that handles a given + connection to be aware of relevant events for that client. This is often + achieved with a publish / subscribe mechanism. + +How can you make two connection handlers agree on which game they're playing? +When the first player starts a game, you give it an identifier. Then, you +communicate the identifier to the second player. When the second player joins +the game, you look it up with the identifier. + +In addition to the game itself, you need to keep track of the WebSocket +connections of the two players. Since both players receive the same events, +you don't need to treat the two connections differently; you can store both +in the same set. + +Let's sketch this in code. + +A module-level :class:`dict` enables lookups by identifier: + +.. code-block:: python + + JOIN = {} + +When the first player starts the game, initialize and store it: + +.. code-block:: python + + import secrets + + async def handler(websocket): + ... + + # Initialize a Connect Four game, the set of WebSocket connections + # receiving moves from this game, and secret access token. + game = Connect4() + connected = {websocket} + + join_key = secrets.token_urlsafe(12) + JOIN[join_key] = game, connected + + try: + + ... + + finally: + del JOIN[join_key] + +When the second player joins the game, look it up: + +.. code-block:: python + + async def handler(websocket): + ... + + join_key = ... # TODO + + # Find the Connect Four game. + game, connected = JOIN[join_key] + + # Register to receive moves from this game. + connected.add(websocket) + try: + + ... + + finally: + connected.remove(websocket) + +Notice how we're carefully cleaning up global state with ``try: ... +finally: ...`` blocks. Else, we could leave references to games or +connections in global state, which would cause a memory leak. + +In both connection handlers, you have a ``game`` pointing to the same +:class:`~connect4.Connect4` instance, so you can interact with the game, +and a ``connected`` set of connections, so you can send game events to +both players as follows: + +.. code-block:: python + + async def handler(websocket): + + ... + + for connection in connected: + await connection.send(json.dumps(event)) + + ... + +Perhaps you spotted a major piece missing from the puzzle. How does the second +player obtain ``join_key``? Let's design new events to carry this information. + +To start a game, the first player sends an ``"init"`` event: + +.. code-block:: javascript + + {type: "init"} + +The connection handler for the first player creates a game as shown above and +responds with: + +.. code-block:: javascript + + {type: "init", join: ""} + +With this information, the user interface of the first player can create a +link to ``http://localhost:8000/?join=``. For the sake of simplicity, +we will assume that the first player shares this link with the second player +outside of the application, for example via an instant messaging service. + +To join the game, the second player sends a different ``"init"`` event: + +.. code-block:: javascript + + {type: "init", join: ""} + +The connection handler for the second player can look up the game with the +join key as shown above. There is no need to respond. + +Let's dive into the details of implementing this design. + +Start a game +------------ + +We'll start with the initialization sequence for the first player. + +In ``main.js``, define a function to send an initialization event when the +WebSocket connection is established, which triggers an ``open`` event: + +.. code-block:: javascript + + function initGame(websocket) { + websocket.addEventListener("open", () => { + // Send an "init" event for the first player. + const event = { type: "init" }; + websocket.send(JSON.stringify(event)); + }); + } + +Update the initialization sequence to call ``initGame()``: + +.. literalinclude:: ../../example/tutorial/step2/main.js + :language: js + :start-at: window.addEventListener + +In ``app.py``, define a new ``handler`` coroutine — keep a copy of the +previous one to reuse it later: + +.. code-block:: python + + import secrets + + + JOIN = {} + + + async def start(websocket): + # Initialize a Connect Four game, the set of WebSocket connections + # receiving moves from this game, and secret access token. + game = Connect4() + connected = {websocket} + + join_key = secrets.token_urlsafe(12) + JOIN[join_key] = game, connected + + try: + # Send the secret access token to the browser of the first player, + # where it'll be used for building a "join" link. + event = { + "type": "init", + "join": join_key, + } + await websocket.send(json.dumps(event)) + + # Temporary - for testing. + print("first player started game", id(game)) + async for message in websocket: + print("first player sent", message) + + finally: + del JOIN[join_key] + + + async def handler(websocket, path): + # Receive and parse the "init" event from the UI. + message = await websocket.recv() + event = json.loads(message) + assert event["type"] == "init" + + # First player starts a new game. + await start(websocket) + +In ``index.html``, add an ```` element to display the link to share with +the other player. + +.. code-block:: html + + + + + + +In ``main.js``, modify ``receiveMoves()`` to handle the ``"init"`` message and +set the target of that link: + +.. code-block:: javascript + + switch (event.type) { + case "init": + // Create link for inviting the second player. + document.querySelector(".join").href = "?join=" + event.join; + break; + // ... + } + +Restart the WebSocket server and reload http://localhost:8000/ in the browser. +There's a link labeled JOIN below the board with a target that looks like +http://localhost:8000/?join=95ftAaU5DJVP1zvb. + +The server logs say ``first player started game ...``. If you click the board, +you see ``"play"`` events. There is no feedback in the UI, though, because +you haven't restored the game logic yet. + +Before we get there, let's handle links with a ``join`` query parameter. + +Join a game +----------- + +We'll now update the initialization sequence to account for the second +player. + +In ``main.js``, update ``initGame()`` to send the join key in the ``"init"`` +message when it's in the URL: + +.. code-block:: javascript + + function initGame(websocket) { + websocket.addEventListener("open", () => { + // Send an "init" event according to who is connecting. + const params = new URLSearchParams(window.location.search); + let event = { type: "init" }; + if (params.has("join")) { + // Second player joins an existing game. + event.join = params.get("join"); + } else { + // First player starts a new game. + } + websocket.send(JSON.stringify(event)); + }); + } + +In ``app.py``, update the ``handler`` coroutine to look for the join key in +the ``"init"`` message, then load that game: + +.. code-block:: python + + async def error(websocket, message): + event = { + "type": "error", + "message": message, + } + await websocket.send(json.dumps(event)) + + + async def join(websocket, join_key): + # Find the Connect Four game. + try: + game, connected = JOIN[join_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + + # Temporary - for testing. + print("second player joined game", id(game)) + async for message in websocket: + print("second player sent", message) + + finally: + connected.remove(websocket) + + + async def handler(websocket): + # Receive and parse the "init" event from the UI. + message = await websocket.recv() + event = json.loads(message) + assert event["type"] == "init" + + if "join" in event: + # Second player joins an existing game. + await join(websocket, event["join"]) + else: + # First player starts a new game. + await start(websocket) + +Restart the WebSocket server and reload http://localhost:8000/ in the browser. + +Copy the link labeled JOIN and open it in another browser. You may also open +it in another tab or another window of the same browser; however, that makes +it a bit tricky to remember which one is the first or second player. + +.. admonition:: You must start a new game when you restart the server. + :class: tip + + Since games are stored in the memory of the Python process, they're lost + when you stop the server. + + Whenever you make changes to ``app.py``, you must restart the server, + create a new game in a browser, and join it in another browser. + +The server logs say ``first player started game ...`` and ``second player +joined game ...``. The numbers match, proving that the ``game`` local +variable in both connection handlers points to same object in the memory of +the Python process. + +Click the board in either browser. The server receives ``"play"`` events from +the corresponding player. + +In the initialization sequence, you're routing connections to ``start()`` or +``join()`` depending on the first message received by the server. This is a +common pattern in servers that handle different clients. + +.. admonition:: Why not use different URIs for ``start()`` and ``join()``? + :class: hint + + Instead of sending an initialization event, you could encode the join key + in the WebSocket URI e.g. ``ws://localhost:8001/join/``. The + WebSocket server would parse ``websocket.path`` and route the connection, + similar to how HTTP servers route requests. + + When you need to send sensitive data like authentication credentials to + the server, sending it an event is considered more secure than encoding + it in the URI because URIs end up in logs. + + For the purposes of this tutorial, both approaches are equivalent because + the join key comes from a HTTP URL. There isn't much at risk anyway! + +Now you can restore the logic for playing moves and you'll have a fully +functional two-player game. + +Add the game logic +------------------ + +Once the initialization is done, the game is symmetrical, so you can write a +single coroutine to process the moves of both players: + +.. code-block:: python + + async def play(game, player, connected): + ... + +With such a coroutine, you can replace the temporary code for testing in +``start()`` by: + +.. code-block:: python + + await play(game, PLAYER1, connected) + +and in ``join()`` by: + +.. code-block:: python + + await play(game, PLAYER2, connected) + +The ``play()`` coroutine will reuse much of the code you wrote in the first +part of the tutorial. + +Try to implement this by yourself! + +Keep in mind that you must restart the WebSocket server, reload the page to +start a new game with the first player, copy the JOIN link, and join the game +with the second player when you make changes. + +When ``play()`` works, you can play the game from two separate browsers, +possibly running on separate computers on the same local network. + +A complete solution is available at the bottom of this document. + +Watch a game +------------ + +Let's add one more feature: allow spectators to watch the game. + +The process for inviting a spectator can be the same as for inviting the +second player. You will have to duplicate all the initialization logic: + +- declare a ``WATCH`` global variable similar to ``JOIN``; +- generate a watch key when creating a game; it must be different from the + join key, or else a spectator could hijack a game by tweaking the URL; +- include the watch key in the ``"init"`` event sent to the first player; +- generate a WATCH link in the UI with a ``watch`` query parameter; +- update the ``initGame()`` function to handle such links; +- update the ``handler()`` coroutine to invoke a ``watch()`` coroutine for + spectators; +- prevent ``sendMoves()`` from sending ``"play"`` events for spectators. + +Once the initialization sequence is done, watching a game is as simple as +registering the WebSocket connection in the ``connected`` set in order to +receive game events and doing nothing until the spectator disconnects. You +can wait for a connection to terminate with +:meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed`: + +.. code-block:: python + + async def watch(websocket, watch_key): + + ... + + connected.add(websocket) + try: + await websocket.wait_closed() + finally: + connected.remove(websocket) + +The connection can terminate because the ``receiveMoves()`` function closed it +explicitly after receiving a ``"win"`` event, because the spectator closed +their browser, or because the network failed. + +Again, try to implement this by yourself. + +When ``watch()`` works, you can invite spectators to watch the game from other +browsers, as long as they're on the same local network. + +As a further improvement, you may support adding spectators while a game is +already in progress. This requires replaying moves that were played before +the spectator was added to the ``connected`` set. Past moves are available in +the :attr:`~connect4.Connect4.moves` attribute of the game. + +This feature is included in the solution proposed below. + +Broadcast +--------- + +When you need to send a message to the two players and to all spectators, +you're using this pattern: + +.. code-block:: python + + async def handler(websocket): + + ... + + for connection in connected: + await connection.send(json.dumps(event)) + + ... + +Since this is a very common pattern in WebSocket servers, websockets provides +the :func:`broadcast` helper for this purpose: + +.. code-block:: python + + async def handler(websocket): + + ... + + websockets.broadcast(connected, json.dumps(event)) + + ... + +Calling :func:`broadcast` once is more efficient than +calling :meth:`~legacy.protocol.WebSocketCommonProtocol.send` in a loop. + +However, there's a subtle difference in behavior. Did you notice that there's +no ``await`` in the second version? Indeed, :func:`broadcast` is a function, +not a coroutine like :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +or :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. + +It's quite obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` +is a coroutine. When you want to receive the next message, you have to wait +until the client sends it and the network transmits it. + +It's less obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.send` is +a coroutine. If you send many messages or large messages, you could write +data faster than the network can transmit it or the client can read it. Then, +outgoing data will pile up in buffers, which will consume memory and may +crash your application. + +To avoid this problem, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +waits until the write buffer drains. By slowing down the application as +necessary, this ensures that the server doesn't send data too quickly. This +is called backpressure and it's useful for building robust systems. + +That said, when you're sending the same messages to many clients in a loop, +applying backpressure in this way can become counterproductive. When you're +broadcasting, you don't want to slow down everyone to the pace of the slowest +clients; you want to drop clients that cannot keep up with the data stream. +That's why :func:`broadcast` doesn't wait until write buffers drain. + +For our Connect Four game, there's no difference in practice: the total amount +of data sent on a connection for a game of Connect Four is less than 64 KB, +so the write buffer never fills up and backpressure never kicks in anyway. + +Summary +------- + +In this second part of the tutorial, you learned how to: + +* configure a connection by exchanging initialization messages; +* keep track of connections within a single server process; +* wait until a client disconnects in a connection handler; +* broadcast a message to many connections efficiently. + +You can now play a Connect Four game from separate browser, communicating over +WebSocket connections with a server that synchronizes the game logic! + +However, the two players have to be on the same local network as the server, +so the constraint of being in the same place still mostly applies. + +Head over to the :doc:`third part ` of the tutorial to deploy the +game to the web and remove this constraint. + +Solution +-------- + +.. literalinclude:: ../../example/tutorial/step2/app.py + :caption: app.py + :language: python + :linenos: + +.. literalinclude:: ../../example/tutorial/step2/index.html + :caption: index.html + :language: html + :linenos: + +.. literalinclude:: ../../example/tutorial/step2/main.js + :caption: main.js + :language: js + :linenos: diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst new file mode 100644 index 000000000..9a447f39b --- /dev/null +++ b/docs/intro/tutorial3.rst @@ -0,0 +1,290 @@ +Part 3 - Deploy to the web +========================== + +.. currentmodule:: websockets + +.. admonition:: This is the third part of the tutorial. + + * In the :doc:`first part `, you created a server and + connected one browser; you could play if you shared the same browser. + * In this :doc:`second part `, you connected a second browser; + you could play from different browsers on a local network. + * In this :doc:`third part `, you will deploy the game to the + web; you can play from any browser connected to the Internet. + +In the first and second parts of the tutorial, for local development, you ran +a HTTP server on ``http://localhost:8000/`` with: + +.. code-block:: console + + $ python -m http.server + +and a WebSocket server on ``ws://localhost:8001/`` with: + +.. code-block:: console + + $ python app.py + +Now you want to deploy these servers on the Internet. There's a vast range of +hosting providers to choose from. For the sake of simplicity, we'll rely on: + +* GitHub Pages for the HTTP server; +* Heroku for the WebSocket server. + +Commit project to git +--------------------- + +Perhaps you committed your work to git while you were progressing through the +tutorial. If you didn't, now is a good time, because GitHub and Heroku offer +git-based deployment workflows. + +Initialize a git repository: + +.. code-block:: console + + $ git init -b main + Initialized empty Git repository in websockets-tutorial/.git/ + $ git commit --allow-empty -m "Initial commit." + [main (root-commit) ...] Initial commit. + +Add all files and commit: + +.. code-block:: console + + $ git add . + $ git commit -m "Initial implementation of Connect Four game." + [main ...] Initial implementation of Connect Four game. + 6 files changed, 500 insertions(+) + create mode 100644 app.py + create mode 100644 connect4.css + create mode 100644 connect4.js + create mode 100644 connect4.py + create mode 100644 index.html + create mode 100644 main.js + +Prepare the WebSocket server +---------------------------- + +Before you deploy the server, you must adapt it to meet requirements of +Heroku's runtime. This involves two small changes: + +1. Heroku expects the server to `listen on a specific port`_, provided in the + ``$PORT`` environment variable. + +2. Heroku sends a ``SIGTERM`` signal when `shutting down a dyno`_, which + should trigger a clean exit. + +.. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port + +.. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown + +Adapt the ``main()`` coroutine accordingly: + +.. code-block:: python + + import os + import signal + +.. literalinclude:: ../../example/tutorial/step3/app.py + :pyobject: main + +To catch the ``SIGTERM`` signal, ``main()`` creates a :class:`~asyncio.Future` +called ``stop`` and registers a signal handler that sets the result of this +future. The value of the future doesn't matter; it's only for waiting for +``SIGTERM``. + +Then, by using :func:`~server.serve` as a context manager and exiting the +context when ``stop`` has a result, ``main()`` ensures that the server closes +connections cleanly and exits on ``SIGTERM``. + +The app is now fully compatible with Heroku. + +Deploy the WebSocket server +--------------------------- + +Create a ``requirements.txt`` file with this content to install ``websockets`` +when building the image: + +.. literalinclude:: ../../example/tutorial/step3/requirements.txt + :language: text + +.. admonition:: Heroku treats ``requirements.txt`` as a signal to `detect a Python app`_. + :class: tip + + That's why you don't need to declare that you need a Python runtime. + +.. _detect a Python app: https://devcenter.heroku.com/articles/python-support#recognizing-a-python-app + +Create a ``Procfile`` file with this content to configure the command for +running the server: + +.. literalinclude:: ../../example/tutorial/step3/Procfile + :language: text + +Commit your changes: + +.. code-block:: console + + $ git add . + $ git commit -m "Deploy to Heroku." + [main ...] Deploy to Heroku. + 3 files changed, 12 insertions(+), 2 deletions(-) + create mode 100644 Procfile + create mode 100644 requirements.txt + +Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if +you haven't done that yet. + +.. _set-up instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up + +Create a Heroku app. You must choose a unique name and replace +``websockets-tutorial`` by this name in the following command: + +.. code-block:: console + + $ heroku create websockets-tutorial + Creating ⬢ websockets-tutorial... done + https://websockets-tutorial.herokuapp.com/ | https://git.heroku.com/websockets-tutorial.git + +If you reuse a name that someone else already uses, you will receive this +error; if this happens, try another name: + +.. code-block:: console + + $ heroku create websockets-tutorial + Creating ⬢ websockets-tutorial... ! + ▸ Name websockets-tutorial is already taken + +Deploy by pushing the code to Heroku: + +.. code-block:: console + + $ git push heroku + + ... lots of output... + + remote: Released v1 + remote: https://websockets-tutorial.herokuapp.com/ deployed to Heroku + remote: + remote: Verifying deploy... done. + To https://git.heroku.com/websockets-tutorial.git + * [new branch] main -> main + +You can test the WebSocket server with the interactive client exactly like you +did in the first part of the tutorial. Replace ``websockets-tutorial`` by the +name of your app in the following command: + +.. code-block:: console + + $ python -m websockets wss://websockets-tutorial.herokuapp.com/ + Connected to wss://websockets-tutorial.herokuapp.com/. + > {"type": "init"} + < {"type": "init", "join": "54ICxFae_Ip7TJE2", "watch": "634w44TblL5Dbd9a"} + Connection closed: 1000 (OK). + +It works! + +Prepare the web application +--------------------------- + +Before you deploy the web application, perhaps you're wondering how it will +locate the WebSocket server? Indeed, at this point, its address is hard-coded +in ``main.js``: + +.. code-block:: javascript + + const websocket = new WebSocket("ws://localhost:8001/"); + +You can take this strategy one step further by checking the address of the +HTTP server and determining the address of the WebSocket server accordingly. + +Add this function to ``main.js``; replace ``aaugustin`` by your GitHub +username and ``websockets-tutorial`` by the name of your app on Heroku: + +.. literalinclude:: ../../example/tutorial/step3/main.js + :language: js + :start-at: function getWebSocketServer + :end-before: function initGame + +Then, update the initialization to connect to this address instead: + +.. code-block:: javascript + + const websocket = new WebSocket(getWebSocketServer()); + +Commit your changes: + +.. code-block:: console + + $ git add . + $ git commit -m "Configure WebSocket server address." + [main ...] Configure WebSocket server address. + 1 file changed, 11 insertions(+), 1 deletion(-) + +Deploy the web application +-------------------------- + +Go to GitHub and create a new repository called ``websockets-tutorial``. + +Push your code to this repository. You must replace ``aaugustin`` by your +GitHub username in the following command: + +.. code-block:: console + + $ git remote add origin git@github.com:aaugustin/websockets-tutorial.git + $ git push -u origin main + Enumerating objects: 11, done. + Counting objects: 100% (11/11), done. + Delta compression using up to 8 threads + Compressing objects: 100% (10/10), done. + Writing objects: 100% (11/11), 5.90 KiB | 2.95 MiB/s, done. + Total 11 (delta 0), reused 0 (delta 0), pack-reused 0 + To github.com:/websockets-tutorial.git + * [new branch] main -> main + Branch 'main' set up to track remote branch 'main' from 'origin'. + +Go back to GitHub, open the Settings tab of the repository and select Pages in +the menu. Select the main branch as source and click Save. GitHub tells you +that your site is published. + +Follow the link and start a game! + +Summary +------- + +In this third part of the tutorial, you learned how to deploy a WebSocket +application with Heroku. + +You can start a Connect Four game, send the JOIN link to a friend, and play +over the Internet! + +Congratulations for completing the tutorial. Enjoy building real-time web +applications with websockets! + +Solution +-------- + +.. literalinclude:: ../../example/tutorial/step3/app.py + :caption: app.py + :language: python + :linenos: + +.. literalinclude:: ../../example/tutorial/step3/index.html + :caption: index.html + :language: html + :linenos: + +.. literalinclude:: ../../example/tutorial/step3/main.js + :caption: main.js + :language: js + :linenos: + +.. literalinclude:: ../../example/tutorial/step3/Procfile + :caption: Procfile + :language: text + :linenos: + +.. literalinclude:: ../../example/tutorial/step3/requirements.txt + :caption: requirements.txt + :language: text + :linenos: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b57d3c77f..1d5ae527d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -18,6 +18,7 @@ coroutines cryptocurrencies cryptocurrency ctrl +deserialize django dyno fractalideas diff --git a/example/tutorial/start/connect4.css b/example/tutorial/start/connect4.css new file mode 100644 index 000000000..27f0baf6e --- /dev/null +++ b/example/tutorial/start/connect4.css @@ -0,0 +1,105 @@ +/* General layout */ + +body { + background-color: white; + display: flex; + flex-direction: column-reverse; + justify-content: center; + align-items: center; + margin: 0; + min-height: 100vh; +} + +/* Action buttons */ + +.actions { + display: flex; + flex-direction: row; + justify-content: space-evenly; + align-items: flex-end; + width: 720px; + height: 100px; +} + +.action { + color: darkgray; + font-family: "Helvetica Neue", sans-serif; + font-size: 20px; + line-height: 20px; + font-weight: 300; + text-align: center; + text-decoration: none; + text-transform: uppercase; + padding: 20px; + width: 120px; +} + +.action:hover { + background-color: darkgray; + color: white; + font-weight: 700; +} + +.action[href=""] { + display: none; +} + +/* Connect Four board */ + +.board { + background-color: blue; + display: flex; + flex-direction: row; + padding: 0 10px; + position: relative; +} + +.board::before, +.board::after { + background-color: blue; + content: ""; + height: 720px; + width: 20px; + position: absolute; +} + +.board::before { + left: -20px; +} + +.board::after { + right: -20px; +} + +.column { + display: flex; + flex-direction: column-reverse; + padding: 10px; +} + +.cell { + border-radius: 50%; + width: 80px; + height: 80px; + margin: 10px 0; +} + +.empty { + background-color: white; +} + +.column:hover .empty { + background-color: lightgray; +} + +.column:hover .empty ~ .empty { + background-color: white; +} + +.red { + background-color: red; +} + +.yellow { + background-color: yellow; +} diff --git a/example/tutorial/start/connect4.js b/example/tutorial/start/connect4.js new file mode 100644 index 000000000..cb5eb9fa2 --- /dev/null +++ b/example/tutorial/start/connect4.js @@ -0,0 +1,45 @@ +const PLAYER1 = "red"; + +const PLAYER2 = "yellow"; + +function createBoard(board) { + // Inject stylesheet. + const linkElement = document.createElement("link"); + linkElement.href = import.meta.url.replace(".js", ".css"); + linkElement.rel = "stylesheet"; + document.head.append(linkElement); + // Generate board. + for (let column = 0; column < 7; column++) { + const columnElement = document.createElement("div"); + columnElement.className = "column"; + columnElement.dataset.column = column; + for (let row = 0; row < 6; row++) { + const cellElement = document.createElement("div"); + cellElement.className = "cell empty"; + cellElement.dataset.column = column; + columnElement.append(cellElement); + } + board.append(columnElement); + } +} + +function playMove(board, player, column, row) { + // Check values of arguments. + if (player !== PLAYER1 && player !== PLAYER2) { + throw new Error(`player must be ${PLAYER1} or ${PLAYER2}.`); + } + const columnElement = board.querySelectorAll(".column")[column]; + if (columnElement === undefined) { + throw new RangeError("column must be between 0 and 6."); + } + const cellElement = columnElement.querySelectorAll(".cell")[row]; + if (cellElement === undefined) { + throw new RangeError("row must be between 0 and 5."); + } + // Place checker in cell. + if (!cellElement.classList.replace("empty", player)) { + throw new Error("cell must be empty."); + } +} + +export { PLAYER1, PLAYER2, createBoard, playMove }; diff --git a/example/tutorial/start/connect4.py b/example/tutorial/start/connect4.py new file mode 100644 index 000000000..0a61e7c7e --- /dev/null +++ b/example/tutorial/start/connect4.py @@ -0,0 +1,62 @@ +__all__ = ["PLAYER1", "PLAYER2", "Connect4"] + +PLAYER1, PLAYER2 = "red", "yellow" + + +class Connect4: + """ + A Connect Four game. + + Play moves with :meth:`play`. + + Get past moves with :attr:`moves`. + + Check for a victory with :attr:`winner`. + + """ + + def __init__(self): + self.moves = [] + self.top = [0 for _ in range(7)] + self.winner = None + + @property + def last_player(self): + """ + Player who played the last move. + + """ + return PLAYER1 if len(self.moves) % 2 else PLAYER2 + + @property + def last_player_won(self): + """ + Whether the last move is winning. + + """ + b = sum(1 << (8 * column + row) for _, column, row in self.moves[::-2]) + return any(b & b >> v & b >> 2 * v & b >> 3 * v for v in [1, 7, 8, 9]) + + def play(self, player, column): + """ + Play a move in a column. + + Returns the row where the checker lands. + + Raises :exc:`RuntimeError` if the move is illegal. + + """ + if player == self.last_player: + raise RuntimeError("It isn't your turn.") + + row = self.top[column] + if row == 6: + raise RuntimeError("This slot is full.") + + self.moves.append((player, column, row)) + self.top[column] += 1 + + if self.winner is None and self.last_player_won: + self.winner = self.last_player + + return row diff --git a/example/tutorial/start/favicon.ico b/example/tutorial/start/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..36e855029d705e72d44428bda6e8cb6d3dd317ed GIT binary patch literal 5430 zcmeH~eNdFw6~-?im3Ahb_(wY9G?@uAnl{s!P5}k6EMLo)Am6YM4N*)C21rB%KSG3E zv;hMxiW&u@(MF93wOCQ3rD-sRXbhdyP7o3K#!S*q0u${7l(0J*XQZVahM#RhT&Mw%^BqMk{U@J1~1H2BJEU64r)0!B+s<_@|*2 zqhiBw1#eC^Kd~FBvDYyg*$#wV#Xwjm9t&v2p)n2r@@vI8+WgsgcG9|#V(P|dLI)7j zfdSf?v#0~lgtfyurUjuNe*Ta2+T*8ca~J%;(ZdETkY^gva%XNcqcj2kdN-G|0yeVF}ZFP=4Z!J2R#i0#02+Ir+*zVG@y z*5YS?yJuBTDvF9 z+e4qHUo$?zisgL(ZFDT{$HRZz5|LeB6ok5h0Mx8E{O!%Wpod<2-phQZDi;R}!Uptt z_Sffbz;E>+X1=#0_Q!9(68Y81O(FP$*?76Z67toHt(cKzH39p~78|EMbIvF7dZI9CxZHU=B0 z<9Jap>Rt-q+)$@;bvu`KC6(K3*mRuQ9MaslWomPE8-LktGQPhv{>AfKV-~kmXM9f| z|2D6HQbVN2tW*=RJWA^l;TTP>949?wSa4Hnl)qW|Kbb4f8Fvyy2h!?Rc2a%__sk=;L*m&%aOk`2^Q3=D%J$`lrna32wt&`ua?J=IYOX`)ey| zsDH*}F--J3_a!wE;q=XC^%sGy0H6D|y~p17*k>&_PGWEtANNEx9#Nf`$LIWDXd7VU z-lc0`F21g#6;|0_CyK+_IPED%Vmjd+bq#|NSGn)ee~4%xL9&NlK@0!?spEQZeA9Sw zOg84YE{XKby>e*LH9+>d2-$me4ei*@T2#ST@l)pX-aNEoMD}B9TkIH*)9-fc#{JDF zsS`tE`y^thLD#X6zKIf~ACX*0P0;qeI{2O1>Ze+hd$YWhu%sV8DP0&!?gSFKPsfuu z=^dyc`Wh;U66imQy~wci5ZZc7tYMe4y3>E-zOC46crNM2&=cI_Q|MnbFb{~Qf9lvq zl)zjpqW|G=)`eCR4jO{~dh5&OHAgQLZpnV9ydbWbam~@=o9J4d>8Y0WqGf3>kS6_3 zH=ygFmaRky%tJKeAJJ-p{_kqA)#Yh(*$efw^_AAoNKZthCz1G^u_xP8J>82TYJl`x zPtw0=VBb|il)!vMb4^CrJ8A0?#hMfy%!Slv9#Q+&`23lKphoC2`RmZF?C`@YT|BPQis$fd)qZc3H#T~U(G%W)mfglo{sU;_HV?#Giog+ zWdZKb7(gC1)GY7CGNOdEKE$SWVT5(5)r3};^sUup2XfhuD&bJR@3U{@`dIhp%!9oK z`hCKgZ~0C9j|Y3#aMr|#eN9k{rX#rbA2Ga;A`w_nTG-C ztBDe-vb<91XNIG%>P!!=&(~otdZi|$F7=v_x=f{SqJ-l`K`Vb(5MVvJ!Jze!ng;g7 z>?QBK{=AowD1llEwR+6*IO(WiJm0l|D{Ep{uL)g4S}$^l5>9OnTX~|`Xtnjy{y6%g zO~ax51$Wq&ClMvEKQ9vBhc`yq?ukr~>U(@1IJr6Gc0;i-fhgfV}k+ptmI=o*Qx0@RDGY}eIiR4)*!i%Vr#UY#Q>p{*{nuh%O zw-{rDA=?=E^vn-xf;^a)yc&4S#>DB;t~+h17%G6M9Y7Y%ttpeEt)~ z*kB1&)1jq0b@m6ll1AkWCmP=Q^&;&&DfRHy%i&r*$hok-P}6XevH3fT@AR)C)pTj8 zN-n>!BLP|-wu%PrJgTk5T@A${HyM{c{o!N z{uyKPol0D`gW)%O|F9}G?&6Mwv&}o=zc^nRJKR(eeeKWX(PfP1M^laIn|c0j*0+@b zZDY{i*oZ&{!b(QfqZlSb-*-dLEEhC+xWJOGU};h)wgh3bYC?&N1*5L&qtO{IrNa-n mx{c17;Wp~=fE$Kp5swEk + + + Connect Four + + +
+ + + diff --git a/example/tutorial/step1/main.js b/example/tutorial/step1/main.js new file mode 100644 index 000000000..dd28f9a6a --- /dev/null +++ b/example/tutorial/step1/main.js @@ -0,0 +1,53 @@ +import { createBoard, playMove } from "./connect4.js"; + +function showMessage(message) { + window.setTimeout(() => window.alert(message), 50); +} + +function receiveMoves(board, websocket) { + websocket.addEventListener("message", ({ data }) => { + const event = JSON.parse(data); + switch (event.type) { + case "play": + // Update the UI with the move. + playMove(board, event.player, event.column, event.row); + break; + case "win": + showMessage(`Player ${event.player} wins!`); + // No further messages are expected; close the WebSocket connection. + websocket.close(1000); + break; + case "error": + showMessage(event.message); + break; + default: + throw new Error(`Unsupported event type: ${event.type}.`); + } + }); +} + +function sendMoves(board, websocket) { + // When clicking a column, send a "play" event for a move in that column. + board.addEventListener("click", ({ target }) => { + const column = target.dataset.column; + // Ignore clicks outside a column. + if (column === undefined) { + return; + } + const event = { + type: "play", + column: parseInt(column, 10), + }; + websocket.send(JSON.stringify(event)); + }); +} + +window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + // Open the WebSocket connection and register event handlers. + const websocket = new WebSocket("ws://localhost:8001/"); + receiveMoves(board, websocket); + sendMoves(board, websocket); +}); diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py new file mode 100644 index 000000000..bac2b6f27 --- /dev/null +++ b/example/tutorial/step2/app.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python + +import asyncio +import json +import secrets + +import websockets + +from connect4 import PLAYER1, PLAYER2, Connect4 + + +JOIN = {} + +WATCH = {} + + +async def error(websocket, message): + """ + Send an error message. + + """ + event = { + "type": "error", + "message": message, + } + await websocket.send(json.dumps(event)) + + +async def replay(websocket, game): + """ + Send previous moves. + + """ + # Make a copy to avoid an exception if game.moves changes while iteration + # is in progress. If a move is played while replay is running, moves will + # be sent out of order but each move will be sent once and eventually the + # UI will be consistent. + for player, column, row in game.moves.copy(): + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + await websocket.send(json.dumps(event)) + + +async def play(websocket, game, player, connected): + """ + Receive and process moves from a player. + + """ + async for message in websocket: + # Parse a "play" event from the UI. + event = json.loads(message) + assert event["type"] == "play" + column = event["column"] + + try: + # Play the move. + row = game.play(player, column) + except RuntimeError as exc: + # Send an "error" event if the move was illegal. + await error(websocket, str(exc)) + continue + + # Send a "play" event to update the UI. + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + websockets.broadcast(connected, json.dumps(event)) + + # If move is winning, send a "win" event. + if game.winner is not None: + event = { + "type": "win", + "player": game.winner, + } + websockets.broadcast(connected, json.dumps(event)) + + +async def start(websocket): + """ + Handle a connection from the first player: start a new game. + + """ + # Initialize a Connect Four game, the set of WebSocket connections + # receiving moves from this game, and secret access tokens. + game = Connect4() + connected = {websocket} + + join_key = secrets.token_urlsafe(12) + JOIN[join_key] = game, connected + + watch_key = secrets.token_urlsafe(12) + WATCH[watch_key] = game, connected + + try: + # Send the secret access tokens to the browser of the first player, + # where they'll be used for building "join" and "watch" links. + event = { + "type": "init", + "join": join_key, + "watch": watch_key, + } + await websocket.send(json.dumps(event)) + # Receive and process moves from the first player. + await play(websocket, game, PLAYER1, connected) + finally: + del JOIN[join_key] + del WATCH[watch_key] + + +async def join(websocket, join_key): + """ + Handle a connection from the second player: join an existing game. + + """ + # Find the Connect Four game. + try: + game, connected = JOIN[join_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + # Send the first move, in case the first player already played it. + await replay(websocket, game) + # Receive and process moves from the second player. + await play(websocket, game, PLAYER2, connected) + finally: + connected.remove(websocket) + + +async def watch(websocket, watch_key): + """ + Handle a connection from a spectator: watch an existing game. + + """ + # Find the Connect Four game. + try: + game, connected = WATCH[watch_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + # Send previous moves, in case the game already started. + await replay(websocket, game) + # Keep the connection open, but don't receive any messages. + await websocket.wait_closed() + finally: + connected.remove(websocket) + + +async def handler(websocket, path): + """ + Handle a connection and dispatch it according to who is connecting. + + """ + # Receive and parse the "init" event from the UI. + message = await websocket.recv() + event = json.loads(message) + assert event["type"] == "init" + + if "join" in event: + # Second player joins an existing game. + await join(websocket, event["join"]) + elif "watch" in event: + # Spectator watches an existing game. + await watch(websocket, event["watch"]) + else: + # First player starts a new game. + await start(websocket) + + +async def main(): + async with websockets.serve(handler, "", 8001): + await asyncio.Future() # run forever + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/tutorial/step2/connect4.css b/example/tutorial/step2/connect4.css new file mode 120000 index 000000000..55a9977ca --- /dev/null +++ b/example/tutorial/step2/connect4.css @@ -0,0 +1 @@ +../start/connect4.css \ No newline at end of file diff --git a/example/tutorial/step2/connect4.js b/example/tutorial/step2/connect4.js new file mode 120000 index 000000000..7c4ed2f3e --- /dev/null +++ b/example/tutorial/step2/connect4.js @@ -0,0 +1 @@ +../start/connect4.js \ No newline at end of file diff --git a/example/tutorial/step2/connect4.py b/example/tutorial/step2/connect4.py new file mode 120000 index 000000000..eab6b7dc0 --- /dev/null +++ b/example/tutorial/step2/connect4.py @@ -0,0 +1 @@ +../start/connect4.py \ No newline at end of file diff --git a/example/tutorial/step2/favicon.ico b/example/tutorial/step2/favicon.ico new file mode 120000 index 000000000..76da1c2fb --- /dev/null +++ b/example/tutorial/step2/favicon.ico @@ -0,0 +1 @@ +../../../logo/favicon.ico \ No newline at end of file diff --git a/example/tutorial/step2/index.html b/example/tutorial/step2/index.html new file mode 100644 index 000000000..1a16f72a2 --- /dev/null +++ b/example/tutorial/step2/index.html @@ -0,0 +1,15 @@ + + + + Connect Four + + +
+ New + Join + Watch +
+
+ + + diff --git a/example/tutorial/step2/main.js b/example/tutorial/step2/main.js new file mode 100644 index 000000000..d38a0140a --- /dev/null +++ b/example/tutorial/step2/main.js @@ -0,0 +1,83 @@ +import { createBoard, playMove } from "./connect4.js"; + +function initGame(websocket) { + websocket.addEventListener("open", () => { + // Send an "init" event according to who is connecting. + const params = new URLSearchParams(window.location.search); + let event = { type: "init" }; + if (params.has("join")) { + // Second player joins an existing game. + event.join = params.get("join"); + } else if (params.has("watch")) { + // Spectator watches an existing game. + event.watch = params.get("watch"); + } else { + // First player starts a new game. + } + websocket.send(JSON.stringify(event)); + }); +} + +function showMessage(message) { + window.setTimeout(() => window.alert(message), 50); +} + +function receiveMoves(board, websocket) { + websocket.addEventListener("message", ({ data }) => { + const event = JSON.parse(data); + switch (event.type) { + case "init": + // Create links for inviting the second player and spectators. + document.querySelector(".join").href = "?join=" + event.join; + document.querySelector(".watch").href = "?watch=" + event.watch; + break; + case "play": + // Update the UI with the move. + playMove(board, event.player, event.column, event.row); + break; + case "win": + showMessage(`Player ${event.player} wins!`); + // No further messages are expected; close the WebSocket connection. + websocket.close(1000); + break; + case "error": + showMessage(event.message); + break; + default: + throw new Error(`Unsupported event type: ${event.type}.`); + } + }); +} + +function sendMoves(board, websocket) { + // Don't send moves for a spectator watching a game. + const params = new URLSearchParams(window.location.search); + if (params.has("watch")) { + return; + } + + // When clicking a column, send a "play" event for a move in that column. + board.addEventListener("click", ({ target }) => { + const column = target.dataset.column; + // Ignore clicks outside a column. + if (column === undefined) { + return; + } + const event = { + type: "play", + column: parseInt(column, 10), + }; + websocket.send(JSON.stringify(event)); + }); +} + +window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + // Open the WebSocket connection and register event handlers. + const websocket = new WebSocket("ws://localhost:8001/"); + initGame(websocket); + receiveMoves(board, websocket); + sendMoves(board, websocket); +}); diff --git a/example/tutorial/step3/Procfile b/example/tutorial/step3/Procfile new file mode 100644 index 000000000..2e35818f6 --- /dev/null +++ b/example/tutorial/step3/Procfile @@ -0,0 +1 @@ +web: python app.py diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py new file mode 100644 index 000000000..6fff79c95 --- /dev/null +++ b/example/tutorial/step3/app.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python + +import asyncio +import json +import os +import secrets +import signal + +import websockets + +from connect4 import PLAYER1, PLAYER2, Connect4 + + +JOIN = {} + +WATCH = {} + + +async def error(websocket, message): + """ + Send an error message. + + """ + event = { + "type": "error", + "message": message, + } + await websocket.send(json.dumps(event)) + + +async def replay(websocket, game): + """ + Send previous moves. + + """ + # Make a copy to avoid an exception if game.moves changes while iteration + # is in progress. If a move is played while replay is running, moves will + # be sent out of order but each move will be sent once and eventually the + # UI will be consistent. + for player, column, row in game.moves.copy(): + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + await websocket.send(json.dumps(event)) + + +async def play(websocket, game, player, connected): + """ + Receive and process moves from a player. + + """ + async for message in websocket: + # Parse a "play" event from the UI. + event = json.loads(message) + assert event["type"] == "play" + column = event["column"] + + try: + # Play the move. + row = game.play(player, column) + except RuntimeError as exc: + # Send an "error" event if the move was illegal. + await error(websocket, str(exc)) + continue + + # Send a "play" event to update the UI. + event = { + "type": "play", + "player": player, + "column": column, + "row": row, + } + websockets.broadcast(connected, json.dumps(event)) + + # If move is winning, send a "win" event. + if game.winner is not None: + event = { + "type": "win", + "player": game.winner, + } + websockets.broadcast(connected, json.dumps(event)) + + +async def start(websocket): + """ + Handle a connection from the first player: start a new game. + + """ + # Initialize a Connect Four game, the set of WebSocket connections + # receiving moves from this game, and secret access tokens. + game = Connect4() + connected = {websocket} + + join_key = secrets.token_urlsafe(12) + JOIN[join_key] = game, connected + + watch_key = secrets.token_urlsafe(12) + WATCH[watch_key] = game, connected + + try: + # Send the secret access tokens to the browser of the first player, + # where they'll be used for building "join" and "watch" links. + event = { + "type": "init", + "join": join_key, + "watch": watch_key, + } + await websocket.send(json.dumps(event)) + # Receive and process moves from the first player. + await play(websocket, game, PLAYER1, connected) + finally: + del JOIN[join_key] + del WATCH[watch_key] + + +async def join(websocket, join_key): + """ + Handle a connection from the second player: join an existing game. + + """ + # Find the Connect Four game. + try: + game, connected = JOIN[join_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + # Send the first move, in case the first player already played it. + await replay(websocket, game) + # Receive and process moves from the second player. + await play(websocket, game, PLAYER2, connected) + finally: + connected.remove(websocket) + + +async def watch(websocket, watch_key): + """ + Handle a connection from a spectator: watch an existing game. + + """ + # Find the Connect Four game. + try: + game, connected = WATCH[watch_key] + except KeyError: + await error(websocket, "Game not found.") + return + + # Register to receive moves from this game. + connected.add(websocket) + try: + # Send previous moves, in case the game already started. + await replay(websocket, game) + # Keep the connection open, but don't receive any messages. + await websocket.wait_closed() + finally: + connected.remove(websocket) + + +async def handler(websocket, path): + """ + Handle a connection and dispatch it according to who is connecting. + + """ + # Receive and parse the "init" event from the UI. + message = await websocket.recv() + event = json.loads(message) + assert event["type"] == "init" + + if "join" in event: + # Second player joins an existing game. + await join(websocket, event["join"]) + elif "watch" in event: + # Spectator watches an existing game. + await watch(websocket, event["watch"]) + else: + # First player starts a new game. + await start(websocket) + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + port = int(os.environ.get("PORT", "8001")) + async with websockets.serve(handler, "", port): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/tutorial/step3/connect4.css b/example/tutorial/step3/connect4.css new file mode 120000 index 000000000..55a9977ca --- /dev/null +++ b/example/tutorial/step3/connect4.css @@ -0,0 +1 @@ +../start/connect4.css \ No newline at end of file diff --git a/example/tutorial/step3/connect4.js b/example/tutorial/step3/connect4.js new file mode 120000 index 000000000..7c4ed2f3e --- /dev/null +++ b/example/tutorial/step3/connect4.js @@ -0,0 +1 @@ +../start/connect4.js \ No newline at end of file diff --git a/example/tutorial/step3/connect4.py b/example/tutorial/step3/connect4.py new file mode 120000 index 000000000..eab6b7dc0 --- /dev/null +++ b/example/tutorial/step3/connect4.py @@ -0,0 +1 @@ +../start/connect4.py \ No newline at end of file diff --git a/example/tutorial/step3/favicon.ico b/example/tutorial/step3/favicon.ico new file mode 120000 index 000000000..76da1c2fb --- /dev/null +++ b/example/tutorial/step3/favicon.ico @@ -0,0 +1 @@ +../../../logo/favicon.ico \ No newline at end of file diff --git a/example/tutorial/step3/index.html b/example/tutorial/step3/index.html new file mode 100644 index 000000000..1a16f72a2 --- /dev/null +++ b/example/tutorial/step3/index.html @@ -0,0 +1,15 @@ + + + + Connect Four + + +
+ New + Join + Watch +
+
+ + + diff --git a/example/tutorial/step3/main.js b/example/tutorial/step3/main.js new file mode 100644 index 000000000..15afd4163 --- /dev/null +++ b/example/tutorial/step3/main.js @@ -0,0 +1,93 @@ +import { createBoard, playMove } from "./connect4.js"; + +function getWebSocketServer() { + if (window.location.host === "aaugustin.github.io") { + return "wss://websockets-tutorial.herokuapp.com/"; + } else if (window.location.host === "localhost:8000") { + return "ws://localhost:8001/"; + } else { + throw new Error(`Unsupported host: ${window.location.host}`); + } +} + +function initGame(websocket) { + websocket.addEventListener("open", () => { + // Send an "init" event according to who is connecting. + const params = new URLSearchParams(window.location.search); + let event = { type: "init" }; + if (params.has("join")) { + // Second player joins an existing game. + event.join = params.get("join"); + } else if (params.has("watch")) { + // Spectator watches an existing game. + event.watch = params.get("watch"); + } else { + // First player starts a new game. + } + websocket.send(JSON.stringify(event)); + }); +} + +function showMessage(message) { + window.setTimeout(() => window.alert(message), 50); +} + +function receiveMoves(board, websocket) { + websocket.addEventListener("message", ({ data }) => { + const event = JSON.parse(data); + switch (event.type) { + case "init": + // Create links for inviting the second player and spectators. + document.querySelector(".join").href = "?join=" + event.join; + document.querySelector(".watch").href = "?watch=" + event.watch; + break; + case "play": + // Update the UI with the move. + playMove(board, event.player, event.column, event.row); + break; + case "win": + showMessage(`Player ${event.player} wins!`); + // No further messages are expected; close the WebSocket connection. + websocket.close(1000); + break; + case "error": + showMessage(event.message); + break; + default: + throw new Error(`Unsupported event type: ${event.type}.`); + } + }); +} + +function sendMoves(board, websocket) { + // Don't send moves for a spectator watching a game. + const params = new URLSearchParams(window.location.search); + if (params.has("watch")) { + return; + } + + // When clicking a column, send a "play" event for a move in that column. + board.addEventListener("click", ({ target }) => { + const column = target.dataset.column; + // Ignore clicks outside a column. + if (column === undefined) { + return; + } + const event = { + type: "play", + column: parseInt(column, 10), + }; + websocket.send(JSON.stringify(event)); + }); +} + +window.addEventListener("DOMContentLoaded", () => { + // Initialize the UI. + const board = document.querySelector(".board"); + createBoard(board); + // Open the WebSocket connection and register event handlers. + const websocket = new WebSocket(getWebSocketServer()); + initGame(websocket); + receiveMoves(board, websocket); + sendMoves(board, websocket); +}); diff --git a/example/tutorial/step3/requirements.txt b/example/tutorial/step3/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/tutorial/step3/requirements.txt @@ -0,0 +1 @@ +websockets From 448e777adc7d8cf0cbc0bd585959bbdfad682926 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Nov 2021 09:01:11 +0100 Subject: [PATCH 213/762] Avoid shutdown on closed socket. Fix #1072. --- src/websockets/legacy/protocol.py | 2 +- tests/legacy/test_protocol.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 3340c33be..23d46d018 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1303,7 +1303,7 @@ async def close_connection(self) -> None: self.logger.debug("! timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). - if self.transport.can_write_eof(): + if self.transport.can_write_eof() and not self.transport.is_closing(): if self.debug: self.logger.debug("x half-closing TCP connection") self.transport.write_eof() diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 1672ab1ed..22e72a696 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -59,6 +59,9 @@ def setup_mock(self, loop, protocol): def can_write_eof(self): return True + def is_closing(self): + return False + def write_eof(self): # When the protocol half-closes the TCP connection, it expects the # other end to close it. Simulate that. From d7320ca11e47499ae8cbbaddfe8f8bd69802bc48 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 20:52:21 +0100 Subject: [PATCH 214/762] Add missing items to changelog. --- docs/project/changelog.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5b190fb94..a8b77c034 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -33,6 +33,8 @@ They may change at any time. New features ............ +* Added a tutorial. + * Made the second parameter of connection handlers optional. It will be deprecated in the next major release. The request path is available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of @@ -49,6 +51,8 @@ New features path = request.path # if handler() uses the path argument ... +* Added ``python -m websockets --version``. + Improvements ............ @@ -66,6 +70,11 @@ Improvements * Documented how to auto-reload on code changes in development. +Bug fixes +......... + +* Avoided half-closing TCP connections that are already closed. + 10.0 ---- From 5f63545517270cd3218c2ed19df190f5ce88f795 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 20:45:50 +0100 Subject: [PATCH 215/762] Release 10.1. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index a8b77c034..ea9079166 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.1 ---- -*In development* +*November 14, 2021* New features ............ diff --git a/src/websockets/version.py b/src/websockets/version.py index 3a6e6aa08..cb76be5d2 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "10.1" From b835315f465124e7fedf1a753d07cfbce95327af Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Nov 2021 21:29:44 +0100 Subject: [PATCH 216/762] Start 10.2. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index ea9079166..b58ff15b1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +10.2 +---- + +*In development* + 10.1 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index cb76be5d2..605c8264a 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "10.1" +tag = version = commit = "10.2" if not released: # pragma: no cover From 67b19e6e9ac96722a584e35d7e9dc3a2173af6b8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Nov 2021 20:33:47 +0100 Subject: [PATCH 217/762] Clean up mypy cache. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index b15cd13c9..6f8130840 100644 --- a/Makefile +++ b/Makefile @@ -27,4 +27,4 @@ build: clean: find . -name '*.pyc' -o -name '*.so' -delete find . -name __pycache__ -delete - rm -rf .coverage build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info + rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 14891ffa20f6acc0f6ec24a96b6bd4d3504ea3eb Mon Sep 17 00:00:00 2001 From: David Sanders Date: Tue, 21 Dec 2021 15:54:27 -0800 Subject: [PATCH 218/762] Fix some typos in comments --- src/websockets/client.py | 2 +- src/websockets/connection.py | 2 +- src/websockets/datastructures.py | 2 +- src/websockets/http11.py | 2 +- src/websockets/server.py | 2 +- src/websockets/speedups.c | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 34732b3a6..9b86b4d0a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -201,7 +201,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: client configuration. If no match is found, an exception is raised. If several variants of the same extension are accepted by the server, - it may be configured severel times, which won't make sense in general. + it may be configured several times, which won't make sense in general. Extensions must implement their own requirements. For this purpose, the list of previously accepted extensions is provided. diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 8661a148b..0a4d3c7bc 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -106,7 +106,7 @@ def __init__( # Connection side. CLIENT or SERVER. self.side = side - # Connnection state. Initially OPEN because subclasses handle CONNECTING. + # Connection state. Initially OPEN because subclasses handle CONNECTING. self.state = state # Maximum size of incoming messages in bytes. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index d5c061cf8..37d3b5f86 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -118,7 +118,7 @@ def __setitem__(self, key: str, value: str) -> None: def __delitem__(self, key: str) -> None: key_lower = key.lower() self._dict.__delitem__(key_lower) - # This is inefficent. Fortunately deleting HTTP headers is uncommon. + # This is inefficient. Fortunately deleting HTTP headers is uncommon. self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] def __eq__(self, other: Any) -> bool: diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 052719c67..a2fd22dd2 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -185,7 +185,7 @@ def parse( read_exact: generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. read_to_eof: generator-based coroutine that reads until the end - of the strem. + of the stream. Raises: EOFError: if the connection is closed without a full HTTP response. diff --git a/src/websockets/server.py b/src/websockets/server.py index 2d9b4f9a8..a94c0b629 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -174,7 +174,7 @@ def process_request( self, request: Request ) -> Tuple[str, Optional[str], Optional[str]]: """ - Check a handshake request and negociate extensions and subprotocol. + Check a handshake request and negotiate extensions and subprotocol. This function doesn't verify that the request is an HTTP/1.1 or higher GET request and doesn't check the ``Host`` header. These controls are diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index f8d24ec7a..a19590419 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -121,7 +121,7 @@ apply_mask(PyObject *self, PyObject *args, PyObject *kwds) goto exit; } - // Since we juste created result, we don't need error checks. + // Since we just created result, we don't need error checks. output = PyBytes_AS_STRING(result); // Perform the masking operation. From 668f320e0547d80afe6529528e1ecc6088955cdc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 24 Dec 2021 18:51:53 +0100 Subject: [PATCH 219/762] Update for the latest version fo mypy. --- src/websockets/legacy/protocol.py | 11 +++++------ src/websockets/legacy/server.py | 10 +++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 23d46d018..c1809e20d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -750,7 +750,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.write_close_frame(Close(code, reason)), self.close_timeout, **loop_if_py_lt_38(self.loop), - ) # type: ignore # remove when removing loop_if_py_lt_38 + ) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. @@ -771,7 +771,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.transfer_data_task, self.close_timeout, **loop_if_py_lt_38(self.loop), - ) # type: ignore # remove when removing loop_if_py_lt_38 + ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1072,8 +1072,7 @@ def append(frame: Frame) -> None: raise ProtocolError("unexpected opcode") append(frame) - # mypy cannot figure out that chunks have the proper type. - return ("" if text else b"").join(chunks) # type: ignore + return ("" if text else b"").join(chunks) async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: """ @@ -1250,7 +1249,7 @@ async def keepalive_ping(self) -> None: pong_waiter, self.ping_timeout, **loop_if_py_lt_38(self.loop), - ) # type: ignore # remove when removing loop_if_py_lt_38 + ) self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1365,7 +1364,7 @@ async def wait_for_connection_lost(self) -> bool: asyncio.shield(self.connection_lost_waiter), self.close_timeout, **loop_if_py_lt_38(self.loop), - ) # type: ignore # remove when removing loop_if_py_lt_38 + ) except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 98712ff86..3172059d2 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -14,6 +14,7 @@ Awaitable, Callable, Generator, + Iterable, List, Optional, Sequence, @@ -361,7 +362,7 @@ async def process_request( warnings.warn( "declare process_request as a coroutine", DeprecationWarning ) - return response # type: ignore + return response return None @staticmethod @@ -589,7 +590,7 @@ async def handshake( else: # For backwards compatibility with 7.0. warnings.warn("declare process_request as a coroutine", DeprecationWarning) - early_response = early_response_awaitable # type: ignore + early_response = early_response_awaitable # The connection may drop while process_request is running. if self.state is State.CLOSED: @@ -677,7 +678,7 @@ def __init__(self, logger: Optional[LoggerLike] = None): # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] - def wrap(self, server: asyncio.AbstractServer) -> None: + def wrap(self, server: asyncio.base_events.Server) -> None: """ Attach to a given :class:`~asyncio.Server`. @@ -692,7 +693,6 @@ def wrap(self, server: asyncio.AbstractServer) -> None: """ self.server = server - assert server.sockets is not None for sock in server.sockets: if sock.family == socket.AF_INET: name = "%s:%d" % sock.getsockname() @@ -842,7 +842,7 @@ async def serve_forever(self) -> None: await self.server.serve_forever() # pragma: no cover @property - def sockets(self) -> Optional[List[socket.socket]]: + def sockets(self) -> Iterable[socket.socket]: """ See :attr:`asyncio.Server.sockets`. From 2b965a146eb41d05073264ddd296542992fb327f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 09:50:26 +0100 Subject: [PATCH 220/762] Document that convenience imports break IDEs. Fix #1124. --- docs/project/changelog.rst | 8 +++++--- docs/reference/index.rst | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b58ff15b1..14945b7bf 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -259,10 +259,12 @@ Backwards-incompatible changes .. admonition:: Convenience imports from ``websockets`` are performed lazily. :class: note - While Python supports this, static code analysis tools such as mypy are - unable to understand the behavior. + While Python supports this, tools relying on static code analysis don't. + This breaks autocompletion in an IDE or type checking with mypy_. - If you depend on such tools, use the real import path, which can be found + .. _mypy: https://github.com/python/mypy + + If you depend on such tools, use the real import paths, which can be found in the API documentation, for example:: from websockets.client import connect diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 385beab29..a5ee57843 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -59,8 +59,8 @@ Anything that isn't listed in the API reference is a private API. There's no guarantees of behavior or backwards-compatibility for private APIs. For convenience, many public APIs can be imported from the ``websockets`` -package. This feature is incompatible with static code analysis tools such as -mypy_, though. If you're using such tools, use the full import path. +package. However, this feature is incompatible with static code analysis. It +breaks autocompletion in an IDE or type checking with mypy_. If you're using +such tools, use the real import paths. .. _mypy: https://github.com/python/mypy - From 1e44065f49647abb6bd58bdb9a435dd064e40c74 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 13:41:05 +0100 Subject: [PATCH 221/762] Document OSError: [Errno 99]. Fix #1102. --- docs/howto/faq.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 20b957745..128d1dbd0 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -124,6 +124,16 @@ Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. :func:`~server.serve` accepts the same arguments as :meth:`~asyncio.loop.create_server`. +What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? +.......................................................................................................................... + +You are calling :func:`~server.serve` without a ``host`` argument in a context +where IPv6 isn't available. + +To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. + +Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. + How do I close a connection properly? ..................................... From cd8659e25444c14f4b1e67b5a612a297d0a00cf4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 13:49:48 +0100 Subject: [PATCH 222/762] Standardize style. (black has gotten better at preserving multi-line style.) --- tests/extensions/test_permessage_deflate.py | 116 +++++++++++++++----- 1 file changed, 91 insertions(+), 25 deletions(-) diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index bcd08a7ef..0d56917b4 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -27,16 +27,20 @@ class ExtensionTestsMixin: def assertExtensionEqual(self, extension1, extension2): self.assertEqual( - extension1.remote_no_context_takeover, extension2.remote_no_context_takeover + extension1.remote_no_context_takeover, + extension2.remote_no_context_takeover, ) self.assertEqual( - extension1.local_no_context_takeover, extension2.local_no_context_takeover + extension1.local_no_context_takeover, + extension2.local_no_context_takeover, ) self.assertEqual( - extension1.remote_max_window_bits, extension2.remote_max_window_bits + extension1.remote_max_window_bits, + extension2.remote_max_window_bits, ) self.assertEqual( - extension1.local_max_window_bits, extension2.local_max_window_bits + extension1.local_max_window_bits, + extension2.local_max_window_bits, ) @@ -84,7 +88,8 @@ def test_encode_decode_text_frame(self): enc_frame = self.extension.encode(frame) self.assertEqual( - enc_frame, dataclasses.replace(frame, rsv1=True, data=b"JNL;\xbc\x12\x00") + enc_frame, + dataclasses.replace(frame, rsv1=True, data=b"JNL;\xbc\x12\x00"), ) dec_frame = self.extension.decode(enc_frame) @@ -97,7 +102,8 @@ def test_encode_decode_binary_frame(self): enc_frame = self.extension.encode(frame) self.assertEqual( - enc_frame, dataclasses.replace(frame, rsv1=True, data=b"*IM\x04\x00") + enc_frame, + dataclasses.replace(frame, rsv1=True, data=b"*IM\x04\x00"), ) dec_frame = self.extension.decode(enc_frame) @@ -120,10 +126,12 @@ def test_encode_decode_fragmented_text_frame(self): ), ) self.assertEqual( - enc_frame2, dataclasses.replace(frame2, data=b"RPS\x00\x00\x00\x00\xff\xff") + enc_frame2, + dataclasses.replace(frame2, data=b"RPS\x00\x00\x00\x00\xff\xff"), ) self.assertEqual( - enc_frame3, dataclasses.replace(frame3, data=b"J.\xca\xcf,.N\xcc+)\x06\x00") + enc_frame3, + dataclasses.replace(frame3, data=b"J.\xca\xcf,.N\xcc+)\x06\x00"), ) dec_frame1 = self.extension.decode(enc_frame1) @@ -304,16 +312,34 @@ def test_init_error(self): def test_get_request_params(self): for config, result in [ # Test without any parameter - ((False, False, None, None), []), + ( + (False, False, None, None), + [], + ), # Test server_no_context_takeover - ((True, False, None, None), [("server_no_context_takeover", None)]), + ( + (True, False, None, None), + [("server_no_context_takeover", None)], + ), # Test client_no_context_takeover - ((False, True, None, None), [("client_no_context_takeover", None)]), + ( + (False, True, None, None), + [("client_no_context_takeover", None)], + ), # Test server_max_window_bits - ((False, False, 10, None), [("server_max_window_bits", "10")]), + ( + (False, False, 10, None), + [("server_max_window_bits", "10")], + ), # Test client_max_window_bits - ((False, False, None, 10), [("client_max_window_bits", "10")]), - ((False, False, None, True), [("client_max_window_bits", None)]), + ( + (False, False, None, 10), + [("client_max_window_bits", "10")], + ), + ( + (False, False, None, True), + [("client_max_window_bits", None)], + ), # Test all parameters together ( (True, True, 12, 12), @@ -332,15 +358,27 @@ def test_get_request_params(self): def test_process_response_params(self): for config, response_params, result in [ # Test without any parameter - ((False, False, None, None), [], (False, False, 15, 15)), - ((False, False, None, None), [("unknown", None)], InvalidParameterName), + ( + (False, False, None, None), + [], + (False, False, 15, 15), + ), + ( + (False, False, None, None), + [("unknown", None)], + InvalidParameterName, + ), # Test server_no_context_takeover ( (False, False, None, None), [("server_no_context_takeover", None)], (True, False, 15, 15), ), - ((True, False, None, None), [], NegotiationError), + ( + (True, False, None, None), + [], + NegotiationError, + ), ( (True, False, None, None), [("server_no_context_takeover", None)], @@ -362,7 +400,11 @@ def test_process_response_params(self): [("client_no_context_takeover", None)], (False, True, 15, 15), ), - ((False, True, None, None), [], (False, True, 15, 15)), + ( + (False, True, None, None), + [], + (False, True, 15, 15), + ), ( (False, True, None, None), [("client_no_context_takeover", None)], @@ -394,7 +436,11 @@ def test_process_response_params(self): [("server_max_window_bits", "16")], NegotiationError, ), - ((False, False, 12, None), [], NegotiationError), + ( + (False, False, 12, None), + [], + NegotiationError, + ), ( (False, False, 12, None), [("server_max_window_bits", "10")], @@ -426,7 +472,11 @@ def test_process_response_params(self): [("client_max_window_bits", "10")], NegotiationError, ), - ((False, False, None, True), [], (False, False, 15, 15)), + ( + (False, False, None, True), + [], + (False, False, 15, 15), + ), ( (False, False, None, True), [("client_max_window_bits", "7")], @@ -442,7 +492,11 @@ def test_process_response_params(self): [("client_max_window_bits", "16")], NegotiationError, ), - ((False, False, None, 12), [], (False, False, 15, 12)), + ( + (False, False, None, 12), + [], + (False, False, 15, 12), + ), ( (False, False, None, 12), [("client_max_window_bits", "10")], @@ -558,7 +612,8 @@ def test_enable_client_permessage_deflate(self): extension = extensions[expected_position] self.assertIsInstance(extension, ClientPerMessageDeflateFactory) self.assertEqual( - extension.compress_settings, expected_compress_settings + extension.compress_settings, + expected_compress_settings, ) @@ -597,7 +652,12 @@ def test_process_request_params(self): # (remote, local) vs. (server, client). for config, request_params, response_params, result in [ # Test without any parameter - ((False, False, None, None), [], [], (False, False, 15, 15)), + ( + (False, False, None, None), + [], + [], + (False, False, 15, 15), + ), ( (False, False, None, None), [("unknown", None)], @@ -746,7 +806,12 @@ def test_process_request_params(self): None, InvalidParameterValue, ), - ((False, False, None, 12), [], None, NegotiationError), + ( + (False, False, None, 12), + [], + None, + NegotiationError, + ), ( (False, False, None, 12), [("client_max_window_bits", None)], @@ -895,5 +960,6 @@ def test_enable_server_permessage_deflate(self): extension = extensions[expected_position] self.assertIsInstance(extension, ServerPerMessageDeflateFactory) self.assertEqual( - extension.compress_settings, expected_compress_settings + extension.compress_settings, + expected_compress_settings, ) From 9535c2137bdcdc0d34cf8367d2bb16c91a6fc083 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 15:24:14 +0100 Subject: [PATCH 223/762] Support building docs without pyenchant. --- docs/conf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index 8d9fcdc8f..fe6282b5a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -66,6 +66,11 @@ "sphinxcontrib_trio", "sphinxext.opengraph", ] +# It is currently inconvenient to install PyEnchant on Apple Silicon. +try: + import sphinxcontrib.spelling +except ImportError: + extensions.remove("sphinxcontrib.spelling") autodoc_typehints = "description" From c3560c9f6ff33d058459c2e626c0bd0b8c2c8e1a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 19 Feb 2022 15:24:31 +0100 Subject: [PATCH 224/762] Made compression negotation more lax. This change means connections from Firefox get compression by default, while they didn't use to since the compression "optimizations" in 10.0. Fix #1109. --- docs/project/changelog.rst | 5 +++++ docs/topics/compression.rst | 21 +++++++++++++++++-- .../extensions/permessage_deflate.py | 19 +++++++++++++---- tests/extensions/test_permessage_deflate.py | 10 ++++++++- tests/legacy/test_client_server.py | 5 ++++- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 14945b7bf..f888aeb2c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,11 @@ They may change at any time. *In development* +Improvements +............ + +* Made compression negotiation more lax for compatibility with Firefox. + 10.1 ---- diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index 0f264dc66..f0b7ce898 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -85,8 +85,13 @@ and memory usage for both sides. an integer between 9 (lowest memory usage) and 15 (best compression). Setting it to 8 is possible but rejected by some versions of zlib. - On the server side, websockets defaults to 12. On the client side, it lets - the server pick a suitable value, which is the same as defaulting to 15. + On the server side, websockets defaults to 12. Specifically, the compression + window size (server to client) is always 12 while the decompression window + (client to server) size may be 12 or 15 depending on whether the client + supports configuring it. + + On the client side, websockets lets the server pick a suitable value, which + has the same effect as defaulting to 15. :mod:`zlib` offers additional parameters for tuning compression. They control the trade-off between compression rate, memory usage, and CPU usage only for @@ -164,6 +169,18 @@ usage is: CPU usage is also higher for compression than decompression. +While it's always possible for a server to use a smaller window size for +compressing outgoing messages, using a smaller window size for decompressing +incoming messages requires collaboration from clients. + +When a client doesn't support configuring the size of its compression window, +websockets enables compression with the largest possible decompression window. +In most use cases, this is more efficient than disabling compression both ways. + +If you are very sensitive to memory usage, you can reverse this behavior by +setting the ``require_client_max_window_bits`` parameter of +:class:`ServerPerMessageDeflateFactory` to ``True``. + For clients ........... diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index fefa55643..017e3d843 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -471,10 +471,13 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): server_max_window_bits: maximum size of the server's LZ77 sliding window in bits, between 8 and 15. client_max_window_bits: maximum size of the client's LZ77 sliding window - in bits, between 8 and 15, or :obj:`True` to indicate support without - setting a limit. + in bits, between 8 and 15. compress_settings: additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. + require_client_max_window_bits: do not enable compression at all if + client doesn't advertise support for ``client_max_window_bits``; + the default behavior is to enable compression without enforcing + ``client_max_window_bits``. """ @@ -487,6 +490,7 @@ def __init__( server_max_window_bits: Optional[int] = None, client_max_window_bits: Optional[int] = None, compress_settings: Optional[Dict[str, Any]] = None, + require_client_max_window_bits: bool = False, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -501,12 +505,18 @@ def __init__( "compress_settings must not include wbits, " "set server_max_window_bits instead" ) + if client_max_window_bits is None and require_client_max_window_bits: + raise ValueError( + "require_client_max_window_bits is enabled, " + "but client_max_window_bits isn't configured" + ) self.server_no_context_takeover = server_no_context_takeover self.client_no_context_takeover = client_no_context_takeover self.server_max_window_bits = server_max_window_bits self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings + self.require_client_max_window_bits = require_client_max_window_bits def process_request_params( self, @@ -587,7 +597,7 @@ def process_request_params( # None None None # None True None - must change value # None 8≤M≤15 M (or None) - # 8≤N≤15 None Error! + # 8≤N≤15 None None or Error! # 8≤N≤15 True N - must change value # 8≤N≤15 8≤M≤N M (or None) # 8≤N≤15 N Date: Sat, 19 Feb 2022 17:35:01 +0100 Subject: [PATCH 225/762] Work around coverage issue in tests. --- tests/legacy/test_client_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index eee50be1d..142e4099e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1438,7 +1438,8 @@ async def run_client(): connect_inst.BACKOFF_MIN = 10 * MS connect_inst.BACKOFF_MAX = 99 * MS connect_inst.BACKOFF_INITIAL = 0 - async for ws in connect_inst: + # coverage has a hard time dealing with this code - I give up. + async for ws in connect_inst: # pragma: no cover await ws.send("spam") msg = await ws.recv() self.assertEqual(msg, "spam") @@ -1457,10 +1458,10 @@ async def run_client(): await server_ws.close() with self.assertRaises(ConnectionClosed): await ws.recv() - pass # work around bug in coverage else: # Exit block with an exception. raise Exception("BOOM!") + pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: with self.assertRaisesRegex(Exception, "BOOM!"): From 88d2e2fb51c9154d603861249d26efb0b5e55d80 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 08:03:41 +0100 Subject: [PATCH 226/762] Refactor nested try/except for clarity. --- src/websockets/legacy/client.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 6704d16ce..7253f4c59 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -655,23 +655,24 @@ async def __await_impl__(self) -> WebSocketClientProtocol: protocol = cast(WebSocketClientProtocol, protocol) try: - try: - await protocol.handshake( - self._wsuri, - origin=protocol.origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - except Exception: - protocol.fail_connection() - await protocol.wait_closed() - raise - else: - self.protocol = protocol - return protocol + await protocol.handshake( + self._wsuri, + origin=protocol.origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) except RedirectHandshake as exc: + protocol.fail_connection() + await protocol.wait_closed() self.handle_redirect(exc.uri) + except Exception: + protocol.fail_connection() + await protocol.wait_closed() + raise + else: + self.protocol = protocol + return protocol else: raise SecurityError("too many redirects") From 8516801f7f41b91cc491b3466f57a7bd4c8ee544 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 10:52:48 +0100 Subject: [PATCH 227/762] Avoid leaking sockets when connect() is canceled. Fix #1113. --- docs/project/changelog.rst | 5 ++++ src/websockets/legacy/client.py | 3 ++- tests/legacy/test_client_server.py | 37 +++++++++++++++++++++++++----- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f888aeb2c..3c264a31b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,11 @@ Improvements * Made compression negotiation more lax for compatibility with Firefox. +Bug fixes +......... + +* Avoided leaking open sockets when :func:`~client.connect` is canceled. + 10.1 ---- diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 7253f4c59..0bb2cb690 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -666,7 +666,8 @@ async def __await_impl__(self) -> WebSocketClientProtocol: protocol.fail_connection() await protocol.wait_closed() self.handle_redirect(exc.uri) - except Exception: + # Avoid leaking a connected socket when the handshake fails. + except (Exception, asyncio.CancelledError): protocol.fail_connection() await protocol.wait_closed() raise diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 142e4099e..2275ecdf3 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -147,15 +147,11 @@ def with_client(*args, **kwargs): return with_manager(temp_test_client, *args, **kwargs) -def get_server_uri(server, secure=False, resource_name="/", user_info=None): +def get_server_address(server): """ - Return a WebSocket URI for connecting to the given server. + Return an address on which the given server listens. """ - proto = "wss" if secure else "ws" - - user_info = ":".join(user_info) + "@" if user_info else "" - # Pick a random socket in order to test both IPv4 and IPv6 on systems # where both are available. Randomizing tests is usually a bad idea. If # needed, either use the first socket, or test separately IPv4 and IPv6. @@ -169,6 +165,17 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): else: # pragma: no cover raise ValueError("expected an IPv6, IPv4, or Unix socket") + return host, port + + +def get_server_uri(server, secure=False, resource_name="/", user_info=None): + """ + Return a WebSocket URI for connecting to the given server. + + """ + proto = "wss" if secure else "ws" + user_info = ":".join(user_info) + "@" if user_info else "" + host, port = get_server_address(server) return f"{proto}://{user_info}{host}:{port}{resource_name}" @@ -1067,6 +1074,21 @@ def test_server_error_in_handshake(self, _process_request): with self.assertRaises(InvalidHandshake): self.start_client() + @with_server(create_protocol=SlowOpeningHandshakeProtocol) + def test_client_connect_canceled_during_handshake(self): + sock = socket.create_connection(get_server_address(self.server)) + sock.send(b"") # socket is connected + + async def cancelled_client(): + start_client = connect(get_server_uri(self.server), sock=sock) + await asyncio.wait_for(start_client, 5 * MS) + + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(cancelled_client()) + + with self.assertRaises(OSError): + sock.send(b"") # socket is closed + @with_server() @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.send") def test_server_handler_crashes(self, send): @@ -1199,6 +1221,9 @@ class SecureClientServerTests( CommonClientServerTests, SecureClientServerTestsMixin, AsyncioTestCase ): + # The implementation of this test makes it hard to run it over TLS. + test_client_connect_canceled_during_handshake = None + # TLS over Unix sockets doesn't make sense. test_unix_socket = None From b3794a9fe1e042be38515c9ae4d22922d2d35382 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 10:54:46 +0100 Subject: [PATCH 228/762] Fail connection when send() is cancelled in a fragmented message. Ref #1129. --- src/websockets/legacy/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c1809e20d..aa737ca26 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -664,7 +664,7 @@ async def send( # Final fragment. await self.write_frame(True, OP_CONT, b"") - except Exception: + except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. self.fail_connection(1011) @@ -708,7 +708,7 @@ async def send( # Final fragment. await self.write_frame(True, OP_CONT, b"") - except Exception: + except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. self.fail_connection(1011) From c6f05b6a594bf5ef5857b2ef23402105d98f3a7a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 10:58:26 +0100 Subject: [PATCH 229/762] Update to latest black version. --- src/websockets/client.py | 2 +- src/websockets/connection.py | 2 +- src/websockets/http11.py | 2 +- src/websockets/legacy/client.py | 8 ++++---- src/websockets/legacy/protocol.py | 8 ++++---- src/websockets/legacy/server.py | 8 ++++---- src/websockets/server.py | 2 +- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 9b86b4d0a..7a904b151 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -74,7 +74,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, state: State = CONNECTING, - max_size: Optional[int] = 2 ** 20, + max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, ): super().__init__( diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 0a4d3c7bc..15a40f80c 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -87,7 +87,7 @@ def __init__( self, side: Side, state: State = OPEN, - max_size: Optional[int] = 2 ** 20, + max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, ) -> None: # Unique identifier. For logs. diff --git a/src/websockets/http11.py b/src/websockets/http11.py index a2fd22dd2..502ca64f8 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -18,7 +18,7 @@ # Support for HTTP response bodies is intended to read an error message # returned by a server. It isn't designed to perform large file transfers. -MAX_BODY = 2 ** 20 # 1 MiB +MAX_BODY = 2**20 # 1 MiB def d(value: bytes) -> str: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 0bb2cb690..fadc3efe8 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -442,10 +442,10 @@ def __init__( ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index aa737ca26..bbcc19664 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -171,10 +171,10 @@ def __init__( ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, # The following arguments are kept only for backwards compatibility. host: Optional[str] = None, port: Optional[int] = None, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3172059d2..8bc466a38 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -969,10 +969,10 @@ def __init__( ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, - max_size: Optional[int] = 2 ** 20, - max_queue: Optional[int] = 2 ** 5, - read_limit: int = 2 ** 16, - write_limit: int = 2 ** 16, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. diff --git a/src/websockets/server.py b/src/websockets/server.py index a94c0b629..bb0e0d7e2 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -72,7 +72,7 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, state: State = CONNECTING, - max_size: Optional[int] = 2 ** 20, + max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, ): super().__init__( From fe946ef0d1fb6dac982879a20f584ef66bc0a879 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 11:11:30 +0100 Subject: [PATCH 230/762] Fix flaky test. Ref #1113. --- tests/legacy/test_client_server.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2275ecdf3..e6fb05d57 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -158,15 +158,12 @@ def get_server_address(server): server_socket = random.choice(server.sockets) if server_socket.family == socket.AF_INET6: # pragma: no cover - host, port = server_socket.getsockname()[:2] # (no IPv6 on CI) - host = f"[{host}]" + return server_socket.getsockname()[:2] # (no IPv6 on CI) elif server_socket.family == socket.AF_INET: - host, port = server_socket.getsockname() + return server_socket.getsockname() else: # pragma: no cover raise ValueError("expected an IPv6, IPv4, or Unix socket") - return host, port - def get_server_uri(server, secure=False, resource_name="/", user_info=None): """ @@ -176,6 +173,8 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None): proto = "wss" if secure else "ws" user_info = ":".join(user_info) + "@" if user_info else "" host, port = get_server_address(server) + if ":" in host: # IPv6 address + host = f"[{host}]" return f"{proto}://{user_info}{host}:{port}{resource_name}" From 5b3a6d26c493f1aa54d267223ba3a908f0793fc8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 11:12:31 +0100 Subject: [PATCH 231/762] Avoid OSError: [Errno 107] noise in logs. Fix #1117. Ref #1072. --- src/websockets/legacy/protocol.py | 10 ++++++++-- tests/legacy/test_protocol.py | 3 --- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index bbcc19664..49c81da4f 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1302,10 +1302,16 @@ async def close_connection(self) -> None: self.logger.debug("! timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). - if self.transport.can_write_eof() and not self.transport.is_closing(): + if self.transport.can_write_eof(): if self.debug: self.logger.debug("x half-closing TCP connection") - self.transport.write_eof() + # write_eof() doesn't document which exceptions it raises. + # "[Errno 107] Transport endpoint is not connected" happens + # but it isn't completely clear under which circumstances. + try: + self.transport.write_eof() + except OSError: # pragma: no cover + pass if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 22e72a696..1672ab1ed 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -59,9 +59,6 @@ def setup_mock(self, loop, protocol): def can_write_eof(self): return True - def is_closing(self): - return False - def write_eof(self): # When the protocol half-closes the TCP connection, it expects the # other end to close it. Simulate that. From 72e9b37d422191faef3e939905a4edaee1145c24 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 14:24:28 +0100 Subject: [PATCH 232/762] Expand discussion of keepalives. Also mention pitfall in browsers. Fix #1084. --- docs/topics/timeouts.rst | 84 +++++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 15 deletions(-) diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 815a29b3f..dcf0322a4 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -3,34 +3,88 @@ Timeouts .. currentmodule:: websockets +Long-lived connections +---------------------- + Since the WebSocket protocol is intended for real-time communications over long-lived connections, it is desirable to ensure that connections don't break, and if they do, to report the problem quickly. -WebSocket is built on top of HTTP/1.1 where connections are short-lived, even -with ``Connection: keep-alive``. Typically, HTTP/1.1 infrastructure closes -idle connections after 30 to 120 seconds. +Connections can drop as a consequence of temporary network connectivity issues, +which are very common, even within datacenters. + +Furthermore, WebSocket builds on top of HTTP/1.1 where connections are +short-lived, even with ``Connection: keep-alive``. Typically, HTTP/1.1 +infrastructure closes idle connections after 30 to 120 seconds. -As a consequence, proxies may terminate WebSocket connections prematurely, -when no message was exchanged in 30 seconds. +As a consequence, proxies may terminate WebSocket connections prematurely when +no message was exchanged in 30 seconds. -In order to avoid this problem, websockets implements a keepalive mechanism -based on WebSocket Ping_ and Pong_ frames. Ping and Pong are designed for this -purpose. +Keepalive in websockets +----------------------- + +To avoid these problems, websockets runs a keepalive and heartbeat mechanism +based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 -By default, websockets waits 20 seconds, then sends a Ping frame, and expects -to receive the corresponding Pong frame within 20 seconds. Else, it considers -the connection broken and closes it. +It loops through these steps: + +1. Wait 20 seconds. +2. Send a Ping frame. +3. Receive a corresponding Pong frame within 20 seconds. + +If the Pong frame isn't received, websockets considers the connection broken and +closes it. + +This mechanism serves two purposes: + +1. It creates a trickle of traffic so that the TCP connection isn't idle and + network infrastructure along the path keeps it open ("keepalive"). +2. It detects if the connection drops or becomes so slow that it's unusable in + practice ("heartbeat"). In that case, it terminates the connection and your + application gets a :exc:`~exceptions.ConnectionClosed` exception. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` -arguments of :func:`~client.connect` and :func:`~server.serve`. +arguments of :func:`~client.connect` and :func:`~server.serve`. Shorter values +will detect connection drops faster but they will increase network traffic and +they will be more sensitive to latency. + +Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and +heartbeat mechanism. + +Setting ``ping_timeout`` to :obj:`None` disables only timeouts. This enables +keepalive, to keep idle connections open, and disables heartbeat, to support large +latency spikes. + +.. admonition:: Why doesn't websockets rely on TCP keepalive? + :class: hint + + TCP keepalive is disabled by default on most operating systems. When + enabled, the default interval is two hours or more, which is far too much. + +Keepalive in browsers +--------------------- + +Browsers don't enable a keepalive mechanism like websockets by default. As a +consequence, they can fail to notice that a WebSocket connection is broken for +an extended period of time, until the TCP connection times out. + +In this scenario, the ``WebSocket`` object in the browser doesn't fire a +``close`` event. If you have a reconnection mechanism, it doesn't kick in +because it believes that the connection is still working. + +If your browser-based app mysteriously and randomly fails to receive events, +this is a likely cause. You need a keepalive mechanism in the browser to avoid +this scenario. + +Unfortunately, the WebSocket API in browsers doesn't expose the native Ping and +Pong functionality in the WebSocket protocol. You have to roll your own in the +application layer. -While WebSocket runs on top of TCP, websockets doesn't rely on TCP keepalive -because it's disabled by default and, if enabled, the default interval is no -less than two hours, which doesn't meet requirements. +Latency issues +-------------- Latency between a client and a server may increase for two reasons: From 1167e1d61d9fa4916a70a212441dbcc3e6624268 Mon Sep 17 00:00:00 2001 From: Carlos Sobrinho Date: Tue, 1 Feb 2022 11:56:34 -0800 Subject: [PATCH 233/762] Update version.py Allow version check to ignore `git` if the command is not available on the path. --- src/websockets/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 605c8264a..324f43168 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -42,7 +42,7 @@ def get_version(tag: str) -> str: check=True, text=True, ).stdout.strip() - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): pass else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7}(?:-dirty)?)" From 46bc5e4dba19aa7052c5777657f9608fd6245c43 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 14:46:27 +0100 Subject: [PATCH 234/762] Add timeout to get_version(). --- src/websockets/version.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 324f43168..a7a12a61e 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -39,10 +39,12 @@ def get_version(tag: str) -> str: ["git", "describe", "--dirty", "--tags", "--long"], capture_output=True, cwd=root_dir, + timeout=1, check=True, text=True, ).stdout.strip() - except (subprocess.CalledProcessError, FileNotFoundError): + # subprocess.run raises FileNotFoundError if git isn't on $PATH. + except (FileNotFoundError, subprocess.CalledProcessError): pass else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7}(?:-dirty)?)" From a2ca001a06ebc0d4af7ae66776cb6212e62f8e24 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 16:22:00 +0100 Subject: [PATCH 235/762] Add to FAQ how to message one, several, or all users. Fix #1083. --- docs/howto/faq.rst | 94 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 86 insertions(+), 8 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 128d1dbd0..03f85802e 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -73,6 +73,92 @@ See also Python's documentation about `running blocking code`_. .. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code +.. _send-message-to-all-users: + +How do I send a message to all users? +..................................... + +Record all connections in a global variable:: + + CONNECTIONS = set() + + async def handler(websocket): + CONNECTIONS.add(websocket) + try: + await websocket.wait_closed() + finally: + CONNECTIONS.remove(websocket) + +Then, call :func:`~websockets.broadcast`:: + + import websockets + + def message_all(message): + websockets.broadcast(CONNECTIONS, message) + +If you're running multiple server processes, make sure you call ``message_all`` +in each process. + +.. _send-message-to-single-user: + +How do I send a message to a single user? +......................................... + +Record connections in a global variable, keyed by user identifier:: + + CONNECTIONS = {} + + async def handler(websocket): + user_id = ... # identify user in your app's context + CONNECTIONS[user_id] = websocket + try: + await websocket.wait_closed() + finally: + del CONNECTIONS[user_id] + +Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: + + async def message_user(user_id, message): + websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected + await websocket.send(message) # may raise websockets.ConnectionClosed + +Add error handling according to the behavior you want if the user disconnected +before the message could be sent. + +This example supports only one connection per user. To support concurrent +connects by the same user, you can change ``CONNECTIONS`` to store a set of +connections for each user. + +If you're running multiple server processes, call ``message_user`` in each +process. The process managing the user's connection sends the message; other +processes do nothing. + +When you reach a scale where server processes cannot keep up with the stream of +all messages, you need a better architecture. For example, you could deploy an +external publish / subscribe system such as Redis_. Server processes would +subscribe their clients. Then, they would receive messages only for the +connections that they're managing. + +.. _Redis: https://redis.io/ + +How do I send a message to a channel, a topic, or a subset of users? +.................................................................... + +websockets doesn't provide built-in publish / subscribe functionality. + +Record connections in a global variable, keyed by user identifier, as shown in +:ref:`How do I send a message to a single user?` + +Then, build the set of recipients and broadcast the message to them, as shown in +:ref:`How do I send a message to all users?` + +:doc:`django` contains a complete implementation of this pattern. + +Again, as you scale, you may reach the performance limits of a basic in-process +implementation. You may need an external publish / subscribe system like Redis_. + +.. _Redis: https://redis.io/ + How can I pass additional arguments to the connection handler? .............................................................. @@ -417,14 +503,6 @@ websockets takes care of responding to pings with pongs. Miscellaneous ------------- -How do I create channels or topics? -................................... - -websockets doesn't have built-in publish / subscribe for these use cases. - -Depending on the scale of your service, a simple in-memory implementation may -do the job or you may need an external publish / subscribe component. - Can I use websockets synchronously, without ``async`` / ``await``? .................................................................. From 731ad8c127bae60b59d622b87d9406333432cfbd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 19:26:27 +0100 Subject: [PATCH 236/762] Add broadcast example to quick start guide. Also review and improve that guide. Fix #1070. --- docs/intro/quickstart.rst | 84 ++++++++++++++++++++++++++------- example/quickstart/show_time.py | 5 +- 2 files changed, 69 insertions(+), 20 deletions(-) diff --git a/docs/intro/quickstart.rst b/docs/intro/quickstart.rst index 8c1221126..da3c3999e 100644 --- a/docs/intro/quickstart.rst +++ b/docs/intro/quickstart.rst @@ -5,14 +5,17 @@ Quick start Here are a few examples to get you started quickly with websockets. -Hello world! ------------- +Say "Hello world!" +------------------ Here's a WebSocket server. It receives a name from the client, sends a greeting, and closes the connection. .. literalinclude:: ../../example/quickstart/server.py + :caption: server.py + :language: python + :linenos: :func:`~server.serve` executes the connection handler coroutine ``hello()`` once for each WebSocket connection. It closes the WebSocket connection when @@ -23,14 +26,17 @@ Here's a corresponding WebSocket client. It sends a name to the server, receives a greeting, and closes the connection. .. literalinclude:: ../../example/quickstart/client.py + :caption: client.py + :language: python + :linenos: Using :func:`~client.connect` as an asynchronous context manager ensures the WebSocket connection is closed. .. _secure-server-example: -Encryption ----------- +Encrypt connections +------------------- Secure WebSocket connections improve confidentiality and also reliability because they reduce the risk of interference by bad proxies. @@ -47,46 +53,79 @@ requires certificates like ``https``. TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an earlier encryption protocol; the name stuck. -Here's how to adapt the server to encrypt connections. See the documentation -of the :mod:`ssl` module for configuring the context securely. +Here's how to adapt the server to encrypt connections. You must download +:download:`localhost.pem <../../example/quickstart/localhost.pem>` and save it +in the same directory as ``server_secure.py``. + +See the documentation of the :mod:`ssl` module for details on configuring the +TLS context securely. .. literalinclude:: ../../example/quickstart/server_secure.py + :caption: server_secure.py + :language: python + :linenos: Here's how to adapt the client similarly. .. literalinclude:: ../../example/quickstart/client_secure.py + :caption: client_secure.py + :language: python + :linenos: -This client needs a context because the server uses a self-signed certificate. +In this example, the client needs a TLS context because the server uses a +self-signed certificate. When connecting to a secure WebSocket server with a valid certificate — any certificate signed by a CA that your Python installation trusts — you can simply pass ``ssl=True`` to :func:`~client.connect`. -In a browser ------------- +Connect from a browser +---------------------- The WebSocket protocol was invented for the web — as the name says! -Here's how to connect to a WebSocket server in a browser. +Here's how to connect to a WebSocket server from a browser. Run this script in a console: .. literalinclude:: ../../example/quickstart/show_time.py + :caption: show_time.py + :language: python + :linenos: Save this file as ``show_time.html``: .. literalinclude:: ../../example/quickstart/show_time.html - :language: html + :caption: show_time.html + :language: html + :linenos: Save this file as ``show_time.js``: .. literalinclude:: ../../example/quickstart/show_time.js - :language: js + :caption: show_time.js + :language: js + :linenos: + +Then, open ``show_time.html`` in several browsers. Clocks tick irregularly. + +Broadcast messages +------------------ + +Let's change the previous example to send the same timestamps to all browsers, +instead of generating independent sequences for each client. + +Stop the previous script if it's still running and run this script in a console: + +.. literalinclude:: ../../example/quickstart/show_time_2.py + :caption: show_time_2.py + :language: python + :linenos: -Then open ``show_time.html`` in a browser and see the clock tick irregularly. +Refresh ``show_time.html`` in all browsers. Clocks tick in sync. -Broadcast ---------- +Manage application state +------------------------ A WebSocket server can receive events from clients, process them to update the application state, and broadcast the updated state to all connected clients. @@ -97,20 +136,29 @@ concurrency model of :mod:`asyncio` guarantees that updates are serialized. Run this script in a console: .. literalinclude:: ../../example/quickstart/counter.py + :caption: counter.py + :language: python + :linenos: Save this file as ``counter.html``: .. literalinclude:: ../../example/quickstart/counter.html - :language: html + :caption: counter.html + :language: html + :linenos: Save this file as ``counter.css``: .. literalinclude:: ../../example/quickstart/counter.css - :language: css + :caption: counter.css + :language: css + :linenos: Save this file as ``counter.js``: .. literalinclude:: ../../example/quickstart/counter.js - :language: js + :caption: counter.js + :language: js + :linenos: Then open ``counter.html`` file in several browsers and play with [+] and [-]. diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index facd56b00..a83078e8a 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -6,8 +6,9 @@ import websockets async def show_time(websocket): - while websocket.open: - await websocket.send(datetime.datetime.utcnow().isoformat() + "Z") + while True: + message = datetime.datetime.utcnow().isoformat() + "Z" + await websocket.send(message) await asyncio.sleep(random.random() * 2 + 1) async def main(): From 330dd3dab7af30fbda404e0b0c03deb7d3cd9dd9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 19:36:41 +0100 Subject: [PATCH 237/762] Update FAQ on reconnection for clients. --- docs/howto/faq.rst | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 03f85802e..262551664 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -276,9 +276,16 @@ The easiest is to use :func:`~client.connect` as a context manager:: How do I reconnect automatically when the connection drops? ........................................................... -See `issue 414`_. +Use :func:`connect` as an asynchronous iterator:: -.. _issue 414: https://github.com/aaugustin/websockets/issues/414 + async for websocket in websockets.connect(...): + try: + ... + except websockets.ConnectionClosed: + continue + +Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions +will break out of the loop. How do I stop a client that is continuously processing messages? ................................................................ From a113adc08a97b1689acd5d79fbf9ef7f00805c0a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 19:39:05 +0100 Subject: [PATCH 238/762] Add FAQ on stopping a server. --- docs/howto/faq.rst | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 262551664..44ae889da 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -220,11 +220,22 @@ To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. -How do I close a connection properly? -..................................... +How do I close a connection? +............................ websockets takes care of closing the connection when the handler exits. +How do I stop a server? +....................... + +Exit the :func:`~server.serve` context manager. + +Here's an example that terminates cleanly when it receives SIGTERM on Unix: + +.. literalinclude:: ../../example/shutdown_server.py + :emphasize-lines: 12-15,18 + + How do I run a HTTP server and WebSocket server on the same port? ................................................................. @@ -265,8 +276,8 @@ change it to:: async with connect(...) as websocket: await do_some_work() -How do I close a connection properly? -..................................... +How do I close a connection? +............................ The easiest is to use :func:`~client.connect` as a context manager:: @@ -287,8 +298,8 @@ Use :func:`connect` as an asynchronous iterator:: Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions will break out of the loop. -How do I stop a client that is continuously processing messages? -................................................................ +How do I stop a client that is processing messages in a loop? +............................................................. You can close the connection. From c7d8b587c1d2124ca04ffb618f15566a64b9eaa6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 20:09:43 +0100 Subject: [PATCH 239/762] Shorten questions in FAQ. --- docs/howto/faq.rst | 70 ++++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 44ae889da..30259e37e 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -15,8 +15,8 @@ FAQ Server side ----------- -Why does my server close the connection prematurely? -.................................................... +Why does the server close the connection prematurely? +..................................................... Your connection handler exits prematurely. Wait for the work to be finished before returning. @@ -31,8 +31,8 @@ change it to:: async def handler(websocket): await do_some_work() -Why does the server close the connection after processing one message? -...................................................................... +Why does the server close the connection after one message? +........................................................... Your connection handler exits after processing one message. Write a loop to process multiple messages. @@ -141,8 +141,8 @@ connections that they're managing. .. _Redis: https://redis.io/ -How do I send a message to a channel, a topic, or a subset of users? -.................................................................... +How do I send a message to a channel, a topic, or some users? +............................................................. websockets doesn't provide built-in publish / subscribe functionality. @@ -159,8 +159,8 @@ implementation. You may need an external publish / subscribe system like Redis_. .. _Redis: https://redis.io/ -How can I pass additional arguments to the connection handler? -.............................................................. +How do I pass arguments to the connection handler? +.................................................. You can bind additional arguments to the connection handler with :func:`functools.partial`:: @@ -179,8 +179,8 @@ Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it through an argument. -How do I get access HTTP headers, for example cookies? -...................................................... +How do I access HTTP headers, like cookies? +........................................... To access HTTP headers during the WebSocket handshake, you can override :attr:`~server.WebSocketServerProtocol.process_request`:: @@ -194,16 +194,16 @@ Once the connection is established, they're available in async def handler(websocket): cookies = websocket.request_headers["Cookie"] -How do I get the IP address of the client connecting to my server? -.................................................................. +How do I get the IP address of the client? +.......................................... It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: async def handler(websocket): remote_ip = websocket.remote_address[0] -How do I set which IP addresses my server listens to? -..................................................... +How do I set the IP addresses my server listens on? +................................................... Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. @@ -236,13 +236,13 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: :emphasize-lines: 12-15,18 -How do I run a HTTP server and WebSocket server on the same port? -................................................................. +How do I run HTTP and WebSocket servers on the same port? +......................................................... You don't. -HTTP and WebSockets have widely different operational characteristics. -Running them with the same server becomes inconvenient when you scale. +HTTP and WebSocket have widely different operational characteristics. Running +them with the same server becomes inconvenient when you scale. Providing a HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. @@ -260,8 +260,8 @@ support WebSocket connections, like Sanic_. Client side ----------- -Why does my client close the connection prematurely? -.................................................... +Why does the client close the connection prematurely? +..................................................... You're exiting the context manager prematurely. Wait for the work to be finished before exiting. @@ -284,8 +284,8 @@ The easiest is to use :func:`~client.connect` as a context manager:: async with connect(...) as websocket: ... -How do I reconnect automatically when the connection drops? -........................................................... +How do I reconnect when the connection drops? +............................................. Use :func:`connect` as an asynchronous iterator:: @@ -319,8 +319,8 @@ Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. asyncio usage ------------- -How do I do two things in parallel? How do I integrate with another coroutine? -.............................................................................. +How do I run two coroutines in parallel? +........................................ You must start two tasks, which the event loop will run concurrently. You can achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. @@ -359,8 +359,8 @@ See `issue 867`_. .. _issue 867: https://github.com/aaugustin/websockets/issues/867 -Why does my very simple program misbehave mysteriously? -....................................................... +Why does my simple program misbehave mysteriously? +.................................................. You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which blocks the event loop and prevents asyncio from operating normally. @@ -484,8 +484,8 @@ See `issue 574`_. .. _issue 574: https://github.com/aaugustin/websockets/issues/574 -How can I pass additional arguments to a custom protocol subclass? -.................................................................. +How can I pass arguments to a custom protocol subclass? +....................................................... You can bind additional arguments to the protocol factory with :func:`functools.partial`:: @@ -513,16 +513,18 @@ It closes the connection if it doesn't get a pong within 20 seconds. You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. +See :doc:`../topics/timeouts` for details. + How do I respond to pings? .......................... -websockets takes care of responding to pings with pongs. +Don't bother; websockets takes care of responding to pings with pongs. Miscellaneous ------------- -Can I use websockets synchronously, without ``async`` / ``await``? -.................................................................. +Can I use websockets without ``async`` and ``await``? +..................................................... You can convert every asynchronous call to a synchronous call by wrapping it in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this @@ -547,9 +549,9 @@ Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. -I'm having problems with threads -................................ +Why am I having problems with threads? +...................................... You shouldn't use threads. Use tasks instead. -:meth:`~asyncio.loop.call_soon_threadsafe` may help. +If you have to, :meth:`~asyncio.loop.call_soon_threadsafe` may help. From 5407db58d5a840a78dbc3572545030dc2d5036fe Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 20:10:20 +0100 Subject: [PATCH 240/762] Clarify answer on async/await. --- docs/howto/faq.rst | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 30259e37e..1eda6adad 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -526,11 +526,7 @@ Miscellaneous Can I use websockets without ``async`` and ``await``? ..................................................... -You can convert every asynchronous call to a synchronous call by wrapping it -in ``asyncio.get_event_loop().run_until_complete(...)``. Unfortunately, this -is deprecated as of Python 3.10. - -If this turns out to be impractical, you should use another library. +No, there is no convenient way to do this. You should use another library. Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ From 695d43a90819d817562b5d51bb2eb8c1b641a324 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 20:31:23 +0100 Subject: [PATCH 241/762] Improve FAQ about HTTP headers. Fix #798. --- docs/howto/faq.rst | 47 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index 1eda6adad..bc742f458 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -179,20 +179,30 @@ Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it through an argument. -How do I access HTTP headers, like cookies? -........................................... +How do I access HTTP headers? +............................. To access HTTP headers during the WebSocket handshake, you can override :attr:`~server.WebSocketServerProtocol.process_request`:: async def process_request(self, path, request_headers): - cookies = request_header["Cookie"] + authorization = request_headers["Authorization"] -Once the connection is established, they're available in -:attr:`~server.WebSocketServerProtocol.request_headers`:: +Once the connection is established, HTTP headers are available in +:attr:`~server.WebSocketServerProtocol.request_headers` and +:attr:`~server.WebSocketServerProtocol.response_headers`:: async def handler(websocket): - cookies = websocket.request_headers["Cookie"] + authorization = websocket.request_headers["Authorization"] + +How do I set HTTP headers? +.......................... + +To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in +the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` +arguments of :func:`~server.serve`. + +To set other HTTP headers, use the ``extra_headers`` argument. How do I get the IP address of the client? .......................................... @@ -276,6 +286,27 @@ change it to:: async with connect(...) as websocket: await do_some_work() +How do I access HTTP headers? +............................. + +Once the connection is established, HTTP headers are available in +:attr:`~client.WebSocketClientProtocol.request_headers` and +:attr:`~client.WebSocketClientProtocol.response_headers`. + +How do I set HTTP headers? +.......................... + +To set the ``Origin``, ``Sec-WebSocket-Extensions``, or +``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the +``origin``, ``extensions``, or ``subprotocols`` arguments of +:func:`~client.connect`. + +To set other HTTP headers, for example the ``Authorization`` header, use the +``extra_headers`` argument:: + + async with connect(..., extra_headers={"Authorization": ...}) as websocket: + ... + How do I close a connection? ............................ @@ -284,10 +315,12 @@ The easiest is to use :func:`~client.connect` as a context manager:: async with connect(...) as websocket: ... +The connection is closed when exiting the context manager. + How do I reconnect when the connection drops? ............................................. -Use :func:`connect` as an asynchronous iterator:: +Use :func:`~client.connect` as an asynchronous iterator:: async for websocket in websockets.connect(...): try: From f605e86aeab7ac045c21e2eac0d677916031022e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Feb 2022 21:28:28 +0100 Subject: [PATCH 242/762] Restore backwards-compatibility for partial handlers. Fix #1095. --- docs/project/changelog.rst | 3 +++ src/websockets/legacy/server.py | 28 +++++++++++++++++++--------- tests/legacy/test_client_server.py | 17 +++++++++++++++++ 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 3c264a31b..b47ae2341 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,9 @@ Improvements Bug fixes ......... +* Fixed backwards-incompatibility in 10.1 for connection handlers created with + :func:`functools.partial`. + * Avoided leaking open sockets when :func:`~client.connect` is canceled. 10.1 diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 8bc466a38..ea5b0d1fa 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1126,17 +1126,27 @@ def remove_path_argument( Callable[[WebSocketServerProtocol, str], Awaitable[Any]], ] ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: - if len(inspect.signature(ws_handler).parameters) == 2: - # Enable deprecation warning and announce deprecation in 11.0. - # warnings.warn("remove second argument of ws_handler", DeprecationWarning) + try: + inspect.signature(ws_handler).bind(None) + except TypeError: + try: + inspect.signature(ws_handler).bind(None, "") + except TypeError: # pragma: no cover + # ws_handler accepts neither one nor two arguments; leave it alone. + pass + else: + # ws_handler accepts two arguments; activate backwards compatibility. + + # Enable deprecation warning and announce deprecation in 11.0. + # warnings.warn("remove second argument of ws_handler", DeprecationWarning) - async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: - return await cast( - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], - ws_handler, - )(websocket, websocket.path) + async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: + return await cast( + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler, + )(websocket, websocket.path) - return _ws_handler + return _ws_handler return cast( Callable[[WebSocketServerProtocol], Awaitable[Any]], diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index e6fb05d57..acc231d76 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -497,6 +497,23 @@ async def handler_with_path(ws, path): "/path", ) + def test_ws_handler_argument_backwards_compatibility_partial(self): + async def handler_with_path(ws, path, extra): + await ws.send(path) + + bound_handler_with_path = functools.partial(handler_with_path, extra=None) + + with self.temp_server( + handler=bound_handler_with_path, + # Enable deprecation warning and announce deprecation in 11.0. + # deprecation_warnings=["remove second argument of ws_handler"], + ): + with self.temp_client("/path"): + self.assertEqual( + self.loop.run_until_complete(self.client.recv()), + "/path", + ) + async def process_request_OK(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" From a1d16ba97344e7f98c3dfce9295231c3862fde10 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Feb 2022 07:40:14 +0100 Subject: [PATCH 243/762] Update changelog for 10.2. --- docs/project/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b47ae2341..027b2d8bf 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,8 @@ Improvements * Made compression negotiation more lax for compatibility with Firefox. +* Improved FAQ and quick start guide. + Bug fixes ......... From 498cc8c061e53f0001cb2e3ade22ee8ce5ff11a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Feb 2022 07:41:08 +0100 Subject: [PATCH 244/762] Release 10.2. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 027b2d8bf..019aac324 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.2 ---- -*In development* +*February 21, 2022* Improvements ............ diff --git a/src/websockets/version.py b/src/websockets/version.py index a7a12a61e..ca379afe2 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "10.2" From 778a1ca6936ac67e7a3fe1bbe585db2eafeaa515 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Feb 2022 07:42:14 +0100 Subject: [PATCH 245/762] Start 10.3. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 019aac324..81dee9e10 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +10.3 +---- + +*In development* + 10.2 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index ca379afe2..605a0eeef 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "10.2" +tag = version = commit = "10.3" if not released: # pragma: no cover From c60df611023ac47345d9201b0a4785c4d8dbdbd9 Mon Sep 17 00:00:00 2001 From: Erik van Raalte Date: Sat, 26 Feb 2022 21:33:26 +0100 Subject: [PATCH 246/762] Update documentation Fix misleading function definition for `play` function in intro/tutorial2 --- docs/intro/tutorial2.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 669d46cde..5f4be2047 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -382,7 +382,7 @@ single coroutine to process the moves of both players: .. code-block:: python - async def play(game, player, connected): + async def play(websocket, game, player, connected): ... With such a coroutine, you can replace the temporary code for testing in @@ -390,13 +390,13 @@ With such a coroutine, you can replace the temporary code for testing in .. code-block:: python - await play(game, PLAYER1, connected) + await play(websocket, game, PLAYER1, connected) and in ``join()`` by: .. code-block:: python - await play(game, PLAYER2, connected) + await play(websocket, game, PLAYER2, connected) The ``play()`` coroutine will reuse much of the code you wrote in the first part of the tutorial. From a7b9860538c50e58a06f751b5f9eecde575fae2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Mar 2022 08:19:26 +0100 Subject: [PATCH 247/762] Add file forgotten in 731ad8c. --- example/quickstart/show_time_2.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100755 example/quickstart/show_time_2.py diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py new file mode 100755 index 000000000..08e87f593 --- /dev/null +++ b/example/quickstart/show_time_2.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import random +import websockets + +CONNECTIONS = set() + +async def register(websocket): + CONNECTIONS.add(websocket) + try: + await websocket.wait_closed() + finally: + CONNECTIONS.remove(websocket) + +async def show_time(): + while True: + message = datetime.datetime.utcnow().isoformat() + "Z" + websockets.broadcast(CONNECTIONS, message) + await asyncio.sleep(random.random() * 2 + 1) + +async def main(): + async with websockets.serve(register, "localhost", 5678): + await show_time() + +if __name__ == "__main__": + asyncio.run(main()) From db600e51a0edd6d1fbc18891775182cbc00f5f62 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Mar 2022 22:37:05 +0100 Subject: [PATCH 248/762] Work around problematic change in typeshed. Refs https://github.com/python/typeshed/pull/6653. --- src/websockets/datastructures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 37d3b5f86..77b68ba9a 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -134,7 +134,7 @@ def clear(self) -> None: self._dict = {} self._list = [] - def update(self, *args: HeadersLike, **kwargs: str) -> None: + def update(self, *args: HeadersLike, **kwargs: str) -> None: # type: ignore """ Update from a :class:`Headers` instance and/or keyword arguments. From 747df3d174acdbcda5506ed1588bd635fedd7078 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 21 Mar 2022 22:35:24 +0100 Subject: [PATCH 249/762] Widen HeadersLike type. Restore compatibility with the latest mypy. Refs https://github.com/python/typeshed/pull/6653. --- docs/reference/types.rst | 2 ++ src/websockets/datastructures.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 4b7952553..d86429be4 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -18,3 +18,5 @@ Types .. autodata:: websockets.connection.Event .. autodata:: websockets.datastructures.HeadersLike + +.. autodata:: websockets.datastructures.SupportsKeysAndGetItem diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 77b68ba9a..a0a648463 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -8,6 +8,7 @@ List, Mapping, MutableMapping, + Protocol, Tuple, Union, ) @@ -134,7 +135,7 @@ def clear(self) -> None: self._dict = {} self._list = [] - def update(self, *args: HeadersLike, **kwargs: str) -> None: # type: ignore + def update(self, *args: HeadersLike, **kwargs: str) -> None: """ Update from a :class:`Headers` instance and/or keyword arguments. @@ -164,5 +165,30 @@ def raw_items(self) -> Iterator[Tuple[str, str]]: return iter(self._list) -HeadersLike = Union[Headers, Mapping[str, str], Iterable[Tuple[str, str]]] -"""Types accepted where :class:`Headers` is expected.""" +# copy of _typeshed.SupportsKeysAndGetItem. +class SupportsKeysAndGetItem(Protocol): # pragma: no cover + """ + Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods. + + """ + + def keys(self) -> Iterable[str]: + ... + + def __getitem__(self, key: str) -> str: + ... + + +HeadersLike = Union[ + Headers, + Mapping[str, str], + Iterable[Tuple[str, str]], + SupportsKeysAndGetItem, +] +""" +Types accepted where :class:`Headers` is expected. + +In addition to :class:`Headers` itself, this includes dict-like types where both +keys and values are :class:`str`. + +""" From f7478be1321e5ccc9cd8995e2e95676faff57944 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Mar 2022 07:42:36 +0100 Subject: [PATCH 250/762] Restore compatibility with Python 3.7. --- src/websockets/datastructures.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index a0a648463..36a2cbaf9 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from typing import ( Any, Dict, @@ -8,12 +9,17 @@ List, Mapping, MutableMapping, - Protocol, Tuple, Union, ) +if sys.version_info[:2] >= (3, 8): + from typing import Protocol +else: # pragma: no cover + Protocol = object # mypy will report errors on Python 3.7. + + __all__ = ["Headers", "HeadersLike", "MultipleValuesError"] From 286768512b0c2bd671cae0ae3e64c1545632b6d4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Mar 2022 08:51:27 +0100 Subject: [PATCH 251/762] Remove path parameters in connection handlers. 9b8a8d1c wasn't reflected properly before the tutorial was merged. Fix #1154. --- docs/intro/tutorial2.rst | 2 +- example/tutorial/step2/app.py | 2 +- example/tutorial/step3/app.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 5f4be2047..53f295d3f 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -212,7 +212,7 @@ previous one to reuse it later: del JOIN[join_key] - async def handler(websocket, path): + async def handler(websocket): # Receive and parse the "init" event from the UI. message = await websocket.recv() event = json.loads(message) diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index bac2b6f27..2693d4304 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -160,7 +160,7 @@ async def watch(websocket, watch_key): connected.remove(websocket) -async def handler(websocket, path): +async def handler(websocket): """ Handle a connection and dispatch it according to who is connecting. diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 6fff79c95..c2ee020d2 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -162,7 +162,7 @@ async def watch(websocket, watch_key): connected.remove(websocket) -async def handler(websocket, path): +async def handler(websocket): """ Handle a connection and dispatch it according to who is connecting. From 0796c43c5c88e385d11471264eb519d421c7232d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Apr 2022 09:25:55 +0200 Subject: [PATCH 252/762] Simplify implementation of close_expected(). --- src/websockets/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 15a40f80c..ca996ba62 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -481,7 +481,6 @@ def close_expected(self) -> bool: bool: Whether the TCP connection is expected to close soon. """ - # We already got a TCP Close if and only if the state is CLOSED. # We expect a TCP close if and only if we sent a close frame: # * Normal closure: once we send a close frame, we expect a TCP close: # server waits for client to complete the TCP closing handshake; @@ -489,7 +488,8 @@ def close_expected(self) -> bool: # * Abnormal closure: we always send a close frame and the same logic # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. - return self.state is not CLOSED and self.close_sent is not None + # We already got a TCP Close if and only if the state is CLOSED. + return self.state is CLOSING # Private methods for receiving data. From 5dc16c27001a9465b22832e5b02036d2f10862bb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Apr 2022 10:28:12 +0200 Subject: [PATCH 253/762] Store handshake exceptions in connection objects. Storing them in request/response objects was probably legacy. --- docs/howto/sansio.rst | 14 ++++++++------ docs/project/changelog.rst | 11 +++++++++++ docs/reference/client.rst | 2 ++ docs/reference/server.rst | 2 ++ src/websockets/client.py | 3 ++- src/websockets/connection.py | 9 +++++++++ src/websockets/http11.py | 27 +++++++++++++++++++++------ src/websockets/server.py | 12 ++++++++---- tests/test_client.py | 26 +++++++++++++------------- tests/test_server.py | 32 ++++++++++++++++---------------- 10 files changed, 92 insertions(+), 46 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index 83496bff2..1373c81a5 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -52,10 +52,11 @@ the network, as described in `Send data`_ below. The first event returned by :meth:`~connection.Connection.events_received` is the WebSocket handshake response. -When the handshake fails, the reason is available in ``response.exception``:: +When the handshake fails, the reason is available in +:attr:`~client.ClientConnection.handshake_exc`:: - if response.exception is not None: - raise response.exception + if connection.handshake_exc is not None: + raise connection.handshake_exc Else, the WebSocket connection is open. @@ -96,10 +97,11 @@ the network, as described in `Send data`_ below. Even when you call :meth:`~server.ServerConnection.accept`, the WebSocket handshake may fail if the request is incorrect or unsupported. -When the handshake fails, the reason is available in ``request.exception``:: +When the handshake fails, the reason is available in +:attr:`~server.ServerConnection.handshake_exc`:: - if request.exception is not None: - raise request.exception + if connection.handshake_exc is not None: + raise connection.handshake_exc Else, the WebSocket connection is open. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 81dee9e10..26e9a5cdc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,17 @@ They may change at any time. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated. + :class: note + + Use the ``handshake_exc`` attribute of :class:`~server.ServerConnection` and + :class:`~client.ClientConnection` instead. + + See :doc:`../howto/sansio` for details. + 10.2 ---- diff --git a/docs/reference/client.rst b/docs/reference/client.rst index daf01ef58..379765397 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -103,6 +103,8 @@ Sans-I/O .. autoproperty:: state + .. autoattribute:: handshake_exc + .. autoproperty:: close_code .. autoproperty:: close_reason diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 97bf320b6..0e446f382 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -151,6 +151,8 @@ Sans-I/O .. autoproperty:: state + .. autoattribute:: handshake_exc + .. autoproperty:: close_code .. autoproperty:: close_reason diff --git a/src/websockets/client.py b/src/websockets/client.py index 7a904b151..8d826feaa 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -334,7 +334,8 @@ def parse(self) -> Generator[None, None, None]: try: self.process_response(response) except InvalidHandshake as exc: - response.exception = exc + response._exception = exc + self.handshake_exc = exc else: assert self.state is CONNECTING self.state = OPEN diff --git a/src/websockets/connection.py b/src/websockets/connection.py index ca996ba62..967bd8fa5 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -130,6 +130,15 @@ def __init__( self.close_sent: Optional[Close] = None self.close_rcvd_then_sent: Optional[bool] = None + # Track if an exception happened during the handshake. + self.handshake_exc: Optional[Exception] = None + """ + Exception to raise if the opening handshake failed. + + :obj:`None` if the opening handshake succeeded. + + """ + # Track if send_eof() was called. self.eof_sent = False diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 502ca64f8..84048fa47 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -2,6 +2,7 @@ import dataclasses import re +import warnings from typing import Callable, Generator, Optional from . import datastructures, exceptions @@ -57,15 +58,22 @@ class Request: Attributes: path: Request path, including optional query. headers: Request headers. - exception: If processing the response triggers an exception, - the exception is stored in this attribute. """ path: str headers: datastructures.Headers # body isn't useful is the context of this library. - exception: Optional[Exception] = None + _exception: Optional[Exception] = None + + @property + def exception(self) -> Optional[Exception]: # pragma: no cover + warnings.warn( + "Request.exception is deprecated; " + "use ServerConnection.handshake_exc instead", + DeprecationWarning, + ) + return self._exception @classmethod def parse( @@ -152,8 +160,6 @@ class Response: reason_phrase: Response reason. headers: Response headers. body: Response body, if any. - exception: if processing the response triggers an exception, - the exception is stored in this attribute. """ @@ -162,7 +168,16 @@ class Response: headers: datastructures.Headers body: Optional[bytes] = None - exception: Optional[Exception] = None + _exception: Optional[Exception] = None + + @property + def exception(self) -> Optional[Exception]: # pragma: no cover + warnings.warn( + "Response.exception is deprecated; " + "use ClientConnection.handshake_exc instead", + DeprecationWarning, + ) + return self._exception @classmethod def parse( diff --git a/src/websockets/server.py b/src/websockets/server.py index bb0e0d7e2..214417ad0 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -110,7 +110,8 @@ def accept(self, request: Request) -> Response: protocol_header, ) = self.process_request(request) except InvalidOrigin as exc: - request.exception = exc + request._exception = exc + self.handshake_exc = exc if self.debug: self.logger.debug("! invalid origin", exc_info=True) return self.reject( @@ -118,7 +119,8 @@ def accept(self, request: Request) -> Response: f"Failed to open a WebSocket connection: {exc}.\n", ) except InvalidUpgrade as exc: - request.exception = exc + request._exception = exc + self.handshake_exc = exc if self.debug: self.logger.debug("! invalid upgrade", exc_info=True) response = self.reject( @@ -133,7 +135,8 @@ def accept(self, request: Request) -> Response: response.headers["Upgrade"] = "websocket" return response except InvalidHandshake as exc: - request.exception = exc + request._exception = exc + self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) return self.reject( @@ -141,7 +144,8 @@ def accept(self, request: Request) -> Response: f"Failed to open a WebSocket connection: {exc}.\n", ) except Exception as exc: - request.exception = exc + request._exception = exc + self.handshake_exc = exc self.logger.error("opening handshake failed", exc_info=True) return self.reject( http.HTTPStatus.INTERNAL_SERVER_ERROR, diff --git a/tests/test_client.py b/tests/test_client.py index 1c1452d41..12fd8726f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -255,7 +255,7 @@ def test_missing_connection(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): @@ -268,7 +268,7 @@ def test_invalid_connection(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): @@ -280,7 +280,7 @@ def test_missing_upgrade(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): @@ -293,7 +293,7 @@ def test_invalid_upgrade(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_accept(self): @@ -305,7 +305,7 @@ def test_missing_accept(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") def test_multiple_accept(self): @@ -317,7 +317,7 @@ def test_multiple_accept(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Accept header: " @@ -334,7 +334,7 @@ def test_invalid_accept(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHeader) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), f"invalid Sec-WebSocket-Accept header: {ACCEPT}" ) @@ -383,7 +383,7 @@ def test_unexpected_extension(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "no extensions supported") def test_unsupported_extension(self): @@ -398,7 +398,7 @@ def test_unsupported_extension(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), "Unsupported extension: name = x-op, params = [('op', None)]", @@ -429,7 +429,7 @@ def test_unsupported_extension_parameters(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), "Unsupported extension: name = x-op, params = [('op', 'that')]", @@ -520,7 +520,7 @@ def test_unexpected_subprotocol(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "no subprotocols supported") def test_multiple_subprotocols(self): @@ -536,7 +536,7 @@ def test_multiple_subprotocols(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual( str(raised.exception), "multiple subprotocols: superchat, chat" ) @@ -566,7 +566,7 @@ def test_unsupported_subprotocol(self): self.assertEqual(client.state, CONNECTING) with self.assertRaises(InvalidHandshake) as raised: - raise response.exception + raise client.handshake_exc self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") diff --git a/tests/test_server.py b/tests/test_server.py index 54699c3ef..43bc03e14 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -188,7 +188,7 @@ def test_unexpected_exception(self): self.assertEqual(response.status_code, 500) with self.assertRaises(Exception) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "BOOM") def test_missing_connection(self): @@ -200,7 +200,7 @@ def test_missing_connection(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): @@ -213,7 +213,7 @@ def test_invalid_connection(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): @@ -225,7 +225,7 @@ def test_missing_upgrade(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): @@ -238,7 +238,7 @@ def test_invalid_upgrade(self): self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_key(self): @@ -249,7 +249,7 @@ def test_missing_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") def test_multiple_key(self): @@ -260,7 +260,7 @@ def test_multiple_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: " @@ -276,7 +276,7 @@ def test_invalid_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" ) @@ -292,7 +292,7 @@ def test_truncated_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" ) @@ -305,7 +305,7 @@ def test_missing_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") def test_multiple_version(self): @@ -316,7 +316,7 @@ def test_multiple_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: " @@ -332,7 +332,7 @@ def test_invalid_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Sec-WebSocket-Version header: 11" ) @@ -344,7 +344,7 @@ def test_no_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual(str(raised.exception), "missing Origin header") def test_origin(self): @@ -364,7 +364,7 @@ def test_unexpected_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Origin header: https://other.example.com" ) @@ -382,7 +382,7 @@ def test_multiple_origin(self): # 400 Bad Request rather than 403 Forbidden. self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Origin header: more than one Origin header found", @@ -409,7 +409,7 @@ def test_unsupported_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: - raise request.exception + raise server.handshake_exc self.assertEqual( str(raised.exception), "invalid Origin header: https://original.example.com" ) From 4034d8dbc81742dc2c0688dc9f29c5161bec66e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Apr 2022 10:43:21 +0200 Subject: [PATCH 254/762] Expect TCP close after a failed opening handshake. --- src/websockets/client.py | 2 ++ src/websockets/connection.py | 2 +- src/websockets/server.py | 11 ++++++++++- tests/test_client.py | 5 +++++ tests/test_server.py | 4 ++++ 5 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 8d826feaa..df8e53429 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -336,6 +336,8 @@ def parse(self) -> Generator[None, None, None]: except InvalidHandshake as exc: response._exception = exc self.handshake_exc = exc + self.parser = self.discard() + next(self.parser) # start coroutine else: assert self.state is CONNECTING self.state = OPEN diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 967bd8fa5..db8b53699 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -498,7 +498,7 @@ def close_expected(self) -> bool: # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING + return self.state is CLOSING or self.handshake_exc is not None # Private methods for receiving data. diff --git a/src/websockets/server.py b/src/websockets/server.py index 214417ad0..5dad50b6a 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -13,6 +13,7 @@ InvalidHeader, InvalidHeaderValue, InvalidOrigin, + InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -471,8 +472,14 @@ def reject( ("Server", USER_AGENT), ] ) + response = Response(status.value, status.phrase, headers, body) + # When reject() is called from accept(), handshake_exc is already set. + # If a user calls reject(), set handshake_exc to guarantee invariant: + # "handshake_exc is None if and only if opening handshake succeded." + if self.handshake_exc is None: + self.handshake_exc = InvalidStatus(response) self.logger.info("connection failed (%d %s)", status.value, status.phrase) - return Response(status.value, status.phrase, headers, body) + return response def send_response(self, response: Response) -> None: """ @@ -497,6 +504,8 @@ def send_response(self, response: Response) -> None: self.state = OPEN else: self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: diff --git a/tests/test_client.py b/tests/test_client.py index 12fd8726f..a843d3272 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -42,6 +42,7 @@ def test_send_connect(self): f"\r\n".encode() ], ) + self.assertFalse(client.close_expected()) def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): @@ -135,6 +136,8 @@ def test_receive_accept(self): ) [response] = client.events_received() self.assertIsInstance(response, Response) + self.assertEqual(client.data_to_send(), []) + self.assertFalse(client.close_expected()) self.assertEqual(client.state, OPEN) def test_receive_reject(self): @@ -155,6 +158,8 @@ def test_receive_reject(self): ) [response] = client.events_received() self.assertIsInstance(response, Response) + self.assertEqual(client.data_to_send(), []) + self.assertTrue(client.close_expected()) self.assertEqual(client.state, CONNECTING) def test_accept_response(self): diff --git a/tests/test_server.py b/tests/test_server.py index 43bc03e14..e3e802239 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -38,6 +38,8 @@ def test_receive_connect(self): ) [request] = server.events_received() self.assertIsInstance(request, Request) + self.assertEqual(server.data_to_send(), []) + self.assertFalse(server.close_expected()) def test_connect_request(self): server = ServerConnection() @@ -104,6 +106,7 @@ def test_send_accept(self): f"\r\n".encode() ], ) + self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) def test_send_reject(self): @@ -126,6 +129,7 @@ def test_send_reject(self): b"", ], ) + self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) def test_accept_response(self): From 5e78f15929569e5557930089613a5760557a3a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=AE=AE=E0=AE=A9=E0=AF=8B=E0=AE=9C=E0=AF=8D=E0=AE=95?= =?UTF-8?q?=E0=AF=81=E0=AE=AE=E0=AE=BE=E0=AE=B0=E0=AF=8D=20=E0=AE=AA?= =?UTF-8?q?=E0=AE=B4=E0=AE=A9=E0=AE=BF=E0=AE=9A=E0=AF=8D=E0=AE=9A=E0=AE=BE?= =?UTF-8?q?=E0=AE=AE=E0=AE=BF?= Date: Fri, 15 Apr 2022 09:51:55 +0530 Subject: [PATCH 255/762] Update README.rst --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f8df94ba4..2b9a445ea 100644 --- a/README.rst +++ b/README.rst @@ -125,7 +125,7 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for a HTTP health check. - If you want do to both in the same server, look at HTTP frameworks that + If you want to do both in the same server, look at HTTP frameworks that build on top of ``websockets`` to support WebSocket connections, like Sanic_. From 8a58de259381f82036ab624255d01a8fbc6234de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 6 Apr 2022 08:10:53 +0200 Subject: [PATCH 256/762] Update Tidelift security link. --- SECURITY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SECURITY.md b/SECURITY.md index 556217a4d..82024b485 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -2,4 +2,4 @@ Only the latest version receives security updates. -Please report vulnerabilities [via Tidelift](https://tidelift.com/docs/security). +Please report vulnerabilities [via Tidelift](https://tidelift.com/security). From 81cb6f55e50d54640f9a0cdff31e3eb0d079434d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Fri, 8 Apr 2022 15:57:51 +0200 Subject: [PATCH 257/762] Skip non-contiguous buffer tests on NotImplementedError (pypy) PyPy3.9 7.3.9 does not implement creating contiguous buffers from non-contiguous, causing the respective tests to fail due to NotImplementedError. Catch it and skip the tests appropriately as discussed in #1157. Fixes #1158 --- tests/test_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index a9ea8dcbd..528eeaf24 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -88,4 +88,10 @@ def test_apply_mask_check_mask_length(self): class SpeedupsTests(ApplyMaskTests): @staticmethod def apply_mask(*args, **kwargs): - return c_apply_mask(*args, **kwargs) + try: + return c_apply_mask(*args, **kwargs) + except NotImplementedError as e: + # PyPy3.9 as of 7.3.9 does not implement creating + # contiguous buffers from non-contiguous and raises + # NotImplementedError. Catch it and skip the test. + raise unittest.SkipTest(str(e)) From 72abbd3d61801f4f45ed7babdfe14ebbd7c2ccf1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 09:07:57 +0200 Subject: [PATCH 258/762] Narrow down exception for PyPy To avoid accidentally masking real test failures on other platforms. --- tests/test_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 528eeaf24..acd60edfc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ import base64 import itertools +import platform import unittest from websockets.utils import accept_key, apply_mask as py_apply_mask, generate_key @@ -90,8 +91,13 @@ class SpeedupsTests(ApplyMaskTests): def apply_mask(*args, **kwargs): try: return c_apply_mask(*args, **kwargs) - except NotImplementedError as e: - # PyPy3.9 as of 7.3.9 does not implement creating - # contiguous buffers from non-contiguous and raises - # NotImplementedError. Catch it and skip the test. - raise unittest.SkipTest(str(e)) + except NotImplementedError as exc: # pragma: no cover + # PyPy doesn't implement creating contiguous readonly buffer + # from non-contiguous. We don't care about this edge case. + if ( + platform.python_implementation() == "PyPy" + and "not implemented yet" in str(exc) + ): + raise unittest.SkipTest(str(exc)) + else: + raise From 79182b66798ed96c622462dbe9e28457655f887b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 09:43:42 +0200 Subject: [PATCH 259/762] Update regex matching git hashes. A branch (test-on-pypy) now uses 8 characters for git hashes. --- src/websockets/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 605a0eeef..65123b3fa 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -47,7 +47,7 @@ def get_version(tag: str) -> str: except (FileNotFoundError, subprocess.CalledProcessError): pass else: - description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7}(?:-dirty)?)" + description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" match = re.fullmatch(description_re, description) assert match is not None distance, remainder = match.groups() @@ -69,7 +69,7 @@ def get_version(tag: str) -> str: def get_commit(tag: str, version: str) -> str: # Extract commit from version, falling back to tag if not available. - version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7}|unknown)(?:\.dirty)?" + version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?" match = re.fullmatch(version_re, version) assert match is not None (commit,) = match.groups() From f761e23614c4d8b35ce461e70896e7c6c5bc1cfa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 10:42:33 +0200 Subject: [PATCH 260/762] Prevent AttributeError on server shutdown. close() shouldn't be called on a CONNECTING connection because the transfer_data_task attribute isn't initialized. --- src/websockets/legacy/server.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index ea5b0d1fa..3e51db1b7 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -768,13 +768,15 @@ async def _close(self) -> None: # closed, handshake() closes OPENING connections with a HTTP 503 # error. Wait until all connections are closed. - # asyncio.wait doesn't accept an empty first argument - if self.websockets: + close_tasks = [ + asyncio.create_task(websocket.close(1001)) + for websocket in self.websockets + if websocket.state is not State.CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: await asyncio.wait( - [ - asyncio.create_task(websocket.close(1001)) - for websocket in self.websockets - ], + close_tasks, **loop_if_py_lt_38(self.get_loop()), ) From 9e2503a5c5cc4c48f223ec8aa780dfa7268622a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 09:14:58 +0200 Subject: [PATCH 261/762] Add PyPy to testing matrix. --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3aa579aa4..9412c0ea5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,7 +38,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.7", "3.8", "3.9", "3.10"] + python: ["3.7", "3.8", "3.9", "3.10", "pypy-3.9"] steps: - name: Check out repository uses: actions/checkout@v2 From 8dd8e410431408cf33fe761d8b273a95cf45da5f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 10:36:32 +0200 Subject: [PATCH 262/762] Ignore failing test on PyPy. This has to do with closing TLS connections: transport.close() is enough on CPython while PyPy requires transport.abort() (which websockets will call eventually, but not before that test fails.) --- tests/legacy/test_client_server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index acc231d76..f9de70c9c 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -4,6 +4,7 @@ import http import logging import pathlib +import platform import random import socket import ssl @@ -1243,6 +1244,10 @@ class SecureClientServerTests( # TLS over Unix sockets doesn't make sense. test_unix_socket = None + # This test fails under PyPy due to a difference with CPython. + if platform.python_implementation() == "PyPy": # pragma: no cover + test_http_request_ws_endpoint = None + @with_server() def test_ws_uri_is_rejected(self): with self.assertRaises(ValueError): From 1e8e8ce4bcccbbf922d941e2fe61a28a4e370a46 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 11:01:10 +0200 Subject: [PATCH 263/762] Bump timings when testing on PyPy. --- tests/legacy/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 4d4306232..1fa2b53c8 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -3,6 +3,7 @@ import functools import logging import os +import platform import time import unittest @@ -90,5 +91,9 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover MS *= 10 +# PyPy has a performance penalty for this test suite. +if platform.python_implementation() == "PyPy": # pragma: no cover + MS *= 5 + # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From 19367485bbdda2eb7c55ad3a43b5b1296743b1f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 11:57:29 +0200 Subject: [PATCH 264/762] Handle SSLError when receiving messages. Refs #1160. --- src/websockets/legacy/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 49c81da4f..66751a477 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -5,6 +5,7 @@ import collections import logging import random +import ssl import struct import uuid import warnings @@ -974,13 +975,14 @@ async def transfer_data(self) -> None: self.transfer_data_exc = exc self.fail_connection(1002) - except (ConnectionError, TimeoutError, EOFError) as exc: + except (ConnectionError, TimeoutError, EOFError, ssl.SSLError) as exc: # Reading data with self.reader.readexactly may raise: # - most subclasses of ConnectionError if the TCP connection # breaks, is reset, or is aborted; # - TimeoutError if the TCP connection times out; # - IncompleteReadError, a subclass of EOFError, if fewer - # bytes are available than requested. + # bytes are available than requested; + # - ssl.SSLError if the other side infringes the TLS protocol. self.transfer_data_exc = exc self.fail_connection(1006) From a446fb78d9994c30a479719ea579c349de4f372f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 12:03:30 +0200 Subject: [PATCH 265/762] Handle zlib.error when receiving compressed messages. Refs #1160, #665. --- src/websockets/extensions/permessage_deflate.py | 5 ++++- tests/extensions/test_permessage_deflate.py | 5 ++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 017e3d843..e0de5e8f8 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -125,7 +125,10 @@ def decode( if frame.fin: data += _EMPTY_UNCOMPRESSED_BLOCK max_length = 0 if max_size is None else max_size - data = self.decoder.decompress(data, max_length) + try: + data = self.decoder.decompress(data, max_length) + except zlib.error as exc: + raise exceptions.ProtocolError("decompression failed") from exc if self.decoder.unconsumed_tail: raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)") diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index a50685587..3cc9172df 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,6 +1,5 @@ import dataclasses import unittest -import zlib from websockets.exceptions import ( DuplicateParameter, @@ -8,6 +7,7 @@ InvalidParameterValue, NegotiationError, PayloadTooBig, + ProtocolError, ) from websockets.extensions.permessage_deflate import * from websockets.frames import ( @@ -225,9 +225,8 @@ def test_remote_no_context_takeover(self): dec_frame1 = self.extension.decode(enc_frame1) self.assertEqual(dec_frame1, frame) - with self.assertRaises(zlib.error) as exc: + with self.assertRaises(ProtocolError): self.extension.decode(enc_frame2) - self.assertIn("invalid distance too far back", str(exc.exception)) def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. From 8850eb06c7955462dd641acf50d1dcfcfc95b0ee Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 28 Feb 2022 14:10:16 +0800 Subject: [PATCH 266/762] Fix logging error when sending a memoryview. See also: https://bugs.python.org/issue15945 --- src/websockets/frames.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 82b4a1403..07aabac8d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -139,7 +139,7 @@ def __str__(self) -> str: # Encode just what we need, plus two dummy bytes to elide later. binary = self.data if len(binary) > 25: - binary = binary[:16] + b"\x00\x00" + binary[-8:] + binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: data = str(Close.parse(self.data)) @@ -153,7 +153,7 @@ def __str__(self) -> str: except UnicodeDecodeError: binary = self.data if len(binary) > 25: - binary = binary[:16] + b"\x00\x00" + binary[-8:] + binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: From 778879eb0144a14a2a406fb2c3fa45f80afcd421 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 12:29:21 +0200 Subject: [PATCH 267/762] Add tests for previous commit. --- src/websockets/frames.py | 8 +++++--- tests/test_frames.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 07aabac8d..043b688b5 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -145,12 +145,14 @@ def __str__(self) -> str: data = str(Close.parse(self.data)) elif self.data: # We don't know if a Continuation frame contains text or binary. - # Ping and Pong frames could contain UTF-8. Attempt to decode as - # UTF-8 and display it as text; fallback to binary. + # Ping and Pong frames could contain UTF-8. + # Attempt to decode as UTF-8 and display it as text; fallback to + # binary. If self.data is a memoryview, it has no decode() method, + # which raises AttributeError. try: data = repr(self.data.decode()) coding = "text" - except UnicodeDecodeError: + except (UnicodeDecodeError, AttributeError): binary = self.data if len(binary) > 25: binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) diff --git a/tests/test_frames.py b/tests/test_frames.py index c8f9867d4..e7c48b930 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -207,6 +207,12 @@ def test_cont_binary(self): "CONT fc fd fe ff [binary, 4 bytes, continued]", ) + def test_cont_binary_from_memoryview(self): + self.assertEqual( + str(Frame(OP_CONT, memoryview(b"\xfc\xfd\xfe\xff"), fin=False)), + "CONT fc fd fe ff [binary, 4 bytes, continued]", + ) + def test_cont_final_text(self): self.assertEqual( str(Frame(OP_CONT, b" cr\xc3\xa8me")), @@ -219,6 +225,12 @@ def test_cont_final_binary(self): "CONT fc fd fe ff [binary, 4 bytes]", ) + def test_cont_final_binary_from_memoryview(self): + self.assertEqual( + str(Frame(OP_CONT, memoryview(b"\xfc\xfd\xfe\xff"))), + "CONT fc fd fe ff [binary, 4 bytes]", + ) + def test_cont_text_truncated(self): self.assertEqual( str(Frame(OP_CONT, b"caf\xc3\xa9 " * 16, fin=False)), @@ -233,6 +245,13 @@ def test_cont_binary_truncated(self): " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", ) + def test_cont_binary_truncated_from_memoryview(self): + self.assertEqual( + str(Frame(OP_CONT, memoryview(bytes(range(256))), fin=False)), + "CONT 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." + " f8 f9 fa fb fc fd fe ff [binary, 256 bytes, continued]", + ) + def test_text(self): self.assertEqual( str(Frame(OP_TEXT, b"caf\xc3\xa9")), @@ -264,12 +283,24 @@ def test_binary(self): "BINARY 00 01 02 03 [4 bytes]", ) + def test_binary_from_memoryview(self): + self.assertEqual( + str(Frame(OP_BINARY, memoryview(b"\x00\x01\x02\x03"))), + "BINARY 00 01 02 03 [4 bytes]", + ) + def test_binary_non_final(self): self.assertEqual( str(Frame(OP_BINARY, b"\x00\x01\x02\x03", fin=False)), "BINARY 00 01 02 03 [4 bytes, continued]", ) + def test_binary_non_final_from_memoryview(self): + self.assertEqual( + str(Frame(OP_BINARY, memoryview(b"\x00\x01\x02\x03"), fin=False)), + "BINARY 00 01 02 03 [4 bytes, continued]", + ) + def test_binary_truncated(self): self.assertEqual( str(Frame(OP_BINARY, bytes(range(256)))), @@ -277,6 +308,13 @@ def test_binary_truncated(self): " f8 f9 fa fb fc fd fe ff [256 bytes]", ) + def test_binary_truncated_from_memoryview(self): + self.assertEqual( + str(Frame(OP_BINARY, memoryview(bytes(range(256))))), + "BINARY 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f ..." + " f8 f9 fa fb fc fd fe ff [256 bytes]", + ) + def test_close(self): self.assertEqual( str(Frame(OP_CLOSE, b"\x03\xe8")), From a96fffa85907a392711c75ad3039c2949531a268 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 12:57:01 +0200 Subject: [PATCH 268/762] Add FAQ on request path. Fix #1154. --- docs/howto/faq.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index bc742f458..e9aeae881 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -179,6 +179,30 @@ Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it through an argument. +How do I access the request path? +................................. + +It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. + +You may route a connection to different handlers depending on the request path:: + + async def handler(websocket): + if websocket.path == "/blue": + await blue_handler(websocket) + elif websocket.path == "/green": + await green_handler(websocket) + else: + # No handler for this path; close the connection. + return + +You may also route the connection based on the first message received from the +client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to +authenticate the connection before routing it, this is usually more convenient. + +Generally speaking, there is far less emphasis on the request path in WebSocket +servers than in HTTP servers. When a WebSockt server provides a single endpoint, +it may ignore the request path entirely. + How do I access HTTP headers? ............................. From f9fd2cebcd42633ed917cd64e805bea17879c2d7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 13:27:32 +0200 Subject: [PATCH 269/762] Catch RuntimeError to cater to uvloop. Fix #1138, #1072 (again). --- src/websockets/legacy/protocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 66751a477..d1d52bfac 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1310,9 +1310,10 @@ async def close_connection(self) -> None: # write_eof() doesn't document which exceptions it raises. # "[Errno 107] Transport endpoint is not connected" happens # but it isn't completely clear under which circumstances. + # uvloop can raise RuntimeError here. try: self.transport.write_eof() - except OSError: # pragma: no cover + except (OSError, RuntimeError): # pragma: no cover pass if await self.wait_for_connection_lost(): From 42e33f17cdaaa258f3035618fb806c74744406aa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 15:39:18 +0200 Subject: [PATCH 270/762] Clarify answer about threads. Fix #1156. Ref #1162. --- docs/howto/faq.rst | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst index e9aeae881..a103f677c 100644 --- a/docs/howto/faq.rst +++ b/docs/howto/faq.rst @@ -607,4 +607,15 @@ Why am I having problems with threads? You shouldn't use threads. Use tasks instead. -If you have to, :meth:`~asyncio.loop.call_soon_threadsafe` may help. +Indeed, when you chose websockets, you chose :mod:`asyncio` as the primary +framework to handle concurrency. This choice is mutually exclusive with +:mod:`threading`. + +If you believe that you need to run websockets in a thread and some logic in +another thread, you should run that logic in a :class:`~asyncio.Task` instead. + +If you believe that you cannot run that logic in the same event loop because it +will block websockets, :meth:`~asyncio.loop.run_in_executor` may help. + +This question is really about :mod:`asyncio`. Please review the advice about +:ref:`asyncio-multithreading` in the Python documentation. From 3fb06c347a6039959c38ff41a7d1057971ca17c5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 16:12:30 +0200 Subject: [PATCH 271/762] Updated changelog for 10.3. --- docs/project/changelog.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 26e9a5cdc..9bdc99db5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,11 @@ Backwards-incompatible changes See :doc:`../howto/sansio` for details. +Improvements +............ + +* Reduced noise in logs when :mod:`ssl` or :mod:`zlib` raise exceptions. + 10.2 ---- From 0d9a251a338f549b91e13cc346b0f73fd5964493 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 16:12:42 +0200 Subject: [PATCH 272/762] Release version 10.3 --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9bdc99db5..398fa65a9 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.3 ---- -*In development* +*April 17, 2022* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 65123b3fa..c30bfd68f 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "10.3" From 3742b429a25d5f51511b626435c6a1acdd9027a3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Apr 2022 16:14:11 +0200 Subject: [PATCH 273/762] Start version 10.4. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 398fa65a9..105065c8c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +10.4 +---- + +*In development* + 10.3 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index c30bfd68f..29d658ce2 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "10.3" +tag = version = commit = "10.4" if not released: # pragma: no cover From c590dc8a69ca9f3052aa34f7db017b9a253a48a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 23 Apr 2022 19:28:43 +0200 Subject: [PATCH 274/762] Tweak coverage setup. This makes it possible to run coverage with the --include option without creating a conflict with the --source option. --- Makefile | 3 +-- setup.cfg | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 6f8130840..a38634aee 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,7 @@ test: python -m unittest coverage: - coverage erase - coverage run -m unittest + coverage run --source websockets,tests -m unittest coverage html coverage report --show-missing --fail-under=100 diff --git a/setup.cfg b/setup.cfg index 9ff1939c7..91f769620 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,9 +22,6 @@ lines_after_imports = 2 branch = True omit = */__main__.py -source = - websockets - tests [coverage:paths] source = From 9c87d43f1d7bbf6847350087aae74fd35f73a642 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 24 Apr 2022 10:24:30 +0200 Subject: [PATCH 275/762] Fix two typos. Found by codereview.doctor (among many false positives). --- example/deployment/kubernetes/benchmark.py | 2 +- tests/test_headers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/deployment/kubernetes/benchmark.py b/example/deployment/kubernetes/benchmark.py index 600c47316..22ee4c5bd 100755 --- a/example/deployment/kubernetes/benchmark.py +++ b/example/deployment/kubernetes/benchmark.py @@ -11,7 +11,7 @@ async def run(client_id, messages): async with websockets.connect(URI) as websocket: for message_id in range(messages): - await websocket.send("{client_id}:{message_id}") + await websocket.send(f"{client_id}:{message_id}") await websocket.recv() diff --git a/tests/test_headers.py b/tests/test_headers.py index a2d51fc6a..4ebd8b90c 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -140,7 +140,7 @@ def test_parse_subprotocol_invalid_header(self): for header in [ # Truncated examples "", - ",\t," + ",\t,", # Wrong delimiter "foo; bar", ]: From ac45051ecb34c07407c4315a08729e9967be8e51 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 13:53:03 +0200 Subject: [PATCH 276/762] Fix typing in client connection initialization. --- src/websockets/legacy/client.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index fadc3efe8..29550b694 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -651,9 +651,8 @@ async def __await_impl_timeout__(self) -> WebSocketClientProtocol: async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): - transport, protocol = await self._create_connection() - protocol = cast(WebSocketClientProtocol, protocol) - + _transport, _protocol = await self._create_connection() + protocol = cast(WebSocketClientProtocol, _protocol) try: await protocol.handshake( self._wsuri, From ebb591811e9626272168ed1ee1cf653d0760d92a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 14:30:18 +0200 Subject: [PATCH 277/762] Move type: ignore comment. It needs to be in a new location to be recognized. --- src/websockets/legacy/protocol.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index d1d52bfac..9a7ad1d35 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -684,9 +684,7 @@ async def send( try: # message_chunk = anext(aiter_message) without anext # https://github.com/python/mypy/issues/5738 - message_chunk = await type(aiter_message).__anext__( # type: ignore - aiter_message - ) + message_chunk = await type(aiter_message).__anext__(aiter_message) # type: ignore # noqa except StopAsyncIteration: return opcode, data = prepare_data(message_chunk) From 670f7c5e8c934c29f46f24089a84dd65709c1d13 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 14:33:56 +0200 Subject: [PATCH 278/762] Rename chunks to fragments. --- src/websockets/legacy/protocol.py | 38 +++++++++++++++---------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 9a7ad1d35..b5cd64618 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -645,10 +645,10 @@ async def send( iter_message = iter(message) try: - message_chunk = next(iter_message) + fragment = next(iter_message) except StopIteration: return - opcode, data = prepare_data(message_chunk) + opcode, data = prepare_data(fragment) self._fragmented_message_waiter = asyncio.Future() try: @@ -656,8 +656,8 @@ async def send( await self.write_frame(False, opcode, data) # Other fragments. - for message_chunk in iter_message: - confirm_opcode, data = prepare_data(message_chunk) + for fragment in iter_message: + confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") await self.write_frame(False, OP_CONT, data) @@ -682,12 +682,12 @@ async def send( # https://github.com/python/mypy/issues/5738 aiter_message = type(message).__aiter__(message) # type: ignore try: - # message_chunk = anext(aiter_message) without anext + # fragment = anext(aiter_message) without anext # https://github.com/python/mypy/issues/5738 - message_chunk = await type(aiter_message).__anext__(aiter_message) # type: ignore # noqa + fragment = await type(aiter_message).__anext__(aiter_message) # type: ignore # noqa except StopAsyncIteration: return - opcode, data = prepare_data(message_chunk) + opcode, data = prepare_data(fragment) self._fragmented_message_waiter = asyncio.Future() try: @@ -698,8 +698,8 @@ async def send( # https://github.com/python/mypy/issues/5738 # coverage reports this code as not covered, but it is # exercised by tests - changing it breaks the tests! - async for message_chunk in aiter_message: # type: ignore # pragma: no cover # noqa - confirm_opcode, data = prepare_data(message_chunk) + async for fragment in aiter_message: # type: ignore # pragma: no cover # noqa + confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") await self.write_frame(False, OP_CONT, data) @@ -1028,7 +1028,7 @@ async def read_message(self) -> Optional[Data]: return frame.data.decode("utf-8") if text else frame.data # 5.4. Fragmentation - chunks: List[Data] = [] + fragments: List[Data] = [] max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") @@ -1036,14 +1036,14 @@ async def read_message(self) -> Optional[Data]: if max_size is None: def append(frame: Frame) -> None: - nonlocal chunks - chunks.append(decoder.decode(frame.data, frame.fin)) + nonlocal fragments + fragments.append(decoder.decode(frame.data, frame.fin)) else: def append(frame: Frame) -> None: - nonlocal chunks, max_size - chunks.append(decoder.decode(frame.data, frame.fin)) + nonlocal fragments, max_size + fragments.append(decoder.decode(frame.data, frame.fin)) assert isinstance(max_size, int) max_size -= len(frame.data) @@ -1051,14 +1051,14 @@ def append(frame: Frame) -> None: if max_size is None: def append(frame: Frame) -> None: - nonlocal chunks - chunks.append(frame.data) + nonlocal fragments + fragments.append(frame.data) else: def append(frame: Frame) -> None: - nonlocal chunks, max_size - chunks.append(frame.data) + nonlocal fragments, max_size + fragments.append(frame.data) assert isinstance(max_size, int) max_size -= len(frame.data) @@ -1072,7 +1072,7 @@ def append(frame: Frame) -> None: raise ProtocolError("unexpected opcode") append(frame) - return ("" if text else b"").join(chunks) + return ("" if text else b"").join(fragments) async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: """ From 50483798cdbdf0ab3edee5cd8d4a9b4069d7dfc4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 07:29:59 +0200 Subject: [PATCH 279/762] Bump cibuildwheel version. Fix #1169. --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index a6df67743..6f529cfdb 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -47,7 +47,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.2.2 + uses: pypa/cibuildwheel@v2.5.0 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From 3e6f0c474fbb89909988cdfdfa8dbee7ac9cb84d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 07:32:21 +0200 Subject: [PATCH 280/762] Use the latest version of each OS. --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 6f529cfdb..6933c941d 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -30,7 +30,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04, windows-2019, macOS-10.15] + os: [ubuntu-latest, windows-latest, macOS-latest] steps: - name: Check out repository From ef8a4de1e1a97ad2ed15637b4abaab89c819a2d6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 08:06:39 +0200 Subject: [PATCH 281/762] Run a subset of tests on branches. --- .github/workflows/tests.yml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9412c0ea5..09ec750df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,7 +38,17 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.7", "3.8", "3.9", "3.10", "pypy-3.9"] + python: + - "3.7" + - "3.8" + - "3.9" + - "3.10" + - "pypy-3.9" + is_main: + - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + exclude: + - python: "pypy-3.9" + is_main: false steps: - name: Check out repository uses: actions/checkout@v2 From 8b603d824786de82330d21b89a109e720082440d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 08:26:51 +0200 Subject: [PATCH 282/762] Test on all PyPy versions. Refs #1169. --- .github/workflows/tests.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 09ec750df..03ce3aff9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,10 +43,16 @@ jobs: - "3.8" - "3.9" - "3.10" + - "pypy-3.7" + - "pypy-3.8" - "pypy-3.9" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: + - python: "pypy-3.7" + is_main: false + - python: "pypy-3.8" + is_main: false - python: "pypy-3.9" is_main: false steps: From 71dd4f881365a910972c6dc1326638e90eda5cf5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 18:51:04 +0200 Subject: [PATCH 283/762] Stop building wheels with every push to main. No need to heat the planet. It's working. Keep the possibility to do it manually. --- .github/workflows/wheels.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 6933c941d..6966b647f 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -2,10 +2,9 @@ name: Build wheels on: push: - branches: - - main tags: - '*' + workflow_dispatch: jobs: sdist: From 6de9572076500959d7c8fa1d5889acc1fa06edac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 3 May 2022 08:27:28 +0200 Subject: [PATCH 284/762] Standardize on multiline list syntax. --- .github/workflows/wheels.yml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 6966b647f..0322055b1 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -29,8 +29,10 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest, macOS-latest] - + os: + - ubuntu-latest + - windows-latest + - macOS-latest steps: - name: Check out repository uses: actions/checkout@v2 @@ -58,7 +60,9 @@ jobs: upload_pypi: name: Upload to PyPI - needs: [sdist, wheels] + needs: + - sdist + - wheels runs-on: ubuntu-latest if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: From c4a4b6f45af607431c5707d56420bbaf471bbb6e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 5 May 2022 08:32:07 +0200 Subject: [PATCH 285/762] Remove all # type: ignore but one. Prefer cast() instead. The only remaining # type: ignore is for the legacy_recv behavior. --- src/websockets/legacy/framing.py | 4 ++-- src/websockets/legacy/protocol.py | 22 ++++++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index c4de7eb28..04cddc0e0 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import struct from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple @@ -163,7 +162,8 @@ def parse_close(data: bytes) -> Tuple[int, str]: UnicodeDecodeError: if the reason isn't valid UTF-8. """ - return dataclasses.astuple(Close.parse(data)) # type: ignore + close = Close.parse(data) + return close.code, close.reason def serialize_close(code: int, reason: str) -> bytes: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b5cd64618..651ff824d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -14,6 +14,7 @@ AsyncIterable, AsyncIterator, Awaitable, + Callable, Deque, Dict, Iterable, @@ -678,13 +679,19 @@ async def send( # Fragmented message -- asynchronous iterator elif isinstance(message, AsyncIterable): - # aiter_message = aiter(message) without aiter - # https://github.com/python/mypy/issues/5738 - aiter_message = type(message).__aiter__(message) # type: ignore + # Implement aiter_message = aiter(message) without aiter + # Work around https://github.com/python/mypy/issues/5738 + aiter_message = cast( + Callable[[AsyncIterable[Data]], AsyncIterator[Data]], + type(message).__aiter__, + )(message) try: - # fragment = anext(aiter_message) without anext - # https://github.com/python/mypy/issues/5738 - fragment = await type(aiter_message).__anext__(aiter_message) # type: ignore # noqa + # Implement fragment = anext(aiter_message) without anext + # Work around https://github.com/python/mypy/issues/5738 + fragment = await cast( + Callable[[AsyncIterator[Data]], Awaitable[Data]], + type(aiter_message).__anext__, + )(aiter_message) except StopAsyncIteration: return opcode, data = prepare_data(fragment) @@ -695,10 +702,9 @@ async def send( await self.write_frame(False, opcode, data) # Other fragments. - # https://github.com/python/mypy/issues/5738 # coverage reports this code as not covered, but it is # exercised by tests - changing it breaks the tests! - async for fragment in aiter_message: # type: ignore # pragma: no cover # noqa + async for fragment in aiter_message: # pragma: no cover confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") From 225a57c450a4064e04dcbcc50dc85dc7e9452628 Mon Sep 17 00:00:00 2001 From: Shoaib Date: Sat, 25 Jun 2022 03:33:42 +0500 Subject: [PATCH 286/762] Replaced unsafe reference to self.pings with a variable --- src/websockets/legacy/protocol.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 651ff824d..2d1391a01 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -846,11 +846,12 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: while data is None or data in self.pings: data = struct.pack("!I", random.getrandbits(32)) - self.pings[data] = self.loop.create_future() + ping_future = self.loop.create_future() + self.pings[data] = ping_future await self.write_frame(True, OP_PING, data) - return asyncio.shield(self.pings[data]) + return asyncio.shield(ping_future) async def pong(self, data: Data = b"") -> None: """ From 57a1325795cc7b9166a5b4599c34c4869d8b2ab6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 17 Jul 2022 09:47:16 +0200 Subject: [PATCH 287/762] Update link to django-sesame documentation. --- docs/howto/django.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 34cce58e9..5bb2e296b 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -113,10 +113,10 @@ your settings module. The connection handler reads the first message received from the client, which is expected to contain a django-sesame token. Then it authenticates the user -with ``get_user()``, the API for `authentication outside views`_. If +with ``get_user()``, the API for `authentication outside a view`_. If authentication fails, it closes the connection and exits. -.. _authentication outside views: https://github.com/aaugustin/django-sesame#authentication-outside-views +.. _authentication outside a view: https://django-sesame.readthedocs.io/en/stable/howto.html#outside-a-view When we call an API that makes a database query such as ``get_user()``, we wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't From fad12d475e4c65aff9040d9e5f34a4924c470774 Mon Sep 17 00:00:00 2001 From: Tim Gates Date: Sun, 31 Jul 2022 13:17:18 +1000 Subject: [PATCH 288/762] docs: Fix a few typos There are small typos in: - src/websockets/legacy/protocol.py - src/websockets/legacy/server.py - tests/legacy/test_protocol.py - tests/legacy/utils.py Fixes: - Should read `suspect` rather than `supect`. - Should read `acknowledged` rather than `acknowleged`. - Should read `attempts` rather than `attemps`. - Should read `asynchronous` rather than `asychronous`. Signed-off-by: Tim Gates --- src/websockets/legacy/protocol.py | 6 +++--- src/websockets/legacy/server.py | 2 +- tests/legacy/test_protocol.py | 4 ++-- tests/legacy/utils.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 2d1391a01..04726033e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1303,7 +1303,7 @@ async def close_connection(self) -> None: if self.is_client and hasattr(self, "transfer_data_task"): if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. + # I suspect a bug in coverage. Ignore it for now. return # pragma: no cover if self.debug: self.logger.debug("! timed out waiting for TCP close") @@ -1323,7 +1323,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. + # I suspect a bug in coverage. Ignore it for now. return # pragma: no cover if self.debug: self.logger.debug("! timed out waiting for TCP close") @@ -1361,7 +1361,7 @@ async def close_transport(self) -> None: # connection_lost() is called quickly after aborting. # Coverage marks this line as a partially executed branch. - # I supect a bug in coverage. Ignore it for now. + # I suspect a bug in coverage. Ignore it for now. await self.wait_for_connection_lost() # pragma: no cover async def wait_for_connection_lost(self) -> bool: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3e51db1b7..6b4eccd02 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -153,7 +153,7 @@ async def handler(self) -> None: Handle the lifecycle of a WebSocket connection. Since this method doesn't have a caller able to handle exceptions, it - attemps to log relevant ones and guarantees that the TCP connection is + attempts to log relevant ones and guarantees that the TCP connection is closed before exiting. """ diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 1672ab1ed..ed6424694 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1188,7 +1188,7 @@ def test_keepalive_ping(self): def test_keepalive_ping_not_acknowledged_closes_connection(self): self.restart_protocol_with_keepalive_ping() - # Ping is sent at 3ms and not acknowleged. + # Ping is sent at 3ms and not acknowledged. self.loop.run_until_complete(asyncio.sleep(4 * MS)) (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) @@ -1257,7 +1257,7 @@ def test_keepalive_ping_with_no_ping_interval(self): def test_keepalive_ping_with_no_ping_timeout(self): self.restart_protocol_with_keepalive_ping(ping_timeout=None) - # Ping is sent at 3ms and not acknowleged. + # Ping is sent at 3ms and not acknowledged. self.loop.run_until_complete(asyncio.sleep(4 * MS)) (ping_1,) = tuple(self.protocol.pings) self.assertOneFrameSent(True, OP_PING, ping_1) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 1fa2b53c8..fd5dfc294 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -18,7 +18,7 @@ def __init_subclass__(cls, **kwargs): """ Convert test coroutines to test functions. - This supports asychronous tests transparently. + This supports asynchronous tests transparently. """ super().__init_subclass__(**kwargs) From d84108b2bf2256187509276a7cb2a905f39af392 Mon Sep 17 00:00:00 2001 From: Dmitry Lavrentev Date: Wed, 10 Aug 2022 19:31:29 +0300 Subject: [PATCH 289/762] Added default None values for Optional fields in WebSocketURI for compatibility with Dataclass syntax. --- src/websockets/uri.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/uri.py b/src/websockets/uri.py index fff0c3806..385090f66 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -33,8 +33,8 @@ class WebSocketURI: port: int path: str query: str - username: Optional[str] - password: Optional[str] + username: Optional[str] = None + password: Optional[str] = None @property def resource_name(self) -> str: From 5d2a04ae65b48c8d02e37318bc6e8d43be9ae18e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 09:05:34 +0200 Subject: [PATCH 290/762] Add usage docs for start_serving and serve_forever. The only example for Python I found was https://bugs.python.org/issue32662: async def main(): srv = await asyncio.start_server(...) async with srv: await srv.serve_forever() asyncio.run(main()) It looks pretty bad to me: it starts the server thrice and stops it twice, relying on idempotency to avoid issues. I dislike this style of "pile it up to be sure that it works" so I'm showing a subset of possibilities only. Ref #1197. --- src/websockets/legacy/server.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 6b4eccd02..2ee469825 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -833,6 +833,13 @@ async def start_serving(self) -> None: """ See :meth:`asyncio.Server.start_serving`. + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + """ await self.server.start_serving() # pragma: no cover @@ -840,6 +847,17 @@ async def serve_forever(self) -> None: """ See :meth:`asyncio.Server.serve_forever`. + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + """ await self.server.serve_forever() # pragma: no cover From c7c8ecc12323855c37f3c77874ade5fdd082890f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 09:22:26 +0200 Subject: [PATCH 291/762] Escape square brackets. They're special characters in some shells. --- docs/howto/autoreload.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst index edd87d0fd..fc736a591 100644 --- a/docs/howto/autoreload.rst +++ b/docs/howto/autoreload.rst @@ -16,7 +16,7 @@ Install watchdog_ with the ``watchmedo`` shell utility: .. code-block:: console - $ pip install watchdog[watchmedo] + $ pip install 'watchdog[watchmedo]' .. _watchdog: https://pypi.org/project/watchdog/ From 7980c00812c683ddbfbf244f6ec0ca6faca8c023 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 09:27:58 +0200 Subject: [PATCH 292/762] Standardize title style. --- docs/howto/extensions.rst | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index 2baead3f0..3c8a7d72a 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -1,5 +1,5 @@ -Writing an extension -==================== +Write an extension +================== .. currentmodule:: websockets.extensions @@ -28,5 +28,3 @@ As a consequence, writing an extension requires implementing several classes: websockets provides base classes for extension factories and extensions. See :class:`ClientExtensionFactory`, :class:`ServerExtensionFactory`, and :class:`Extension` for details. - - From 28e1c7d60214f85d058d4f63629dc94a7e686dc4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 11:58:57 +0200 Subject: [PATCH 293/762] Move FAQ to top-level. This should make it more discoverable. Also: * split FAQ in multiple pages * move quick start to the howto section --- docs/faq/asyncio.rst | 73 ++++ docs/faq/client.rst | 83 ++++ docs/faq/common.rst | 151 +++++++ docs/faq/index.rst | 21 + docs/faq/misc.rst | 26 ++ docs/faq/server.rst | 279 ++++++++++++ docs/howto/faq.rst | 621 --------------------------- docs/howto/index.rst | 8 +- docs/{intro => howto}/quickstart.rst | 0 docs/index.rst | 3 +- docs/intro/index.rst | 6 +- docs/reference/index.rst | 2 +- 12 files changed, 644 insertions(+), 629 deletions(-) create mode 100644 docs/faq/asyncio.rst create mode 100644 docs/faq/client.rst create mode 100644 docs/faq/common.rst create mode 100644 docs/faq/index.rst create mode 100644 docs/faq/misc.rst create mode 100644 docs/faq/server.rst delete mode 100644 docs/howto/faq.rst rename docs/{intro => howto}/quickstart.rst (100%) diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst new file mode 100644 index 000000000..7db43e1be --- /dev/null +++ b/docs/faq/asyncio.rst @@ -0,0 +1,73 @@ +asyncio usage +============= + +.. currentmodule:: websockets + +How do I run two coroutines in parallel? +---------------------------------------- + +You must start two tasks, which the event loop will run concurrently. You can +achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. + +Keep track of the tasks and make sure they terminate or you cancel them when +the connection terminates. + +Why does my program never receive any messages? +----------------------------------------------- + +Your program runs a coroutine that never yields control to the event loop. The +coroutine that receives messages never gets a chance to run. + +Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough +to yield control. Awaiting a coroutine may yield control, but there's no +guarantee that it will. + +For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields +control when send buffers are full, which never happens in most practical +cases. + +If you run a loop that contains only synchronous operations and +a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield +control explicitly with :func:`asyncio.sleep`:: + + async def producer(websocket): + message = generate_next_message() + await websocket.send(message) + await asyncio.sleep(0) # yield control to the event loop + +:func:`asyncio.sleep` always suspends the current task, allowing other tasks +to run. This behavior is documented precisely because it isn't expected from +every coroutine. + +See `issue 867`_. + +.. _issue 867: https://github.com/aaugustin/websockets/issues/867 + +Why am I having problems with threads? +-------------------------------------- + +You shouldn't use threads. Use tasks instead. + +Indeed, when you chose websockets, you chose :mod:`asyncio` as the primary +framework to handle concurrency. This choice is mutually exclusive with +:mod:`threading`. + +If you believe that you need to run websockets in a thread and some logic in +another thread, you should run that logic in a :class:`~asyncio.Task` instead. + +If you believe that you cannot run that logic in the same event loop because it +will block websockets, :meth:`~asyncio.loop.run_in_executor` may help. + +This question is really about :mod:`asyncio`. Please review the advice about +:ref:`asyncio-multithreading` in the Python documentation. + +Why does my simple program misbehave mysteriously? +-------------------------------------------------- + +You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which +blocks the event loop and prevents asyncio from operating normally. + +This may lead to messages getting send but not received, to connection +timeouts, and to unexpected results of shotgun debugging e.g. adding an +unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +makes the program functional. diff --git a/docs/faq/client.rst b/docs/faq/client.rst new file mode 100644 index 000000000..5b39cf1ec --- /dev/null +++ b/docs/faq/client.rst @@ -0,0 +1,83 @@ +Client +====== + +.. currentmodule:: websockets + +Why does the client close the connection prematurely? +----------------------------------------------------- + +You're exiting the context manager prematurely. Wait for the work to be +finished before exiting. + +For example, if your code has a structure similar to:: + + async with connect(...) as websocket: + asyncio.create_task(do_some_work()) + +change it to:: + + async with connect(...) as websocket: + await do_some_work() + +How do I access HTTP headers? +----------------------------- + +Once the connection is established, HTTP headers are available in +:attr:`~client.WebSocketClientProtocol.request_headers` and +:attr:`~client.WebSocketClientProtocol.response_headers`. + +How do I set HTTP headers? +-------------------------- + +To set the ``Origin``, ``Sec-WebSocket-Extensions``, or +``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the +``origin``, ``extensions``, or ``subprotocols`` arguments of +:func:`~client.connect`. + +To set other HTTP headers, for example the ``Authorization`` header, use the +``extra_headers`` argument:: + + async with connect(..., extra_headers={"Authorization": ...}) as websocket: + ... + +How do I close a connection? +---------------------------- + +The easiest is to use :func:`~client.connect` as a context manager:: + + async with connect(...) as websocket: + ... + +The connection is closed when exiting the context manager. + +How do I reconnect when the connection drops? +--------------------------------------------- + +Use :func:`~client.connect` as an asynchronous iterator:: + + async for websocket in websockets.connect(...): + try: + ... + except websockets.ConnectionClosed: + continue + +Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions +will break out of the loop. + +How do I stop a client that is processing messages in a loop? +------------------------------------------------------------- + +You can close the connection. + +Here's an example that terminates cleanly when it receives SIGTERM on Unix: + +.. literalinclude:: ../../example/shutdown_client.py + :emphasize-lines: 10-13 + +How do I disable TLS/SSL certificate verification? +-------------------------------------------------- + +Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. + +:func:`~client.connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection`. diff --git a/docs/faq/common.rst b/docs/faq/common.rst new file mode 100644 index 000000000..dff64f67c --- /dev/null +++ b/docs/faq/common.rst @@ -0,0 +1,151 @@ +Both sides +========== + +.. currentmodule:: websockets + +What does ``ConnectionClosedError: no close frame received or sent`` mean? +-------------------------------------------------------------------------- + +If you're seeing this traceback in the logs of a server: + +.. code-block:: pytb + + connection handler failed + Traceback (most recent call last): + ... + asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: no close frame received or sent + +or if a client crashes with this traceback: + +.. code-block:: pytb + + Traceback (most recent call last): + ... + ConnectionResetError: [Errno 54] Connection reset by peer + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: no close frame received or sent + +it means that the TCP connection was lost. As a consequence, the WebSocket +connection was closed without receiving and sending a close frame, which is +abnormal. + +You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it +from being logged. + +There are several reasons why long-lived connections may be lost: + +* End-user devices tend to lose network connectivity often and unpredictably + because they can move out of wireless network coverage, get unplugged from + a wired network, enter airplane mode, be put to sleep, etc. +* HTTP load balancers or proxies that aren't configured for long-lived + connections may terminate connections after a short amount of time, usually + 30 seconds, despite websockets' keepalive mechanism. + +If you're facing a reproducible issue, :ref:`enable debug logs ` to +see when and how connections are closed. + +What does ``ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received`` mean? +----------------------------------------------------------------------------------------------------------------------- + +If you're seeing this traceback in the logs of a server: + +.. code-block:: pytb + + connection handler failed + Traceback (most recent call last): + ... + asyncio.exceptions.CancelledError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + +or if a client crashes with this traceback: + +.. code-block:: pytb + + Traceback (most recent call last): + ... + asyncio.exceptions.CancelledError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + +it means that the WebSocket connection suffered from excessive latency and was +closed after reaching the timeout of websockets' keepalive mechanism. + +You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it +from being logged. + +There are two main reasons why latency may increase: + +* Poor network connectivity. +* More traffic than the recipient can handle. + +See the discussion of :doc:`timeouts <../topics/timeouts>` for details. + +If websockets' default timeout of 20 seconds is too short for your use case, +you can adjust it with the ``ping_timeout`` argument. + +How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? +-------------------------------------------------------------------------------- + +Use :func:`~asyncio.wait_for`:: + + await asyncio.wait_for(websocket.recv(), timeout=10) + +This technique works for most APIs, except for asynchronous context managers. +See `issue 574`_. + +.. _issue 574: https://github.com/aaugustin/websockets/issues/574 + +How can I pass arguments to a custom protocol subclass? +------------------------------------------------------- + +You can bind additional arguments to the protocol factory with +:func:`functools.partial`:: + + import asyncio + import functools + import websockets + + class MyServerProtocol(websockets.WebSocketServerProtocol): + def __init__(self, extra_argument, *args, **kwargs): + super().__init__(*args, **kwargs) + # do something with extra_argument + + create_protocol = functools.partial(MyServerProtocol, extra_argument='spam') + start_server = websockets.serve(..., create_protocol=create_protocol) + +This example was for a server. The same pattern applies on a client. + +How do I keep idle connections open? +------------------------------------ + +websockets sends pings at 20 seconds intervals to keep the connection open. + +It closes the connection if it doesn't get a pong within 20 seconds. + +You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. + +See :doc:`../topics/timeouts` for details. + +How do I respond to pings? +-------------------------- + +Don't bother; websockets takes care of responding to pings with pongs. diff --git a/docs/faq/index.rst b/docs/faq/index.rst new file mode 100644 index 000000000..9d5b0d538 --- /dev/null +++ b/docs/faq/index.rst @@ -0,0 +1,21 @@ +Frequently asked questions +========================== + +.. currentmodule:: websockets + +.. admonition:: Many questions asked in websockets' issue tracker are really + about :mod:`asyncio`. + :class: seealso + + Python's documentation about `developing with asyncio`_ is a good + complement. + + .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html + +.. toctree:: + + server + client + common + asyncio + misc diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst new file mode 100644 index 000000000..3606937b8 --- /dev/null +++ b/docs/faq/misc.rst @@ -0,0 +1,26 @@ +Miscellaneous +============= + +.. currentmodule:: websockets + +Can I use websockets without ``async`` and ``await``? +..................................................... + +No, there is no convenient way to do this. You should use another library. + +Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? +............................................................................ + +No, there aren't. + +websockets provides high-level, coroutine-based APIs. Compared to callbacks, +coroutines make it easier to manage control flow in concurrent code. + +If you prefer callback-based APIs, you should use another library. + +Why do I get the error: ``module 'websockets' has no attribute '...'``? +....................................................................... + +Often, this is because you created a script called ``websockets.py`` in your +current working directory. Then ``import websockets`` imports this module +instead of the websockets library. diff --git a/docs/faq/server.rst b/docs/faq/server.rst new file mode 100644 index 000000000..a83f879b9 --- /dev/null +++ b/docs/faq/server.rst @@ -0,0 +1,279 @@ +Server +====== + +.. currentmodule:: websockets + +Why does the server close the connection prematurely? +----------------------------------------------------- + +Your connection handler exits prematurely. Wait for the work to be finished +before returning. + +For example, if your handler has a structure similar to:: + + async def handler(websocket): + asyncio.create_task(do_some_work()) + +change it to:: + + async def handler(websocket): + await do_some_work() + +Why does the server close the connection after one message? +----------------------------------------------------------- + +Your connection handler exits after processing one message. Write a loop to +process multiple messages. + +For example, if your handler looks like this:: + + async def handler(websocket): + print(websocket.recv()) + +change it like this:: + + async def handler(websocket): + async for message in websocket: + print(message) + +*Don't feel bad if this happens to you — it's the most common question in +websockets' issue tracker :-)* + +Why can only one client connect at a time? +------------------------------------------ + +Your connection handler blocks the event loop. Look for blocking calls. +Any call that may take some time must be asynchronous. + +For example, if you have:: + + async def handler(websocket): + time.sleep(1) + +change it to:: + + async def handler(websocket): + await asyncio.sleep(1) + +This is part of learning asyncio. It isn't specific to websockets. + +See also Python's documentation about `running blocking code`_. + +.. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code + +.. _send-message-to-all-users: + +How do I send a message to all users? +------------------------------------- + +Record all connections in a global variable:: + + CONNECTIONS = set() + + async def handler(websocket): + CONNECTIONS.add(websocket) + try: + await websocket.wait_closed() + finally: + CONNECTIONS.remove(websocket) + +Then, call :func:`~websockets.broadcast`:: + + import websockets + + def message_all(message): + websockets.broadcast(CONNECTIONS, message) + +If you're running multiple server processes, make sure you call ``message_all`` +in each process. + +.. _send-message-to-single-user: + +How do I send a message to a single user? +----------------------------------------- + +Record connections in a global variable, keyed by user identifier:: + + CONNECTIONS = {} + + async def handler(websocket): + user_id = ... # identify user in your app's context + CONNECTIONS[user_id] = websocket + try: + await websocket.wait_closed() + finally: + del CONNECTIONS[user_id] + +Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: + + async def message_user(user_id, message): + websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected + await websocket.send(message) # may raise websockets.ConnectionClosed + +Add error handling according to the behavior you want if the user disconnected +before the message could be sent. + +This example supports only one connection per user. To support concurrent +connects by the same user, you can change ``CONNECTIONS`` to store a set of +connections for each user. + +If you're running multiple server processes, call ``message_user`` in each +process. The process managing the user's connection sends the message; other +processes do nothing. + +When you reach a scale where server processes cannot keep up with the stream of +all messages, you need a better architecture. For example, you could deploy an +external publish / subscribe system such as Redis_. Server processes would +subscribe their clients. Then, they would receive messages only for the +connections that they're managing. + +.. _Redis: https://redis.io/ + +How do I send a message to a channel, a topic, or some users? +------------------------------------------------------------- + +websockets doesn't provide built-in publish / subscribe functionality. + +Record connections in a global variable, keyed by user identifier, as shown in +:ref:`How do I send a message to a single user?` + +Then, build the set of recipients and broadcast the message to them, as shown in +:ref:`How do I send a message to all users?` + +:doc:`../howto/django` contains a complete implementation of this pattern. + +Again, as you scale, you may reach the performance limits of a basic in-process +implementation. You may need an external publish / subscribe system like Redis_. + +.. _Redis: https://redis.io/ + +How do I pass arguments to the connection handler? +-------------------------------------------------- + +You can bind additional arguments to the connection handler with +:func:`functools.partial`:: + + import asyncio + import functools + import websockets + + async def handler(websocket, extra_argument): + ... + + bound_handler = functools.partial(handler, extra_argument='spam') + start_server = websockets.serve(bound_handler, ...) + +Another way to achieve this result is to define the ``handler`` coroutine in +a scope where the ``extra_argument`` variable exists instead of injecting it +through an argument. + +How do I access the request path? +--------------------------------- + +It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. + +You may route a connection to different handlers depending on the request path:: + + async def handler(websocket): + if websocket.path == "/blue": + await blue_handler(websocket) + elif websocket.path == "/green": + await green_handler(websocket) + else: + # No handler for this path; close the connection. + return + +You may also route the connection based on the first message received from the +client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to +authenticate the connection before routing it, this is usually more convenient. + +Generally speaking, there is far less emphasis on the request path in WebSocket +servers than in HTTP servers. When a WebSockt server provides a single endpoint, +it may ignore the request path entirely. + +How do I access HTTP headers? +----------------------------- + +To access HTTP headers during the WebSocket handshake, you can override +:attr:`~server.WebSocketServerProtocol.process_request`:: + + async def process_request(self, path, request_headers): + authorization = request_headers["Authorization"] + +Once the connection is established, HTTP headers are available in +:attr:`~server.WebSocketServerProtocol.request_headers` and +:attr:`~server.WebSocketServerProtocol.response_headers`:: + + async def handler(websocket): + authorization = websocket.request_headers["Authorization"] + +How do I set HTTP headers? +-------------------------- + +To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in +the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` +arguments of :func:`~server.serve`. + +To set other HTTP headers, use the ``extra_headers`` argument. + +How do I get the IP address of the client? +------------------------------------------ + +It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: + + async def handler(websocket): + remote_ip = websocket.remote_address[0] + +How do I set the IP addresses my server listens on? +--------------------------------------------------- + +Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. + +:func:`~server.serve` accepts the same arguments as +:meth:`~asyncio.loop.create_server`. + +What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? +-------------------------------------------------------------------------------------------------------------------------- + +You are calling :func:`~server.serve` without a ``host`` argument in a context +where IPv6 isn't available. + +To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. + +Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. + +How do I close a connection? +---------------------------- + +websockets takes care of closing the connection when the handler exits. + +How do I stop a server? +----------------------- + +Exit the :func:`~server.serve` context manager. + +Here's an example that terminates cleanly when it receives SIGTERM on Unix: + +.. literalinclude:: ../../example/shutdown_server.py + :emphasize-lines: 12-15,18 + +How do I run HTTP and WebSocket servers on the same port? +--------------------------------------------------------- + +You don't. + +HTTP and WebSocket have widely different operational characteristics. Running +them with the same server becomes inconvenient when you scale. + +Providing a HTTP server is out of scope for websockets. It only aims at +providing a WebSocket server. + +There's limited support for returning HTTP responses with the +:attr:`~server.WebSocketServerProtocol.process_request` hook. + +If you need more, pick a HTTP server and run it separately. + +Alternatively, pick a HTTP framework that builds on top of ``websockets`` to +support WebSocket connections, like Sanic_. + +.. _Sanic: https://sanicframework.org/en/ diff --git a/docs/howto/faq.rst b/docs/howto/faq.rst deleted file mode 100644 index a103f677c..000000000 --- a/docs/howto/faq.rst +++ /dev/null @@ -1,621 +0,0 @@ -FAQ -=== - -.. currentmodule:: websockets - -.. admonition:: Many questions asked in websockets' issue tracker are really - about :mod:`asyncio`. - :class: seealso - - Python's documentation about `developing with asyncio`_ is a good - complement. - - .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html - -Server side ------------ - -Why does the server close the connection prematurely? -..................................................... - -Your connection handler exits prematurely. Wait for the work to be finished -before returning. - -For example, if your handler has a structure similar to:: - - async def handler(websocket): - asyncio.create_task(do_some_work()) - -change it to:: - - async def handler(websocket): - await do_some_work() - -Why does the server close the connection after one message? -........................................................... - -Your connection handler exits after processing one message. Write a loop to -process multiple messages. - -For example, if your handler looks like this:: - - async def handler(websocket): - print(websocket.recv()) - -change it like this:: - - async def handler(websocket): - async for message in websocket: - print(message) - -*Don't feel bad if this happens to you — it's the most common question in -websockets' issue tracker :-)* - -Why can only one client connect at a time? -.......................................... - -Your connection handler blocks the event loop. Look for blocking calls. -Any call that may take some time must be asynchronous. - -For example, if you have:: - - async def handler(websocket): - time.sleep(1) - -change it to:: - - async def handler(websocket): - await asyncio.sleep(1) - -This is part of learning asyncio. It isn't specific to websockets. - -See also Python's documentation about `running blocking code`_. - -.. _running blocking code: https://docs.python.org/3/library/asyncio-dev.html#running-blocking-code - -.. _send-message-to-all-users: - -How do I send a message to all users? -..................................... - -Record all connections in a global variable:: - - CONNECTIONS = set() - - async def handler(websocket): - CONNECTIONS.add(websocket) - try: - await websocket.wait_closed() - finally: - CONNECTIONS.remove(websocket) - -Then, call :func:`~websockets.broadcast`:: - - import websockets - - def message_all(message): - websockets.broadcast(CONNECTIONS, message) - -If you're running multiple server processes, make sure you call ``message_all`` -in each process. - -.. _send-message-to-single-user: - -How do I send a message to a single user? -......................................... - -Record connections in a global variable, keyed by user identifier:: - - CONNECTIONS = {} - - async def handler(websocket): - user_id = ... # identify user in your app's context - CONNECTIONS[user_id] = websocket - try: - await websocket.wait_closed() - finally: - del CONNECTIONS[user_id] - -Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: - - async def message_user(user_id, message): - websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected - await websocket.send(message) # may raise websockets.ConnectionClosed - -Add error handling according to the behavior you want if the user disconnected -before the message could be sent. - -This example supports only one connection per user. To support concurrent -connects by the same user, you can change ``CONNECTIONS`` to store a set of -connections for each user. - -If you're running multiple server processes, call ``message_user`` in each -process. The process managing the user's connection sends the message; other -processes do nothing. - -When you reach a scale where server processes cannot keep up with the stream of -all messages, you need a better architecture. For example, you could deploy an -external publish / subscribe system such as Redis_. Server processes would -subscribe their clients. Then, they would receive messages only for the -connections that they're managing. - -.. _Redis: https://redis.io/ - -How do I send a message to a channel, a topic, or some users? -............................................................. - -websockets doesn't provide built-in publish / subscribe functionality. - -Record connections in a global variable, keyed by user identifier, as shown in -:ref:`How do I send a message to a single user?` - -Then, build the set of recipients and broadcast the message to them, as shown in -:ref:`How do I send a message to all users?` - -:doc:`django` contains a complete implementation of this pattern. - -Again, as you scale, you may reach the performance limits of a basic in-process -implementation. You may need an external publish / subscribe system like Redis_. - -.. _Redis: https://redis.io/ - -How do I pass arguments to the connection handler? -.................................................. - -You can bind additional arguments to the connection handler with -:func:`functools.partial`:: - - import asyncio - import functools - import websockets - - async def handler(websocket, extra_argument): - ... - - bound_handler = functools.partial(handler, extra_argument='spam') - start_server = websockets.serve(bound_handler, ...) - -Another way to achieve this result is to define the ``handler`` coroutine in -a scope where the ``extra_argument`` variable exists instead of injecting it -through an argument. - -How do I access the request path? -................................. - -It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. - -You may route a connection to different handlers depending on the request path:: - - async def handler(websocket): - if websocket.path == "/blue": - await blue_handler(websocket) - elif websocket.path == "/green": - await green_handler(websocket) - else: - # No handler for this path; close the connection. - return - -You may also route the connection based on the first message received from the -client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to -authenticate the connection before routing it, this is usually more convenient. - -Generally speaking, there is far less emphasis on the request path in WebSocket -servers than in HTTP servers. When a WebSockt server provides a single endpoint, -it may ignore the request path entirely. - -How do I access HTTP headers? -............................. - -To access HTTP headers during the WebSocket handshake, you can override -:attr:`~server.WebSocketServerProtocol.process_request`:: - - async def process_request(self, path, request_headers): - authorization = request_headers["Authorization"] - -Once the connection is established, HTTP headers are available in -:attr:`~server.WebSocketServerProtocol.request_headers` and -:attr:`~server.WebSocketServerProtocol.response_headers`:: - - async def handler(websocket): - authorization = websocket.request_headers["Authorization"] - -How do I set HTTP headers? -.......................... - -To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in -the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` -arguments of :func:`~server.serve`. - -To set other HTTP headers, use the ``extra_headers`` argument. - -How do I get the IP address of the client? -.......................................... - -It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: - - async def handler(websocket): - remote_ip = websocket.remote_address[0] - -How do I set the IP addresses my server listens on? -................................................... - -Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. - -:func:`~server.serve` accepts the same arguments as -:meth:`~asyncio.loop.create_server`. - -What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? -.......................................................................................................................... - -You are calling :func:`~server.serve` without a ``host`` argument in a context -where IPv6 isn't available. - -To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. - -Refer to the documentation of :meth:`~asyncio.loop.create_server` for details. - -How do I close a connection? -............................ - -websockets takes care of closing the connection when the handler exits. - -How do I stop a server? -....................... - -Exit the :func:`~server.serve` context manager. - -Here's an example that terminates cleanly when it receives SIGTERM on Unix: - -.. literalinclude:: ../../example/shutdown_server.py - :emphasize-lines: 12-15,18 - - -How do I run HTTP and WebSocket servers on the same port? -......................................................... - -You don't. - -HTTP and WebSocket have widely different operational characteristics. Running -them with the same server becomes inconvenient when you scale. - -Providing a HTTP server is out of scope for websockets. It only aims at -providing a WebSocket server. - -There's limited support for returning HTTP responses with the -:attr:`~server.WebSocketServerProtocol.process_request` hook. - -If you need more, pick a HTTP server and run it separately. - -Alternatively, pick a HTTP framework that builds on top of ``websockets`` to -support WebSocket connections, like Sanic_. - -.. _Sanic: https://sanicframework.org/en/ - -Client side ------------ - -Why does the client close the connection prematurely? -..................................................... - -You're exiting the context manager prematurely. Wait for the work to be -finished before exiting. - -For example, if your code has a structure similar to:: - - async with connect(...) as websocket: - asyncio.create_task(do_some_work()) - -change it to:: - - async with connect(...) as websocket: - await do_some_work() - -How do I access HTTP headers? -............................. - -Once the connection is established, HTTP headers are available in -:attr:`~client.WebSocketClientProtocol.request_headers` and -:attr:`~client.WebSocketClientProtocol.response_headers`. - -How do I set HTTP headers? -.......................... - -To set the ``Origin``, ``Sec-WebSocket-Extensions``, or -``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the -``origin``, ``extensions``, or ``subprotocols`` arguments of -:func:`~client.connect`. - -To set other HTTP headers, for example the ``Authorization`` header, use the -``extra_headers`` argument:: - - async with connect(..., extra_headers={"Authorization": ...}) as websocket: - ... - -How do I close a connection? -............................ - -The easiest is to use :func:`~client.connect` as a context manager:: - - async with connect(...) as websocket: - ... - -The connection is closed when exiting the context manager. - -How do I reconnect when the connection drops? -............................................. - -Use :func:`~client.connect` as an asynchronous iterator:: - - async for websocket in websockets.connect(...): - try: - ... - except websockets.ConnectionClosed: - continue - -Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions -will break out of the loop. - -How do I stop a client that is processing messages in a loop? -............................................................. - -You can close the connection. - -Here's an example that terminates cleanly when it receives SIGTERM on Unix: - -.. literalinclude:: ../../example/shutdown_client.py - :emphasize-lines: 10-13 - -How do I disable TLS/SSL certificate verification? -.................................................. - -Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. - -:func:`~client.connect` accepts the same arguments as -:meth:`~asyncio.loop.create_connection`. - -asyncio usage -------------- - -How do I run two coroutines in parallel? -........................................ - -You must start two tasks, which the event loop will run concurrently. You can -achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. - -Keep track of the tasks and make sure they terminate or you cancel them when -the connection terminates. - -Why does my program never receive any messages? -............................................... - -Your program runs a coroutine that never yields control to the event loop. The -coroutine that receives messages never gets a chance to run. - -Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough -to yield control. Awaiting a coroutine may yield control, but there's no -guarantee that it will. - -For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields -control when send buffers are full, which never happens in most practical -cases. - -If you run a loop that contains only synchronous operations and -a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield -control explicitly with :func:`asyncio.sleep`:: - - async def producer(websocket): - message = generate_next_message() - await websocket.send(message) - await asyncio.sleep(0) # yield control to the event loop - -:func:`asyncio.sleep` always suspends the current task, allowing other tasks -to run. This behavior is documented precisely because it isn't expected from -every coroutine. - -See `issue 867`_. - -.. _issue 867: https://github.com/aaugustin/websockets/issues/867 - -Why does my simple program misbehave mysteriously? -.................................................. - -You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which -blocks the event loop and prevents asyncio from operating normally. - -This may lead to messages getting send but not received, to connection -timeouts, and to unexpected results of shotgun debugging e.g. adding an -unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` -makes the program functional. - -Both sides ----------- - -What does ``ConnectionClosedError: no close frame received or sent`` mean? -.......................................................................... - -If you're seeing this traceback in the logs of a server: - -.. code-block:: pytb - - connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - websockets.exceptions.ConnectionClosedError: no close frame received or sent - -or if a client crashes with this traceback: - -.. code-block:: pytb - - Traceback (most recent call last): - ... - ConnectionResetError: [Errno 54] Connection reset by peer - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - websockets.exceptions.ConnectionClosedError: no close frame received or sent - -it means that the TCP connection was lost. As a consequence, the WebSocket -connection was closed without receiving and sending a close frame, which is -abnormal. - -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. - -There are several reasons why long-lived connections may be lost: - -* End-user devices tend to lose network connectivity often and unpredictably - because they can move out of wireless network coverage, get unplugged from - a wired network, enter airplane mode, be put to sleep, etc. -* HTTP load balancers or proxies that aren't configured for long-lived - connections may terminate connections after a short amount of time, usually - 30 seconds, despite websockets' keepalive mechanism. - -If you're facing a reproducible issue, :ref:`enable debug logs ` to -see when and how connections are closed. - -What does ``ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received`` mean? -....................................................................................................................... - -If you're seeing this traceback in the logs of a server: - -.. code-block:: pytb - - connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received - -or if a client crashes with this traceback: - -.. code-block:: pytb - - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received - -it means that the WebSocket connection suffered from excessive latency and was -closed after reaching the timeout of websockets' keepalive mechanism. - -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. - -There are two main reasons why latency may increase: - -* Poor network connectivity. -* More traffic than the recipient can handle. - -See the discussion of :doc:`timeouts <../topics/timeouts>` for details. - -If websockets' default timeout of 20 seconds is too short for your use case, -you can adjust it with the ``ping_timeout`` argument. - -How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? -................................................................................ - -Use :func:`~asyncio.wait_for`:: - - await asyncio.wait_for(websocket.recv(), timeout=10) - -This technique works for most APIs, except for asynchronous context managers. -See `issue 574`_. - -.. _issue 574: https://github.com/aaugustin/websockets/issues/574 - -How can I pass arguments to a custom protocol subclass? -....................................................... - -You can bind additional arguments to the protocol factory with -:func:`functools.partial`:: - - import asyncio - import functools - import websockets - - class MyServerProtocol(websockets.WebSocketServerProtocol): - def __init__(self, extra_argument, *args, **kwargs): - super().__init__(*args, **kwargs) - # do something with extra_argument - - create_protocol = functools.partial(MyServerProtocol, extra_argument='spam') - start_server = websockets.serve(..., create_protocol=create_protocol) - -This example was for a server. The same pattern applies on a client. - -How do I keep idle connections open? -.................................... - -websockets sends pings at 20 seconds intervals to keep the connection open. - -It closes the connection if it doesn't get a pong within 20 seconds. - -You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. - -See :doc:`../topics/timeouts` for details. - -How do I respond to pings? -.......................... - -Don't bother; websockets takes care of responding to pings with pongs. - -Miscellaneous -------------- - -Can I use websockets without ``async`` and ``await``? -..................................................... - -No, there is no convenient way to do this. You should use another library. - -Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? -............................................................................ - -No, there aren't. - -websockets provides high-level, coroutine-based APIs. Compared to callbacks, -coroutines make it easier to manage control flow in concurrent code. - -If you prefer callback-based APIs, you should use another library. - -Why do I get the error: ``module 'websockets' has no attribute '...'``? -....................................................................... - -Often, this is because you created a script called ``websockets.py`` in your -current working directory. Then ``import websockets`` imports this module -instead of the websockets library. - -Why am I having problems with threads? -...................................... - -You shouldn't use threads. Use tasks instead. - -Indeed, when you chose websockets, you chose :mod:`asyncio` as the primary -framework to handle concurrency. This choice is mutually exclusive with -:mod:`threading`. - -If you believe that you need to run websockets in a thread and some logic in -another thread, you should run that logic in a :class:`~asyncio.Task` instead. - -If you believe that you cannot run that logic in the same event loop because it -will block websockets, :meth:`~asyncio.loop.run_in_executor` may help. - -This question is really about :mod:`asyncio`. Please review the advice about -:ref:`asyncio-multithreading` in the Python documentation. diff --git a/docs/howto/index.rst b/docs/howto/index.rst index d399cebc2..dafb72391 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -1,12 +1,18 @@ How-to guides ============= +In a hurry? Check out these examples. + +.. toctree:: + :titlesonly: + + quickstart + If you're stuck, perhaps you'll find the answer here. .. toctree:: :titlesonly: - faq cheatsheet patterns autoreload diff --git a/docs/intro/quickstart.rst b/docs/howto/quickstart.rst similarity index 100% rename from docs/intro/quickstart.rst rename to docs/howto/quickstart.rst diff --git a/docs/index.rst b/docs/index.rst index 07835a81c..00a9b5999 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,13 +52,14 @@ Also, websockets provides an interactive client: < Hello world! Connection closed: 1000 (OK). -Do you like it? Let's dive in! +Do you like it? :doc:`Let's dive in! ` .. toctree:: :hidden: intro/index howto/index + faq/index reference/index topics/index project/index diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 2c66dea9a..fe4e704d6 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -43,8 +43,4 @@ Learn how to build an real-time web application with websockets. In a hurry? ----------- -Check out these examples. - -.. toctree:: - - quickstart +Look at the :doc:`quick start guide <../howto/quickstart>`. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index a5ee57843..5f51a1c1c 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -43,8 +43,8 @@ and server connections, common methods are documented in a "Both sides" page. .. toctree:: :titlesonly: - client server + client common utilities exceptions From 8e096fc5a2a49b3c42fb2543547867cab11297fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Aug 2022 12:39:33 +0200 Subject: [PATCH 294/762] Add FAQ on benchmarking. Fix #1189. --- docs/faq/misc.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 3606937b8..9ef07ef2d 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -24,3 +24,16 @@ Why do I get the error: ``module 'websockets' has no attribute '...'``? Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. + +Why is websockets slower than another Python library in my benchmark? +..................................................................... + +Not all libraries are as feature-complete as websockets. For a fair benchmark, +you should disable features that the other library doesn't provide. Typically, +you may need to disable: + +* Compression: set ``compression=None`` +* Keepalive: set ``ping_interval=None`` +* UTF-8 decoding: send ``bytes`` rather than ``str`` + +If websockets is still slower than another Python library, please file a bug. From 2a07325cecf8be4732cd991ead0314c530dd7cdd Mon Sep 17 00:00:00 2001 From: Irfanuddin Date: Wed, 17 Aug 2022 10:09:05 +0200 Subject: [PATCH 295/762] Support removing Server and User-Agent headers. Fix #1193. --- docs/faq/client.rst | 3 ++ docs/faq/server.rst | 3 ++ docs/reference/client.rst | 6 ++-- docs/reference/server.rst | 6 ++-- src/websockets/client.py | 8 ++++- src/websockets/legacy/client.py | 13 ++++++-- src/websockets/legacy/server.py | 16 +++++++-- src/websockets/server.py | 12 +++++-- tests/legacy/test_client_server.py | 52 ++++++++++++++++++++++++++---- tests/test_client.py | 16 +++++++++ tests/test_server.py | 22 +++++++++++++ 11 files changed, 137 insertions(+), 20 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 5b39cf1ec..5bbbd6ded 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -34,6 +34,9 @@ To set the ``Origin``, ``Sec-WebSocket-Extensions``, or ``origin``, ``extensions``, or ``subprotocols`` arguments of :func:`~client.connect`. +To override the ``User-Agent`` header, use the ``user_agent_header`` argument. +Set it to :obj:`None` to remove the header. + To set other HTTP headers, for example the ``Authorization`` header, use the ``extra_headers`` argument:: diff --git a/docs/faq/server.rst b/docs/faq/server.rst index a83f879b9..68490d755 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -214,6 +214,9 @@ To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` arguments of :func:`~server.serve`. +To override the ``Server`` header, use the ``server_header`` argument. Set it to +:obj:`None` to remove the header. + To set other HTTP headers, use the ``extra_headers`` argument. How do I get the IP address of the client? diff --git a/docs/reference/client.rst b/docs/reference/client.rst index 379765397..3016e85d7 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -9,16 +9,16 @@ asyncio Opening a connection .................... -.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: -.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Using a connection .................. -.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 0e446f382..65f98842a 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -9,10 +9,10 @@ asyncio Starting a server ................. -.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: -.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Stopping a server @@ -37,7 +37,7 @@ Stopping a server Using a connection .................. -.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv diff --git a/src/websockets/client.py b/src/websockets/client.py index df8e53429..1c33fae0b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -64,6 +64,9 @@ class ClientConnection(Connection): logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. + user_agent_header: value of the ``User-Agent`` request header; + defauts to ``"Python/x.y.z websockets/X.Y"``; + :obj:`None` removes the header. """ @@ -73,6 +76,7 @@ def __init__( origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, + user_agent_header: Optional[str] = USER_AGENT, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -87,6 +91,7 @@ def __init__( self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols + self.user_agent_header = user_agent_header self.key = generate_key() def connect(self) -> Request: # noqa: F811 @@ -131,7 +136,8 @@ def connect(self) -> Request: # noqa: F811 protocol_header = build_subprotocol(self.available_subprotocols) headers["Sec-WebSocket-Protocol"] = protocol_header - headers["User-Agent"] = USER_AGENT + if self.user_agent_header is not None: + headers["User-Agent"] = self.user_agent_header return Request(self.wsuri.resource_name, headers) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 29550b694..93566b87e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -70,7 +70,8 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): is closed with any other code. See :func:`connect` for the documentation of ``logger``, ``origin``, - ``extensions``, ``subprotocols``, and ``extra_headers``. + ``extensions``, ``subprotocols``, ``extra_headers``, and + ``user_agent_header``. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -89,6 +90,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, **kwargs: Any, ) -> None: if logger is None: @@ -98,6 +100,7 @@ def __init__( self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers + self.user_agent_header = user_agent_header def write_http_request(self, path: str, headers: Headers) -> None: """ @@ -315,7 +318,8 @@ async def handshake( if self.extra_headers is not None: request_headers.update(self.extra_headers) - request_headers.setdefault("User-Agent", USER_AGENT) + if self.user_agent_header is not None: + request_headers.setdefault("User-Agent", self.user_agent_header) self.write_http_request(wsuri.resource_name, request_headers) @@ -393,6 +397,9 @@ class Connect: subprotocols: list of supported subprotocols, in order of decreasing preference. extra_headers: arbitrary HTTP headers to add to the request. + user_agent_header: value of the ``User-Agent`` request header; + defauts to ``"Python/x.y.z websockets/X.Y"``; + :obj:`None` removes the header. open_timeout: timeout for opening the connection in seconds; :obj:`None` to disable the timeout @@ -438,6 +445,7 @@ def __init__( extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, @@ -503,6 +511,7 @@ def __init__( extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, + user_agent_header=user_agent_header, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 2ee469825..836496b6e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -84,7 +84,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): ws_server: WebSocket server that created this connection. See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``, - ``extensions``, ``subprotocols``, and ``extra_headers``. + ``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -108,6 +108,7 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, + server_header: Optional[str] = USER_AGENT, process_request: Optional[ Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] ] = None, @@ -132,6 +133,7 @@ def __init__( self.available_extensions = extensions self.available_subprotocols = subprotocols self.extra_headers = extra_headers + self.server_header = server_header self._process_request = process_request self._select_subprotocol = select_subprotocol @@ -216,7 +218,9 @@ async def handler(self) -> None: ) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - headers.setdefault("Server", USER_AGENT) + if self.server_header is not None: + headers.setdefault("Server", self.server_header) + headers.setdefault("Content-Length", str(len(body))) headers.setdefault("Content-Type", "text/plain") headers.setdefault("Connection", "close") @@ -635,7 +639,8 @@ async def handshake( response_headers.update(extra_headers) response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - response_headers.setdefault("Server", USER_AGENT) + if self.server_header is not None: + response_headers.setdefault("Server", self.server_header) self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) @@ -938,6 +943,9 @@ class Serve: a :data:`~websockets.datastructures.HeadersLike` or a callable taking the request path and headers in arguments and returning a :data:`~websockets.datastructures.HeadersLike`. + server_header: value of the ``Server`` response header; + defauts to ``"Python/x.y.z websockets/X.Y"``; + :obj:`None` removes the header. process_request (Optional[Callable[[str, Headers], \ Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): intercept HTTP request before the opening handshake; @@ -980,6 +988,7 @@ def __init__( extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, extra_headers: Optional[HeadersLikeOrCallable] = None, + server_header: Optional[str] = USER_AGENT, process_request: Optional[ Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] ] = None, @@ -1061,6 +1070,7 @@ def __init__( extensions=extensions, subprotocols=subprotocols, extra_headers=extra_headers, + server_header=server_header, process_request=process_request, select_subprotocol=select_subprotocol, logger=logger, diff --git a/src/websockets/server.py b/src/websockets/server.py index 5dad50b6a..2057d9fb9 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -64,6 +64,9 @@ class ServerConnection(Connection): logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. + server_header: value of the ``Server`` response header; + defauts to ``"Python/x.y.z websockets/X.Y"``; + :obj:`None` removes the header. """ @@ -72,6 +75,7 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, + server_header: Optional[str] = USER_AGENT, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -85,6 +89,7 @@ def __init__( self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols + self.server_header = server_header def accept(self, request: Request) -> Response: """ @@ -170,7 +175,8 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - headers["Server"] = USER_AGENT + if self.server_header is not None: + headers["Server"] = self.server_header self.logger.info("connection open") return Response(101, "Switching Protocols", headers) @@ -469,9 +475,11 @@ def reject( ("Connection", "close"), ("Content-Length", str(len(body))), ("Content-Type", "text/plain; charset=utf-8"), - ("Server", USER_AGENT), ] ) + if self.server_header is not None: + headers["Server"] = self.server_header + response = Response(status.value, status.phrase, headers, body) # When reject() is called from accept(), handshake_exc is already set. # If a user calls reject(), set handshake_exc to guarantee invariant: diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f9de70c9c..22d72d1f7 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -655,12 +655,27 @@ def test_protocol_custom_request_headers(self): self.assertIn("('X-Spam', 'Eggs')", req_headers) @with_server() - @with_client("/headers", extra_headers={"User-Agent": "Eggs"}) - def test_protocol_custom_request_user_agent(self): + @with_client("/headers", extra_headers={"User-Agent": "websockets"}) + def test_protocol_custom_user_agent_header_legacy(self): req_headers = self.loop.run_until_complete(self.client.recv()) self.loop.run_until_complete(self.client.recv()) self.assertEqual(req_headers.count("User-Agent"), 1) - self.assertIn("('User-Agent', 'Eggs')", req_headers) + self.assertIn("('User-Agent', 'websockets')", req_headers) + + @with_server() + @with_client("/headers", user_agent_header=None) + def test_protocol_no_user_agent_header(self): + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertNotIn("User-Agent", req_headers) + + @with_server() + @with_client("/headers", user_agent_header="websockets") + def test_protocol_custom_user_agent_header(self): + req_headers = self.loop.run_until_complete(self.client.recv()) + self.loop.run_until_complete(self.client.recv()) + self.assertEqual(req_headers.count("User-Agent"), 1) + self.assertIn("('User-Agent', 'websockets')", req_headers) @with_server(extra_headers=lambda p, r: {"X-Spam": "Eggs"}) @with_client("/headers") @@ -682,13 +697,28 @@ def test_protocol_custom_response_headers(self): resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertIn("('X-Spam', 'Eggs')", resp_headers) - @with_server(extra_headers={"Server": "Eggs"}) + @with_server(extra_headers={"Server": "websockets"}) + @with_client("/headers") + def test_protocol_custom_server_header_legacy(self): + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertEqual(resp_headers.count("Server"), 1) + self.assertIn("('Server', 'websockets')", resp_headers) + + @with_server(server_header=None) @with_client("/headers") - def test_protocol_custom_response_user_agent(self): + def test_protocol_no_server_header(self): + self.loop.run_until_complete(self.client.recv()) + resp_headers = self.loop.run_until_complete(self.client.recv()) + self.assertNotIn("Server", resp_headers) + + @with_server(server_header="websockets") + @with_client("/headers") + def test_protocol_custom_server_header(self): self.loop.run_until_complete(self.client.recv()) resp_headers = self.loop.run_until_complete(self.client.recv()) self.assertEqual(resp_headers.count("Server"), 1) - self.assertIn("('Server', 'Eggs')", resp_headers) + self.assertIn("('Server', 'websockets')", resp_headers) @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): @@ -724,6 +754,16 @@ def test_ws_connection_ws_endpoint(self): self.loop.run_until_complete(self.client.recv()) self.stop_client() + @with_server(create_protocol=HealthCheckServerProtocol, server_header=None) + def test_http_request_no_server_header(self): + response = self.loop.run_until_complete(self.make_http_request("/__health__/")) + self.assertNotIn("Server", response.headers) + + @with_server(create_protocol=HealthCheckServerProtocol, server_header="websockets") + def test_http_request_custom_server_header(self): + response = self.loop.run_until_complete(self.make_http_request("/__health__/")) + self.assertEqual(response.headers["Server"], "websockets") + def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() diff --git a/tests/test_client.py b/tests/test_client.py index a843d3272..0504f79e6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -574,6 +574,22 @@ def test_unsupported_subprotocol(self): raise client.handshake_exc self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") + def test_no_user_agent_header(self): + client = ClientConnection( + parse_uri("wss://example.com/"), + user_agent_header=None, + ) + request = client.connect() + self.assertNotIn("User-Agent", request.headers) + + def test_custom_user_agent_header(self): + client = ClientConnection( + parse_uri("wss://example.com/"), + user_agent_header="websockets", + ) + request = client.connect() + self.assertEqual(request.headers["User-Agent"], "websockets") + class MiscTests(unittest.TestCase): def test_bypass_handshake(self): diff --git a/tests/test_server.py b/tests/test_server.py index e3e802239..c7f398cdf 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -608,6 +608,28 @@ def test_unsupported_subprotocol(self): self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) + def test_no_server_header(self): + server = ServerConnection(server_header=None) + request = self.make_request() + response = server.accept(request) + self.assertNotIn("Server", response.headers) + + def test_custom_server_header(self): + server = ServerConnection(server_header="websockets") + request = self.make_request() + response = server.accept(request) + self.assertEqual(response.headers["Server"], "websockets") + + def test_reject_response_no_server_header(self): + server = ServerConnection(server_header=None) + response = server.reject(http.HTTPStatus.OK, "Hello world!\n") + self.assertNotIn("Server", response.headers) + + def test_reject_response_custom_server_header(self): + server = ServerConnection(server_header="websockets") + response = server.reject(http.HTTPStatus.OK, "Hello world!\n") + self.assertEqual(response.headers["Server"], "websockets") + class MiscTests(unittest.TestCase): def test_bypass_handshake(self): From b5a2f37f8b2d0149e81ab380b84441a37a2db83e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Sun, 15 May 2022 10:27:09 +0200 Subject: [PATCH 296/762] Wrap recv_into() in test_explicit_socket to fix py3.11 Extend TrackedSocket class in test_explicit_socket to wrap the recv_into() method. The Python 3.11 implementation of asyncio is calling it rather than recv(), therefore causing used_for_read not to be set if it's not wrapped. --- tests/legacy/test_client_server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 22d72d1f7..c7e2e1cae 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -422,10 +422,14 @@ def __init__(self, *args, **kwargs): self.used_for_write = False super().__init__(*args, **kwargs) - def recv(self, *args, **kwargs): + def recv(self, *args, **kwargs): # pragma: no cover self.used_for_read = True return super().recv(*args, **kwargs) + def recv_into(self, *args, **kwargs): # pragma: no cover + self.used_for_read = True + return super().recv_into(*args, **kwargs) + def send(self, *args, **kwargs): self.used_for_write = True return super().send(*args, **kwargs) From bebbf40b5a14a75cd06e58be79dc1eed4255d8ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Sun, 15 May 2022 10:32:03 +0200 Subject: [PATCH 297/762] Skip YieldFromTests in Python 3.11+ asyncio.coroutine has been removed in Python 3.11, so skip it if Python is newer than that. Fixes #1175 --- tests/legacy/test_client_server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c7e2e1cae..8d0f8725f 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1343,6 +1343,9 @@ def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") +@unittest.skipIf( + sys.version_info[:2] >= (3, 11), "asyncio.coroutine has been removed in Python 3.11" +) class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): @with_server() def test_client(self): From 0ea2818b370ae572d52ec5d7063de2f8b64faa38 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 17 Aug 2022 10:19:59 +0200 Subject: [PATCH 298/762] Add changelog for 2a07325c. --- docs/project/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 105065c8c..b52943b8d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,12 @@ They may change at any time. *In development* +New features +............ + +* Supported overriding or removing the ``User-Agent`` header in clients and the + ``Server`` header in servers. + 10.3 ---- From 653466389996b466e7b0a6a83ec9e573ca5b09c5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 Aug 2022 12:45:04 +0200 Subject: [PATCH 299/762] Fix tests with PyPy. --- tests/legacy/test_client_server.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 8d0f8725f..f13ef6882 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -761,12 +761,16 @@ def test_ws_connection_ws_endpoint(self): @with_server(create_protocol=HealthCheckServerProtocol, server_header=None) def test_http_request_no_server_header(self): response = self.loop.run_until_complete(self.make_http_request("/__health__/")) - self.assertNotIn("Server", response.headers) + + with contextlib.closing(response): + self.assertNotIn("Server", response.headers) @with_server(create_protocol=HealthCheckServerProtocol, server_header="websockets") def test_http_request_custom_server_header(self): response = self.loop.run_until_complete(self.make_http_request("/__health__/")) - self.assertEqual(response.headers["Server"], "websockets") + + with contextlib.closing(response): + self.assertEqual(response.headers["Server"], "websockets") def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: From f7f8313f4078a75db4213d4340234b0715567349 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 22 Aug 2022 22:19:12 +0200 Subject: [PATCH 300/762] Enable dependabot --- .github/dependabot.yml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..123014908 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" From e36f5b673bc32885d3c08102a19f5194951e7940 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:19:42 +0000 Subject: [PATCH 301/762] Bump actions/setup-python from 2 to 4 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 2 to 4. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v2...v4) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/tests.yml | 4 ++-- .github/workflows/wheels.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 03ce3aff9..2e6e0f10e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: - name: Check out repository uses: actions/checkout@v2 - name: Install Python 3.x - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: "3.x" - name: Install tox @@ -59,7 +59,7 @@ jobs: - name: Check out repository uses: actions/checkout@v2 - name: Install Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install tox diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 0322055b1..aff7dcec9 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -14,7 +14,7 @@ jobs: - name: Check out repository uses: actions/checkout@v2 - name: Install Python 3.x - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.x - name: Build sdist @@ -39,7 +39,7 @@ jobs: - name: Make extension build mandatory run: touch .cibuildwheel - name: Install Python 3.x - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.x - name: Set up QEMU From e22a19a0666dfe6e0deeb0fd72adbdbbabfe5fea Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:19:46 +0000 Subject: [PATCH 302/762] Bump docker/setup-qemu-action from 1 to 2 Bumps [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) from 1 to 2. - [Release notes](https://github.com/docker/setup-qemu-action/releases) - [Commits](https://github.com/docker/setup-qemu-action/compare/v1...v2) --- updated-dependencies: - dependency-name: docker/setup-qemu-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index aff7dcec9..110e6add8 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -44,7 +44,7 @@ jobs: python-version: 3.x - name: Set up QEMU if: runner.os == 'Linux' - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 with: platforms: all - name: Build wheels From 8836c4e7e1b70ca41752e0dfb65f777062264f54 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:19:50 +0000 Subject: [PATCH 303/762] Bump pypa/cibuildwheel from 2.5.0 to 2.9.0 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.5.0 to 2.9.0. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.5.0...v2.9.0) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 110e6add8..fa1c660b5 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -48,7 +48,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.5.0 + uses: pypa/cibuildwheel@v2.9.0 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From fdb4b68d0dc5c4b439f1e2d81bf91fe70b8c8c95 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:19:53 +0000 Subject: [PATCH 304/762] Bump actions/download-artifact from 2 to 3 Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 2 to 3. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index fa1c660b5..7b07d9b31 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -66,7 +66,7 @@ jobs: runs-on: ubuntu-latest if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v3 with: name: artifact path: dist From 83204e26354be8f574e4c339f6d013899cc773ac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:20:00 +0000 Subject: [PATCH 305/762] Bump actions/upload-artifact from 2 to 3 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 2 to 3. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 7b07d9b31..2ce8a36fe 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -20,7 +20,7 @@ jobs: - name: Build sdist run: python setup.py sdist - name: Save sdist - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: path: dist/*.tar.gz @@ -54,7 +54,7 @@ jobs: CIBW_ARCHS_LINUX: "auto aarch64" CIBW_SKIP: cp36-* - name: Save wheels - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: path: wheelhouse/*.whl From b839d36b493a49e6da2a72a8883e5b4d48c3aff6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 22 Aug 2022 22:37:57 +0200 Subject: [PATCH 306/762] Configure dependabot schedule. --- .github/dependabot.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 123014908..ad1e824b4 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,4 +3,7 @@ updates: - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "daily" + interval: "weekly" + day: "saturday" + time: "07:00" + timezone: "Europe/Paris" From ff0e13010684e08f81c2534b315d228aac0acfae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Aug 2022 20:38:59 +0000 Subject: [PATCH 307/762] Bump actions/checkout from 2 to 3 Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 3. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/tests.yml | 4 ++-- .github/workflows/wheels.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2e6e0f10e..4785e8d20 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -57,7 +57,7 @@ jobs: is_main: false steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Install Python ${{ matrix.python }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 2ce8a36fe..fe80b3431 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -35,7 +35,7 @@ jobs: - macOS-latest steps: - name: Check out repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Make extension build mandatory run: touch .cibuildwheel - name: Install Python 3.x From cfa70cac5214c5eb2ba28077b53da037ee51d9a0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 23 Aug 2022 00:07:34 +0200 Subject: [PATCH 308/762] Pin action to a released version --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index fe80b3431..e7b9dc431 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -70,6 +70,6 @@ jobs: with: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@master + - uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} From fd17f4684ef41750bdad09e09e22ee9a653f1238 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 23 Aug 2022 22:39:03 +0200 Subject: [PATCH 309/762] Add OSS-Fuzz fuzz targets (experimental). --- fuzzing/fuzz_http11_request_parser.py | 35 ++++++++++++++++++++++++ fuzzing/fuzz_http11_response_parser.py | 38 ++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 fuzzing/fuzz_http11_request_parser.py create mode 100644 fuzzing/fuzz_http11_response_parser.py diff --git a/fuzzing/fuzz_http11_request_parser.py b/fuzzing/fuzz_http11_request_parser.py new file mode 100644 index 000000000..4879899e1 --- /dev/null +++ b/fuzzing/fuzz_http11_request_parser.py @@ -0,0 +1,35 @@ +import sys + +import atheris + + +with atheris.instrument_imports(): + from websockets.exceptions import SecurityError + from websockets.http11 import Request + from websockets.streams import StreamReader + + +def test_one_input(data): + reader = StreamReader() + reader.feed_data(data) + reader.feed_eof() + + try: + Request.parse( + reader.read_line, + ) + except ( + EOFError, # connection is closed without a full HTTP request + SecurityError, # request exceeds a security limit + ValueError, # request isn't well formatted + ): + pass + + +def main(): + atheris.Setup(sys.argv, test_one_input) + atheris.Fuzz() + + +if __name__ == "__main__": + main() diff --git a/fuzzing/fuzz_http11_response_parser.py b/fuzzing/fuzz_http11_response_parser.py new file mode 100644 index 000000000..2f0bcd3ed --- /dev/null +++ b/fuzzing/fuzz_http11_response_parser.py @@ -0,0 +1,38 @@ +import sys + +import atheris + + +with atheris.instrument_imports(): + from websockets.exceptions import SecurityError + from websockets.http11 import Response + from websockets.streams import StreamReader + + +def test_one_input(data): + reader = StreamReader() + reader.feed_data(data) + reader.feed_eof() + + try: + Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + ) + except ( + EOFError, # connection is closed without a full HTTP response. + SecurityError, # response exceeds a security limit. + LookupError, # response isn't well formatted. + ValueError, # response isn't well formatted. + ): + pass + + +def main(): + atheris.Setup(sys.argv, test_one_input) + atheris.Fuzz() + + +if __name__ == "__main__": + main() From a26cec226e63894b2cff3c5ea1ad3a643fb1d889 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 24 Aug 2022 07:50:33 +0200 Subject: [PATCH 310/762] Make fuzz targets actually run. --- fuzzing/fuzz_http11_request_parser.py | 10 +++++++--- fuzzing/fuzz_http11_response_parser.py | 21 ++++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/fuzzing/fuzz_http11_request_parser.py b/fuzzing/fuzz_http11_request_parser.py index 4879899e1..148785385 100644 --- a/fuzzing/fuzz_http11_request_parser.py +++ b/fuzzing/fuzz_http11_request_parser.py @@ -14,10 +14,14 @@ def test_one_input(data): reader.feed_data(data) reader.feed_eof() + parser = Request.parse( + reader.read_line, + ) + try: - Request.parse( - reader.read_line, - ) + next(parser) + except StopIteration: + pass # request is available in exc.value except ( EOFError, # connection is closed without a full HTTP request SecurityError, # request exceeds a security limit diff --git a/fuzzing/fuzz_http11_response_parser.py b/fuzzing/fuzz_http11_response_parser.py index 2f0bcd3ed..0f783f6fd 100644 --- a/fuzzing/fuzz_http11_response_parser.py +++ b/fuzzing/fuzz_http11_response_parser.py @@ -14,17 +14,20 @@ def test_one_input(data): reader.feed_data(data) reader.feed_eof() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + ) try: - Response.parse( - reader.read_line, - reader.read_exact, - reader.read_to_eof, - ) + next(parser) + except StopIteration: + pass # response is available in exc.value except ( - EOFError, # connection is closed without a full HTTP response. - SecurityError, # response exceeds a security limit. - LookupError, # response isn't well formatted. - ValueError, # response isn't well formatted. + EOFError, # connection is closed without a full HTTP response + SecurityError, # response exceeds a security limit + LookupError, # response isn't well formatted + ValueError, # response isn't well formatted ): pass From 5d1bad7ebb1121349d08260554368553c02d1a37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 24 Aug 2022 07:50:48 +0200 Subject: [PATCH 311/762] Add fuzz target for WebSocket parser. --- fuzzing/fuzz_websocket_parser.py | 46 ++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 fuzzing/fuzz_websocket_parser.py diff --git a/fuzzing/fuzz_websocket_parser.py b/fuzzing/fuzz_websocket_parser.py new file mode 100644 index 000000000..7569d0b61 --- /dev/null +++ b/fuzzing/fuzz_websocket_parser.py @@ -0,0 +1,46 @@ +import sys + +import atheris + + +with atheris.instrument_imports(): + from websockets.exceptions import PayloadTooBig, ProtocolError + from websockets.frames import Frame + from websockets.streams import StreamReader + + +def test_one_input(data): + fdp = atheris.FuzzedDataProvider(data) + mask = fdp.ConsumeBool() + max_size_enabled = fdp.ConsumeBool() + max_size = fdp.ConsumeInt(4) + payload = fdp.ConsumeBytes(atheris.ALL_REMAINING) + + reader = StreamReader() + reader.feed_data(payload) + reader.feed_eof() + + parser = Frame.parse( + reader.read_exact, + mask=mask, + max_size=max_size if max_size_enabled else None, + ) + + try: + next(parser) + except StopIteration: + pass # response is available in exc.value + except ( + PayloadTooBig, # frame's payload size exceeds ``max_size`` + ProtocolError, # frame contains incorrect values + ): + pass + + +def main(): + atheris.Setup(sys.argv, test_one_input) + atheris.Fuzz() + + +if __name__ == "__main__": + main() From 61e0e1c10f9895a15b34f0d43f160c6b4861e18b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 24 Aug 2022 08:04:54 +0200 Subject: [PATCH 312/762] Ensure fuzz targets work as expected. --- fuzzing/fuzz_http11_request_parser.py | 9 ++++++--- fuzzing/fuzz_http11_response_parser.py | 9 ++++++--- fuzzing/fuzz_websocket_parser.py | 9 ++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/fuzzing/fuzz_http11_request_parser.py b/fuzzing/fuzz_http11_request_parser.py index 148785385..59e0cea0f 100644 --- a/fuzzing/fuzz_http11_request_parser.py +++ b/fuzzing/fuzz_http11_request_parser.py @@ -20,14 +20,17 @@ def test_one_input(data): try: next(parser) - except StopIteration: - pass # request is available in exc.value + except StopIteration as exc: + assert isinstance(exc.value, Request) + return # input accepted except ( EOFError, # connection is closed without a full HTTP request SecurityError, # request exceeds a security limit ValueError, # request isn't well formatted ): - pass + return # input rejected with a documented exception + + raise RuntimeError("parsing didn't complete") def main(): diff --git a/fuzzing/fuzz_http11_response_parser.py b/fuzzing/fuzz_http11_response_parser.py index 0f783f6fd..6906720a4 100644 --- a/fuzzing/fuzz_http11_response_parser.py +++ b/fuzzing/fuzz_http11_response_parser.py @@ -21,15 +21,18 @@ def test_one_input(data): ) try: next(parser) - except StopIteration: - pass # response is available in exc.value + except StopIteration as exc: + assert isinstance(exc.value, Response) + return # input accepted except ( EOFError, # connection is closed without a full HTTP response SecurityError, # response exceeds a security limit LookupError, # response isn't well formatted ValueError, # response isn't well formatted ): - pass + return # input rejected with a documented exception + + raise RuntimeError("parsing didn't complete") def main(): diff --git a/fuzzing/fuzz_websocket_parser.py b/fuzzing/fuzz_websocket_parser.py index 7569d0b61..ab9c1dd2e 100644 --- a/fuzzing/fuzz_websocket_parser.py +++ b/fuzzing/fuzz_websocket_parser.py @@ -28,13 +28,16 @@ def test_one_input(data): try: next(parser) - except StopIteration: - pass # response is available in exc.value + except StopIteration as exc: + assert isinstance(exc.value, Frame) + return # input accepted except ( PayloadTooBig, # frame's payload size exceeds ``max_size`` ProtocolError, # frame contains incorrect values ): - pass + return # input rejected with a documented exception + + raise RuntimeError("parsing didn't complete") def main(): From 973edf67f2956c5fc6e5bd11c779f267224d08e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 24 Aug 2022 08:28:45 +0200 Subject: [PATCH 313/762] Add expected exceptions in Frame.parse. --- fuzzing/fuzz_websocket_parser.py | 2 ++ src/websockets/frames.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/fuzzing/fuzz_websocket_parser.py b/fuzzing/fuzz_websocket_parser.py index ab9c1dd2e..1509a3549 100644 --- a/fuzzing/fuzz_websocket_parser.py +++ b/fuzzing/fuzz_websocket_parser.py @@ -32,6 +32,8 @@ def test_one_input(data): assert isinstance(exc.value, Frame) return # input accepted except ( + EOFError, # connection is closed without a full WebSocket frame + UnicodeDecodeError, # frame contains invalid UTF-8 PayloadTooBig, # frame's payload size exceeds ``max_size`` ProtocolError, # frame contains incorrect values ): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 043b688b5..ec6b8547d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -191,6 +191,8 @@ def parse( extensions: list of extensions, applied in reverse order. Raises: + EOFError: if the connection is closed without a full WebSocket frame. + UnicodeDecodeError: if the frame contains invalid UTF-8. PayloadTooBig: if the frame's payload size exceeds ``max_size``. ProtocolError: if the frame contains incorrect values. From eeedb71b1b2f88633d3b0cc8715f2ce8e62f77de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 10 Sep 2022 08:43:32 +0200 Subject: [PATCH 314/762] Add CII Best Practices badge. Remove "wheel" badge which has become standard practice and isn't particularly interesting. --- README.rst | 10 +++++----- docs/index.rst | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.rst b/README.rst index 2b9a445ea..7cbccfe13 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ :width: 480px :alt: websockets -|licence| |version| |pyversions| |wheel| |tests| |docs| +|licence| |version| |pyversions| |tests| |docs| |openssf| .. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets @@ -13,15 +13,15 @@ .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg - :target: https://pypi.python.org/pypi/websockets - -.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main +.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main?label=tests :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ +.. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge + :target: https://bestpractices.coreinfrastructure.org/projects/6475 + What is ``websockets``? ----------------------- diff --git a/docs/index.rst b/docs/index.rst index 00a9b5999..d64d80ac3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,7 @@ websockets ========== -|licence| |version| |pyversions| |wheel| |tests| |docs| +|licence| |version| |pyversions| |tests| |docs| |openssf| .. |licence| image:: https://img.shields.io/pypi/l/websockets.svg :target: https://pypi.python.org/pypi/websockets @@ -12,15 +12,15 @@ websockets .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |wheel| image:: https://img.shields.io/pypi/wheel/websockets.svg - :target: https://pypi.python.org/pypi/websockets - -.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main +.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main?label=tests :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ +.. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge + :target: https://bestpractices.coreinfrastructure.org/projects/6475 + websockets is a library for building WebSocket_ servers and clients in Python with a focus on correctness, simplicity, robustness, and performance. From eabb4b6a218eaa4be908e1038ebe7eb0724031de Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 17 Sep 2022 05:29:33 +0000 Subject: [PATCH 315/762] Bump pypa/cibuildwheel from 2.9.0 to 2.10.0 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.9.0 to 2.10.0. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.9.0...v2.10.0) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index e7b9dc431..cc22cd4c6 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -48,7 +48,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.9.0 + uses: pypa/cibuildwheel@v2.10.0 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From 569ba1e04aa78313fd4bceb1cdae59c822682add Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 26 Sep 2022 21:59:13 +0200 Subject: [PATCH 316/762] Add disclaimer to Heroku tutorial. I wouldn't write it today :-( --- docs/howto/heroku.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 464420e05..2ac5a8c1e 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -4,6 +4,14 @@ Deploy to Heroku This guide describes how to deploy a websockets server to Heroku_. The same principles should apply to other Platform as a Service providers. +.. admonition:: Heroku no longer offers a free tier. + :class: attention + + When this tutorial was written, in September 2021, Heroku offered a free + tier where a websockets app could run at no cost. In November 2022, Heroku + removed the free tier, making it impossible to maintain this document. As a + consequence, it isn't updated anymore and may be removed in the future. + We're going to deploy a very simple app. The process would be identical for a more realistic app. @@ -42,7 +50,7 @@ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: .. literalinclude:: ../../example/deployment/heroku/app.py - :language: text + :language: python Heroku expects the server to `listen on a specific port`_, which is provided in the ``$PORT`` environment variable. The app reads it and passes it to From 7eedf7aab316cd26b8db2db2d7113d479d2f4d0b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Oct 2022 17:59:29 +0200 Subject: [PATCH 317/762] Add guide for deploying on Render. --- docs/howto/index.rst | 1 + docs/howto/render.rst | 173 +++++++++++++++++++++ example/deployment/render/app.py | 36 +++++ example/deployment/render/requirements.txt | 1 + 4 files changed, 211 insertions(+) create mode 100644 docs/howto/render.rst create mode 100644 example/deployment/render/app.py create mode 100644 example/deployment/render/requirements.txt diff --git a/docs/howto/index.rst b/docs/howto/index.rst index dafb72391..23fe823e1 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -39,6 +39,7 @@ Once your application is ready, learn how to deploy it on various platforms. .. toctree:: :titlesonly: + render heroku kubernetes supervisor diff --git a/docs/howto/render.rst b/docs/howto/render.rst new file mode 100644 index 000000000..b8c417b66 --- /dev/null +++ b/docs/howto/render.rst @@ -0,0 +1,173 @@ +Deploy to Render +================ + +This guide describes how to deploy a websockets server to Render_. + +.. _Render: https://render.com/ + +.. admonition:: The free plan of Render is sufficient for trying this guide. + :class: tip + + However, on a `free plan`__, connections are dropped after five minutes, + which is quite short for WebSocket application. + + __ https://render.com/docs/free + +We're going to deploy a very simple app. The process would be identical for a +more realistic app. + +Create repository +----------------- + +Deploying to Render requires a git repository. Let's initialize one: + +.. code-block:: console + + $ mkdir websockets-echo + $ cd websockets-echo + $ git init -b main + Initialized empty Git repository in websockets-echo/.git/ + $ git commit --allow-empty -m "Initial commit." + [main (root-commit) 816c3b1] Initial commit. + +Render requires the git repository to be hosted at GitHub or GitLab. + +Sign up or log in to GitHub. Create a new repository named ``websockets-echo``. +Don't enable any of the initialization options offered by GitHub. Then, follow +instructions for pushing an existing repository from the command line. + +After pushing, refresh your repository's homepage on GitHub. You should see an +empty repository with an empty initial commit. + +Create application +------------------ + +Here's the implementation of the app, an echo server. Save it in a file called +``app.py``: + +.. literalinclude:: ../../example/deployment/render/app.py + :language: python + +This app implements requirements for `zero downtime deploys`_: + +* it provides a health check at ``/healthz``; +* it closes connections and exits cleanly when it receives a ``SIGTERM`` signal. + +.. _zero downtime deploys: https://render.com/docs/deploys#zero-downtime-deploys + +Create a ``requirements.txt`` file containing this line to declare a dependency +on websockets: + +.. literalinclude:: ../../example/deployment/render/requirements.txt + :language: text + +Confirm that you created the correct files and commit them to git: + +.. code-block:: console + + $ ls + app.py requirements.txt + $ git add . + $ git commit -m "Initial implementation." + [main f26bf7f] Initial implementation. + 2 files changed, 37 insertions(+) + create mode 100644 app.py + create mode 100644 requirements.txt + +Push the changes to GitHub: + +.. code-block:: console + + $ git push + ... + To github.com:/websockets-echo.git + 816c3b1..f26bf7f main -> main + +The app is ready. Let's deploy it! + +Deploy application +------------------ + +Sign up or log in to Render. + +Create a new web service. Connect the git repository that you just created. + +Then, finalize the configuration of your app as follows: + +* **Name**: websockets-echo +* **Start Command**: ``python app.py`` + +If you're just experimenting, select the free plan. Create the web service. + +To configure the health check, go to Settings, scroll down to Health & Alerts, +and set: + +* **Health Check Path**: /healthz + +This triggers a new deployment. + +Validate deployment +------------------- + +Let's confirm that your application is running as expected. + +Since it's a WebSocket server, you need a WebSocket client, such as the +interactive client that comes with websockets. + +If you're currently building a websockets server, perhaps you're already in a +virtualenv where websockets is installed. If not, you can install it in a new +virtualenv as follows: + +.. code-block:: console + + $ python -m venv websockets-client + $ . websockets-client/bin/activate + $ pip install websockets + +Connect the interactive client — you must replace ``websockets-echo`` with the +name of your Render app in this command: + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.onrender.com/ + Connected to wss://websockets-echo.onrender.com/. + > + +Great! Your app is running! + +Once you're connected, you can send any message and the server will echo it, +or press Ctrl-D to terminate the connection: + +.. code-block:: console + + > Hello! + < Hello! + Connection closed: 1000 (OK). + +You can also confirm that your application shuts down gracefully when you deploy +a new version. Due to limitations of Render's free plan, you must upgrade to a +paid plan before you perform this test. + +Connect an interactive client again — remember to replace ``websockets-echo`` +with your app: + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.onrender.com/ + Connected to wss://websockets-echo.onrender.com/. + > + +Trigger a new deployment with Manual Deploy > Deploy latest commit. When the +deployment completes, the connection is closed with code 1001 (going away). + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.onrender.com/ + Connected to wss://websockets-echo.onrender.com/. + Connection closed: 1001 (going away). + +If graceful shutdown wasn't working, the server wouldn't perform a closing +handshake and the connection would be closed with code 1006 (connection closed +abnormally). + +Remember to downgrade to a free plan if you upgraded just for testing this feature. diff --git a/example/deployment/render/app.py b/example/deployment/render/app.py new file mode 100644 index 000000000..4ca34d23b --- /dev/null +++ b/example/deployment/render/app.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +import asyncio +import http +import signal + +import websockets + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +async def health_check(path, request_headers): + if path == "/healthz": + return http.HTTPStatus.OK, [], b"OK\n" + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, + host="", + port=8080, + process_request=health_check, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/render/requirements.txt b/example/deployment/render/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/deployment/render/requirements.txt @@ -0,0 +1 @@ +websockets From 98999b66d502a067cdb0d241dcd7003f6f185da4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Oct 2022 20:39:35 +0200 Subject: [PATCH 318/762] Add guide for deploying on Fly. --- docs/howto/fly.rst | 178 ++++++++++++++++++++++++ docs/howto/index.rst | 1 + docs/howto/kubernetes.rst | 2 + docs/spelling_wordlist.txt | 4 + example/deployment/fly/Procfile | 1 + example/deployment/fly/app.py | 36 +++++ example/deployment/fly/fly.toml | 16 +++ example/deployment/fly/requirements.txt | 1 + 8 files changed, 239 insertions(+) create mode 100644 docs/howto/fly.rst create mode 100644 example/deployment/fly/Procfile create mode 100644 example/deployment/fly/app.py create mode 100644 example/deployment/fly/fly.toml create mode 100644 example/deployment/fly/requirements.txt diff --git a/docs/howto/fly.rst b/docs/howto/fly.rst new file mode 100644 index 000000000..7e404de61 --- /dev/null +++ b/docs/howto/fly.rst @@ -0,0 +1,178 @@ +Deploy to Fly +================ + +This guide describes how to deploy a websockets server to Fly_. + +.. _Fly: https://fly.io/ + +.. admonition:: The free tier of Fly is sufficient for trying this guide. + :class: tip + + The `free tier`__ include up to three small VMs. This guide uses only one. + + __ https://fly.io/docs/about/pricing/ + +We're going to deploy a very simple app. The process would be identical for a +more realistic app. + +Create application +------------------ + +Here's the implementation of the app, an echo server. Save it in a file called +``app.py``: + +.. literalinclude:: ../../example/deployment/fly/app.py + :language: python + +This app implements typical requirements for running on a Platform as a Service: + +* it provides a health check at ``/healthz``; +* it closes connections and exits cleanly when it receives a ``SIGTERM`` signal. + +Create a ``requirements.txt`` file containing this line to declare a dependency +on websockets: + +.. literalinclude:: ../../example/deployment/fly/requirements.txt + :language: text + +The app is ready. Let's deploy it! + +Deploy application +------------------ + +Follow the instructions__ to install the Fly CLI, if you haven't done that yet. + +__ https://fly.io/docs/hands-on/install-flyctl/ + +Sign up or log in to Fly. + +Launch the app — you'll have to pick a different name because I'm already using +``websockets-echo``: + +.. code-block:: console + + $ fly launch + Creating app in ... + Scanning source code + Detected a Python app + Using the following build configuration: + Builder: paketobuildpacks/builder:base + ? App Name (leave blank to use an auto-generated name): websockets-echo + ? Select organization: ... + ? Select region: ... + Created app websockets-echo in organization ... + Wrote config file fly.toml + ? Would you like to set up a Postgresql database now? No + We have generated a simple Procfile for you. Modify it to fit your needs and run "fly deploy" to deploy your application. + +.. admonition:: This will build the image with a generic buildpack. + :class: tip + + Fly can `build images`__ with a Dockerfile or a buildpack. Here, ``fly + launch`` configures a generic Paketo buildpack. + + If you'd rather package the app with a Dockerfile, check out the guide to + :ref:`containerize an application `. + + __ https://fly.io/docs/reference/builders/ + +Replace the auto-generated ``fly.toml`` with: + +.. literalinclude:: ../../example/deployment/fly/fly.toml + :language: toml + +This configuration: + +* listens on port 443, terminates TLS, and forwards to the app on port 8080; +* declares a health check at ``/healthz``; +* requests a ``SIGTERM`` for terminating the app. + +Replace the auto-generated ``Procfile`` with: + +.. literalinclude:: ../../example/deployment/fly/Procfile + :language: text + +This tells Fly how to run the app. + +Now you can deploy it: + +.. code-block:: console + + $ fly deploy + + ... lots of output... + + ==> Monitoring deployment + + 1 desired, 1 placed, 1 healthy, 0 unhealthy [health checks: 1 total, 1 passing] + --> v0 deployed successfully + +Validate deployment +------------------- + +Let's confirm that your application is running as expected. + +Since it's a WebSocket server, you need a WebSocket client, such as the +interactive client that comes with websockets. + +If you're currently building a websockets server, perhaps you're already in a +virtualenv where websockets is installed. If not, you can install it in a new +virtualenv as follows: + +.. code-block:: console + + $ python -m venv websockets-client + $ . websockets-client/bin/activate + $ pip install websockets + +Connect the interactive client — you must replace ``websockets-echo`` with the +name of your Fly app in this command: + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.fly.dev/ + Connected to wss://websockets-echo.fly.dev/. + > + +Great! Your app is running! + +Once you're connected, you can send any message and the server will echo it, +or press Ctrl-D to terminate the connection: + +.. code-block:: console + + > Hello! + < Hello! + Connection closed: 1000 (OK). + +You can also confirm that your application shuts down gracefully. + +Connect an interactive client again — remember to replace ``websockets-echo`` +with your app: + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.fly.dev/ + Connected to wss://websockets-echo.fly.dev/. + > + +In another shell, restart the app — again, replace ``websockets-echo`` with your +app: + +.. code-block:: console + + $ fly restart websockets-echo + websockets-echo is being restarted + +Go back to the first shell. The connection is closed with code 1001 (going +away). + +.. code-block:: console + + $ python -m websockets wss://websockets-echo.fly.dev/ + Connected to wss://websockets-echo.fly.dev/. + Connection closed: 1001 (going away). + +If graceful shutdown wasn't working, the server wouldn't perform a closing +handshake and the connection would be closed with code 1006 (connection closed +abnormally). diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 23fe823e1..ddbe67d3a 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -40,6 +40,7 @@ Once your application is ready, learn how to deploy it on various platforms. :titlesonly: render + fly heroku kubernetes supervisor diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst index 26dbf8a94..c217e5946 100644 --- a/docs/howto/kubernetes.rst +++ b/docs/howto/kubernetes.rst @@ -13,6 +13,8 @@ websockets is concerned. .. _Kubernetes: https://kubernetes.io/ +.. _containerize-application: + Containerize application ------------------------ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1d5ae527d..d5b093e2d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -10,6 +10,7 @@ balancers bottlenecked bufferbloat bugfix +buildpack bytestring bytestrings changelog @@ -20,9 +21,11 @@ cryptocurrency ctrl deserialize django +Dockerfile dyno fractalideas gunicorn +healthz hypercorn iframe IPv @@ -38,6 +41,7 @@ lookups MiB mypy nginx +Paketo permessage pid proxying diff --git a/example/deployment/fly/Procfile b/example/deployment/fly/Procfile new file mode 100644 index 000000000..2e35818f6 --- /dev/null +++ b/example/deployment/fly/Procfile @@ -0,0 +1 @@ +web: python app.py diff --git a/example/deployment/fly/app.py b/example/deployment/fly/app.py new file mode 100644 index 000000000..4ca34d23b --- /dev/null +++ b/example/deployment/fly/app.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +import asyncio +import http +import signal + +import websockets + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +async def health_check(path, request_headers): + if path == "/healthz": + return http.HTTPStatus.OK, [], b"OK\n" + + +async def main(): + # Set the stop condition when receiving SIGTERM. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) + + async with websockets.serve( + echo, + host="", + port=8080, + process_request=health_check, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/fly/fly.toml b/example/deployment/fly/fly.toml new file mode 100644 index 000000000..5290072ed --- /dev/null +++ b/example/deployment/fly/fly.toml @@ -0,0 +1,16 @@ +app = "websockets-echo" +kill_signal = "SIGTERM" + +[build] + builder = "paketobuildpacks/builder:base" + +[[services]] + internal_port = 8080 + protocol = "tcp" + + [[services.http_checks]] + path = "/healthz" + + [[services.ports]] + handlers = ["tls", "http"] + port = 443 diff --git a/example/deployment/fly/requirements.txt b/example/deployment/fly/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/deployment/fly/requirements.txt @@ -0,0 +1 @@ +websockets From c77b3087a25b2eb459b1049b5352b6a4e912970e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Oct 2022 20:49:35 +0200 Subject: [PATCH 319/762] Uniformize deployment guides. --- docs/howto/heroku.rst | 80 +++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 2ac5a8c1e..2b3a44819 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -4,6 +4,8 @@ Deploy to Heroku This guide describes how to deploy a websockets server to Heroku_. The same principles should apply to other Platform as a Service providers. +.. _Heroku: https://www.heroku.com/ + .. admonition:: Heroku no longer offers a free tier. :class: attention @@ -15,10 +17,8 @@ principles should apply to other Platform as a Service providers. We're going to deploy a very simple app. The process would be identical for a more realistic app. -.. _Heroku: https://www.heroku.com/ - -Create application ------------------- +Create repository +----------------- Deploying to Heroku requires a git repository. Let's initialize one: @@ -31,20 +31,8 @@ Deploying to Heroku requires a git repository. Let's initialize one: $ git commit --allow-empty -m "Initial commit." [main (root-commit) 1e7947d] Initial commit. -Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if -you haven't done that yet. - -.. _set-up instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up - -Then, create a Heroku app — if you follow these instructions step-by-step, -you'll have to pick a different name because I'm already using -``websockets-echo`` on Heroku: - -.. code-block:: console - - $ heroku create websockets-echo - Creating ⬢ websockets-echo... done - https://websockets-echo.herokuapp.com/ | https://git.heroku.com/websockets-echo.git +Create application +------------------ Here's the implementation of the app, an echo server. Save it in a file called ``app.py``: @@ -64,20 +52,18 @@ cleanly. .. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown -Deploy application ------------------- - -In order to build the app, Heroku needs to know that it depends on websockets. -Create a ``requirements.txt`` file containing this line: +Create a ``requirements.txt`` file containing this line to declare a dependency +on websockets: .. literalinclude:: ../../example/deployment/heroku/requirements.txt :language: text -Heroku also needs to know how to run the app. Create a ``Procfile`` with this -content: +Create a ``Procfile``. .. literalinclude:: ../../example/deployment/heroku/Procfile +This tells Heroku how to run the app. + Confirm that you created the correct files and commit them to git: .. code-block:: console @@ -85,8 +71,8 @@ Confirm that you created the correct files and commit them to git: $ ls Procfile app.py requirements.txt $ git add . - $ git commit -m "Deploy echo server to Heroku." - [main 8418c62] Deploy echo server to Heroku. + $ git commit -m "Initial implementation." + [main 8418c62] Initial implementation.  3 files changed, 32 insertions(+)  create mode 100644 Procfile  create mode 100644 app.py @@ -94,6 +80,25 @@ Confirm that you created the correct files and commit them to git: The app is ready. Let's deploy it! +Deploy application +------------------ + +Follow the instructions_ to install the Heroku CLI, if you haven't done that +yet. + +.. _instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up + +Sign up or log in to Heroku. + +Create a Heroku app — you'll have to pick a different name because I'm already +using ``websockets-echo``: + +.. code-block:: console + + $ heroku create websockets-echo + Creating ⬢ websockets-echo... done + https://websockets-echo.herokuapp.com/ | https://git.heroku.com/websockets-echo.git + .. code-block:: console $ git push heroku @@ -111,7 +116,7 @@ The app is ready. Let's deploy it! Validate deployment ------------------- -Of course you'd like to confirm that your application is running as expected! +Let's confirm that your application is running as expected. Since it's a WebSocket server, you need a WebSocket client, such as the interactive client that comes with websockets. @@ -126,8 +131,8 @@ virtualenv as follows: $ . websockets-client/bin/activate $ pip install websockets -Connect the interactive client — using the name of your Heroku app instead of -``websockets-echo``: +Connect the interactive client — you must replace ``websockets-echo`` with the +name of your Heroku app in this command: .. code-block:: console @@ -137,12 +142,8 @@ Connect the interactive client — using the name of your Heroku app instead of Great! Your app is running! -In this example, I used a secure connection (``wss://``). It worked because -Heroku served a valid TLS certificate for ``websockets-echo.herokuapp.com``. -An insecure connection (``ws://``) would also work. - Once you're connected, you can send any message and the server will echo it, -then press Ctrl-D to terminate the connection: +or press Ctrl-D to terminate the connection: .. code-block:: console @@ -150,8 +151,10 @@ then press Ctrl-D to terminate the connection: < Hello! Connection closed: 1000 (OK). -You can also confirm that your application shuts down gracefully. Connect an -interactive client again — remember to replace ``websockets-echo`` with your app: +You can also confirm that your application shuts down gracefully. + +Connect an interactive client again — remember to replace ``websockets-echo`` +with your app: .. code-block:: console @@ -159,7 +162,8 @@ interactive client again — remember to replace ``websockets-echo`` with your a Connected to wss://websockets-echo.herokuapp.com/. > -In another shell, restart the dyno — again, replace ``websockets-echo`` with your app: +In another shell, restart the app — again, replace ``websockets-echo`` with your +app: .. code-block:: console From 16a85b8362738851fc622b8ffd8092f7f1bb3e87 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Oct 2022 20:52:52 +0200 Subject: [PATCH 320/762] Run spell check. --- docs/project/changelog.rst | 2 +- docs/reference/index.rst | 2 +- docs/spelling_wordlist.txt | 1 + docs/topics/timeouts.rst | 2 +- src/websockets/client.py | 2 +- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- 8 files changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b52943b8d..dd19228be 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -307,7 +307,7 @@ Backwards-incompatible changes :class: note While Python supports this, tools relying on static code analysis don't. - This breaks autocompletion in an IDE or type checking with mypy_. + This breaks auto-completion in an IDE or type checking with mypy_. .. _mypy: https://github.com/python/mypy diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 5f51a1c1c..8147c4cf3 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -60,7 +60,7 @@ guarantees of behavior or backwards-compatibility for private APIs. For convenience, many public APIs can be imported from the ``websockets`` package. However, this feature is incompatible with static code analysis. It -breaks autocompletion in an IDE or type checking with mypy_. If you're using +breaks auto-completion in an IDE or type checking with mypy_. If you're using such tools, use the real import paths. .. _mypy: https://github.com/python/mypy diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index d5b093e2d..1d342515b 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -23,6 +23,7 @@ deserialize django Dockerfile dyno +formatter fractalideas gunicorn healthz diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index dcf0322a4..23e8020c4 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -11,7 +11,7 @@ long-lived connections, it is desirable to ensure that connections don't break, and if they do, to report the problem quickly. Connections can drop as a consequence of temporary network connectivity issues, -which are very common, even within datacenters. +which are very common, even within data centers. Furthermore, WebSocket builds on top of HTTP/1.1 where connections are short-lived, even with ``Connection: keep-alive``. Typically, HTTP/1.1 diff --git a/src/websockets/client.py b/src/websockets/client.py index 1c33fae0b..c18c76f22 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -65,7 +65,7 @@ class ClientConnection(Connection): defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. user_agent_header: value of the ``User-Agent`` request header; - defauts to ``"Python/x.y.z websockets/X.Y"``; + defaults to ``"Python/x.y.z websockets/X.Y"``; :obj:`None` removes the header. """ diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 93566b87e..1e3c6e741 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -398,7 +398,7 @@ class Connect: preference. extra_headers: arbitrary HTTP headers to add to the request. user_agent_header: value of the ``User-Agent`` request header; - defauts to ``"Python/x.y.z websockets/X.Y"``; + defaults to ``"Python/x.y.z websockets/X.Y"``; :obj:`None` removes the header. open_timeout: timeout for opening the connection in seconds; :obj:`None` to disable the timeout diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 836496b6e..e01548896 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -944,7 +944,7 @@ class Serve: taking the request path and headers in arguments and returning a :data:`~websockets.datastructures.HeadersLike`. server_header: value of the ``Server`` response header; - defauts to ``"Python/x.y.z websockets/X.Y"``; + defaults to ``"Python/x.y.z websockets/X.Y"``; :obj:`None` removes the header. process_request (Optional[Callable[[str, Headers], \ Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): diff --git a/src/websockets/server.py b/src/websockets/server.py index 2057d9fb9..ca48449d6 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -65,7 +65,7 @@ class ServerConnection(Connection): defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. server_header: value of the ``Server`` response header; - defauts to ``"Python/x.y.z websockets/X.Y"``; + defaults to ``"Python/x.y.z websockets/X.Y"``; :obj:`None` removes the header. """ From 270d5dae8f87afa4f0a340fc0045e4cc980e5d44 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Oct 2022 13:06:01 +0200 Subject: [PATCH 321/762] Don't require src to be on PYTHONPATH. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index a38634aee..578f6b1ae 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ export PYTHONWARNINGS=default default: coverage style style: - isort src tests + isort --project websockets src tests black src tests flake8 src tests mypy --strict src From a899d5a1e5453767cda984324c61caa83ef0654c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:03:48 +0200 Subject: [PATCH 322/762] Avoid triggerring deprecation warnings in tests. This also avoids having to account for them. --- tests/legacy/test_client_server.py | 50 ++++++++++++++++-------------- tests/legacy/test_protocol.py | 46 ++++++++++++++++----------- 2 files changed, 54 insertions(+), 42 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f13ef6882..62cddd2b4 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -226,16 +226,16 @@ def start_server(self, deprecation_warnings=None, **kwargs): # Disable pings by default in tests. kwargs.setdefault("ping_interval", None) + # This logic is encapsulated in a coroutine to prevent it from executing + # before the event loop is running which causes asyncio.get_event_loop() + # to raise a DeprecationWarning on Python ≥ 3.10. + async def start_server(): + return await serve(handler, "localhost", 0, **kwargs) + with warnings.catch_warnings(record=True) as recorded_warnings: - start_server = serve(handler, "localhost", 0, **kwargs) - self.server = self.loop.run_until_complete(start_server) + self.server = self.loop.run_until_complete(start_server()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if ( - sys.version_info[:2] >= (3, 10) - and "remove loop argument" not in expected_warnings - ): # pragma: no cover - expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def start_client( @@ -252,16 +252,16 @@ def start_client( except KeyError: server_uri = get_server_uri(self.server, secure, resource_name, user_info) + # This logic is encapsulated in a coroutine to prevent it from executing + # before the event loop is running which causes asyncio.get_event_loop() + # to raise a DeprecationWarning on Python ≥ 3.10. + async def start_client(): + return await connect(server_uri, **kwargs) + with warnings.catch_warnings(record=True) as recorded_warnings: - start_client = connect(server_uri, **kwargs) - self.client = self.loop.run_until_complete(start_client) + self.client = self.loop.run_until_complete(start_client()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings - if ( - sys.version_info[:2] >= (3, 10) - and "remove loop argument" not in expected_warnings - ): # pragma: no cover - expected_warnings += ["There is no current event loop"] self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): @@ -465,25 +465,26 @@ def test_unix_socket(self): path = bytes(pathlib.Path(temp_dir) / "websockets") # Like self.start_server() but with unix_serve(). - with warnings.catch_warnings(record=True) as recorded_warnings: - unix_server = unix_serve(default_handler, path, loop=self.loop) - self.server = self.loop.run_until_complete(unix_server) - self.assertDeprecationWarnings(recorded_warnings, ["remove loop argument"]) + async def start_server(): + return await unix_serve(default_handler, path) + + self.server = self.loop.run_until_complete(start_server()) try: # Like self.start_client() but with unix_connect() - with warnings.catch_warnings(record=True) as recorded_warnings: - unix_client = unix_connect(path, loop=self.loop) - self.client = self.loop.run_until_complete(unix_client) - self.assertDeprecationWarnings( - recorded_warnings, ["remove loop argument"] - ) + async def start_client(): + return await unix_connect(path) + + self.client = self.loop.run_until_complete(start_client()) + try: self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") + finally: self.stop_client() + finally: self.stop_server() @@ -1341,6 +1342,7 @@ def test_checking_lack_of_origin_succeeds(self): self.assertEqual(self.loop.run_until_complete(self.client.recv()), "Hello!") @with_server(origins=[""]) + # The deprecation warning is raised when a client connects to the server. @with_client(deprecation_warnings=["use None instead of '' in origins"]) def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): self.loop.run_until_complete(self.client.send("Hello!")) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ed6424694..6e7f3727b 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import sys import unittest import unittest.mock import warnings @@ -88,9 +87,16 @@ class CommonTests: def setUp(self): super().setUp() - with warnings.catch_warnings(record=True): + + # This logic is encapsulated in a coroutine to prevent it from executing + # before the event loop is running which causes asyncio.get_event_loop() + # to raise a DeprecationWarning on Python ≥ 3.10. + + async def create_protocol(): # Disable pings to make it easier to test what frames are sent exactly. - self.protocol = WebSocketCommonProtocol(ping_interval=None) + return WebSocketCommonProtocol(ping_interval=None) + + self.protocol = self.loop.run_until_complete(create_protocol()) self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) @@ -312,29 +318,24 @@ def assertCompletesWithin(self, min_time, max_time): # Test constructor. def test_timeout_backwards_compatibility(self): + async def create_protocol(): + return WebSocketCommonProtocol(ping_interval=None, timeout=5) + with warnings.catch_warnings(record=True) as recorded: - protocol = WebSocketCommonProtocol(timeout=5) + protocol = self.loop.run_until_complete(create_protocol()) self.assertEqual(protocol.close_timeout, 5) - - expected = ["rename timeout to close_timeout"] - if sys.version_info[:2] >= (3, 10): # pragma: no cover - expected += ["There is no current event loop"] - - self.assertDeprecationWarnings(recorded, expected) + self.assertDeprecationWarnings(recorded, ["rename timeout to close_timeout"]) def test_loop_backwards_compatibility(self): loop = asyncio.new_event_loop() self.addCleanup(loop.close) with warnings.catch_warnings(record=True) as recorded: - protocol = WebSocketCommonProtocol(loop=loop) + protocol = WebSocketCommonProtocol(ping_interval=None, loop=loop) self.assertEqual(protocol.loop, loop) - - expected = ["remove loop argument"] - - self.assertDeprecationWarnings(recorded, expected) + self.assertDeprecationWarnings(recorded, ["remove loop argument"]) # Test public attributes. @@ -1151,18 +1152,27 @@ def test_connection_closed_attributes(self): # Test the protocol logic for sending keepalive pings. def restart_protocol_with_keepalive_ping( - self, ping_interval=3 * MS, ping_timeout=3 * MS + self, + ping_interval=3 * MS, + ping_timeout=3 * MS, ): initial_protocol = self.protocol + # copied from tearDown + self.transport.close() self.loop.run_until_complete(self.protocol.close()) + # copied from setUp, but enables keepalive pings - with warnings.catch_warnings(record=True): - self.protocol = WebSocketCommonProtocol( + + async def create_protocol(): + return WebSocketCommonProtocol( ping_interval=ping_interval, ping_timeout=ping_timeout, ) + + self.protocol = self.loop.run_until_complete(create_protocol()) + self.transport = TransportMock() self.transport.setup_mock(self.loop, self.protocol) self.protocol.is_client = initial_protocol.is_client From 15791c57c28b334bf548e401d037b51f16620604 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:20:17 +0200 Subject: [PATCH 323/762] Clean deprecation warnings in test suite. --- tests/legacy/test_client_server.py | 12 ++++++++++-- tests/legacy/test_framing.py | 6 ++++-- tests/legacy/test_protocol.py | 2 ++ tests/test_imports.py | 1 + tox.ini | 2 +- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 62cddd2b4..d02daede1 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -233,6 +233,7 @@ async def start_server(): return await serve(handler, "localhost", 0, **kwargs) with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") self.server = self.loop.run_until_complete(start_server()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings @@ -259,6 +260,7 @@ async def start_client(): return await connect(server_uri, **kwargs) with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") self.client = self.loop.run_until_complete(start_client()) expected_warnings = [] if deprecation_warnings is None else deprecation_warnings @@ -536,6 +538,7 @@ def legacy_process_request_OK(path, request_headers): @with_server(process_request=legacy_process_request_OK) def test_process_request_argument_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): @@ -563,6 +566,7 @@ def process_request(self, path, request_headers): @with_server(create_protocol=LegacyProcessRequestOKServerProtocol) def test_process_request_override_backwards_compatibility(self): with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") response = self.loop.run_until_complete(self.make_http_request("/")) with contextlib.closing(response): @@ -607,6 +611,7 @@ def test_protocol_deprecated_attributes(self): for server_socket in self.server.sockets ] with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") client_attrs = (self.client.host, self.client.port, self.client.secure) self.assertDeprecationWarnings( recorded_warnings, @@ -620,6 +625,7 @@ def test_protocol_deprecated_attributes(self): expected_server_attrs = ("localhost", 0, self.secure) with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") self.loop.run_until_complete(self.client.send("")) server_attrs = self.loop.run_until_complete(self.client.recv()) self.assertDeprecationWarnings( @@ -1356,7 +1362,8 @@ class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): @with_server() def test_client(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 - with warnings.catch_warnings(record=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") @asyncio.coroutine def run_client(): @@ -1370,7 +1377,8 @@ def run_client(): def test_server(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 - with warnings.catch_warnings(record=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") @asyncio.coroutine def run_server(): diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 4646817a8..035f0f03c 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -17,7 +17,8 @@ def decode(self, message, mask=False, max_size=None, extensions=None): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(message) stream.feed_eof() - with warnings.catch_warnings(record=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") frame = self.loop.run_until_complete( Frame.read( stream.readexactly, @@ -32,7 +33,8 @@ def decode(self, message, mask=False, max_size=None, extensions=None): def encode(self, frame, mask=False, extensions=None): write = unittest.mock.Mock() - with warnings.catch_warnings(record=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") frame.write(write, mask=mask, extensions=extensions) # Ensure the entire frame is sent with a single call to write(). # Multiple calls cause TCP fragmentation and degrade performance. diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 6e7f3727b..4fe0092a5 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -322,6 +322,7 @@ async def create_protocol(): return WebSocketCommonProtocol(ping_interval=None, timeout=5) with warnings.catch_warnings(record=True) as recorded: + warnings.simplefilter("always") protocol = self.loop.run_until_complete(create_protocol()) self.assertEqual(protocol.close_timeout, 5) @@ -332,6 +333,7 @@ def test_loop_backwards_compatibility(self): self.addCleanup(loop.close) with warnings.catch_warnings(record=True) as recorded: + warnings.simplefilter("always") protocol = WebSocketCommonProtocol(ping_interval=None, loop=loop) self.assertEqual(protocol.loop, loop) diff --git a/tests/test_imports.py b/tests/test_imports.py index 8f1625a9b..b69ed9316 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -30,6 +30,7 @@ def test_get_deprecated_alias(self): ) with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") self.assertEqual(self.mod.bar, bar) self.assertEqual(len(recorded_warnings), 1) diff --git a/tox.ini b/tox.ini index c243e9880..20c5320c9 100644 --- a/tox.ini +++ b/tox.ini @@ -2,7 +2,7 @@ envlist = py37,py38,py39,py310,coverage,black,flake8,isort,mypy [testenv] -commands = python -W default -m unittest {posargs} +commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} [testenv:coverage] commands = From 3a4ea9d270b5af431179adf193e1357a46e5fc95 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:22:49 +0200 Subject: [PATCH 324/762] Confirm support for Python 3.11. --- docs/project/changelog.rst | 2 ++ tox.ini | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index dd19228be..6f44bb36b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -33,6 +33,8 @@ They may change at any time. New features ............ +* Validated compatibility with Python 3.11. + * Supported overriding or removing the ``User-Agent`` header in clients and the ``Server`` header in servers. diff --git a/tox.ini b/tox.ini index 20c5320c9..3a284ed31 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,15 @@ [tox] -envlist = py37,py38,py39,py310,coverage,black,flake8,isort,mypy +envlist = + py37 + py38 + py39 + py310 + py311 + coverage + black + flake8 + isort + mypy [testenv] commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} From f461295aefaa517054ff374bdff6362ad1dfac7b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:43:14 +0200 Subject: [PATCH 325/762] Add Python 3.11 for the next release. --- setup.cfg | 2 +- setup.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 91f769620..a300ce628 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py37.py38.py39.py310 +python-tag = py37.py38.py39.py310.py311 [metadata] license_file = LICENSE diff --git a/setup.py b/setup.py index b2d07737d..492b1597c 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', ], package_dir = {'': 'src'}, package_data = {'websockets': ['py.typed']}, From ee54c4db1ad0d7a0701bad90e44950cc51c73ce9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 14:45:45 +0200 Subject: [PATCH 326/762] Clarify comment. --- src/websockets/legacy/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index e01548896..0c69d1317 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -712,7 +712,7 @@ def wrap(self, server: asyncio.base_events.Server) -> None: self.logger.info("server listening on %s", name) # Initialized here because we need a reference to the event loop. - # This should be moved back to __init__ in Python 3.10. + # This should be moved back to __init__ when dropping Python < 3.10. self.closed_waiter = server.get_loop().create_future() def register(self, protocol: WebSocketServerProtocol) -> None: From 8ce4739b7efed3ac78b287da7fb5e537f78e72aa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Oct 2022 19:31:23 +0200 Subject: [PATCH 327/762] Increase maximum header length (again). Fix #1239. --- src/websockets/http11.py | 14 ++++++-------- src/websockets/legacy/http.py | 2 +- tests/test_http11.py | 6 +++--- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 84048fa47..68249192c 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -8,14 +8,12 @@ from . import datastructures, exceptions -# Maximum total size of headers is around 256 * 4 KiB = 1 MiB -MAX_HEADERS = 256 - -# We can use the same limit for the request line and header lines: -# "GET <4096 bytes> HTTP/1.1\r\n" = 4111 bytes -# "Set-Cookie: <4097 bytes>\r\n" = 4111 bytes -# (RFC requires 4096 bytes; for some reason Firefox supports 4097 bytes.) -MAX_LINE = 4111 +# Maximum total size of headers is around 128 * 8 KiB = 1 MiB. +MAX_HEADERS = 128 + +# Limit request line and header lines. 8KiB is the most common default +# configuration of popular HTTP servers. +MAX_LINE = 8192 # Support for HTTP response bodies is intended to read an error message # returned by a server. It isn't designed to perform large file transfers. diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index cc2ef1f06..d9e44cc28 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -192,7 +192,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: """ # Security: this is bounded by the StreamReader's limit (default = 32 KiB). line = await stream.readline() - # Security: this guarantees header values are small (hard-coded = 4 KiB) + # Security: this guarantees header values are small (hard-coded = 8 KiB) if len(line) > MAX_LINE: raise SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 diff --git a/tests/test_http11.py b/tests/test_http11.py index 61d377925..d2e5e0462 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -328,13 +328,13 @@ def test_parse_invalid_value(self): next(self.parse_headers()) def test_parse_too_long_value(self): - self.reader.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") + self.reader.feed_data(b"foo: bar\r\n" * 129 + b"\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) def test_parse_too_long_line(self): - # Header line contains 5 + 4105 + 2 = 4112 bytes. - self.reader.feed_data(b"foo: " + b"a" * 4105 + b"\r\n\r\n") + # Header line contains 5 + 8186 + 2 = 8193 bytes. + self.reader.feed_data(b"foo: " + b"a" * 8186 + b"\r\n\r\n") with self.assertRaises(SecurityError): next(self.parse_headers()) From 86961582f40596efe81d616ac9daa5369e0e17e9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Oct 2022 14:53:10 +0200 Subject: [PATCH 328/762] Add API for connection latency. Fix #1195. --- docs/project/changelog.rst | 6 ++++ docs/reference/client.rst | 2 ++ docs/reference/common.rst | 2 ++ docs/reference/server.rst | 2 ++ docs/topics/timeouts.rst | 28 ++++++----------- src/websockets/legacy/protocol.py | 52 +++++++++++++++++++++---------- tests/legacy/test_protocol.py | 49 +++++++++++++++++++---------- 7 files changed, 90 insertions(+), 51 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 6f44bb36b..8ad7bc5ab 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,12 @@ New features * Validated compatibility with Python 3.11. +* Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property to + protocols. + +* Changed :attr:`~legacy.protocol.WebSocketCommonProtocol.ping` to return the + latency of the connection. + * Supported overriding or removing the ``User-Agent`` header in clients and the ``Server`` header in servers. diff --git a/docs/reference/client.rst b/docs/reference/client.rst index 3016e85d7..b72f49f5d 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -46,6 +46,8 @@ Using a connection .. autoproperty:: closed + .. autoattribute:: latency + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/common.rst b/docs/reference/common.rst index f2683bc77..6ba11bff5 100644 --- a/docs/reference/common.rst +++ b/docs/reference/common.rst @@ -34,6 +34,8 @@ asyncio .. autoproperty:: closed + .. autoattribute:: latency + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 65f98842a..12fe1f806 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -71,6 +71,8 @@ Using a connection .. autoproperty:: closed + .. autoattribute:: latency + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/topics/timeouts.rst b/docs/topics/timeouts.rst index 23e8020c4..633fc1ab4 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/timeouts.rst @@ -20,6 +20,8 @@ infrastructure closes idle connections after 30 to 120 seconds. As a consequence, proxies may terminate WebSocket connections prematurely when no message was exchanged in 30 seconds. +.. _keepalive: + Keepalive in websockets ----------------------- @@ -101,26 +103,14 @@ Latency between a client and a server may increase for two reasons: the default timeout elapses. As a consequence, it closes the connection. This is a reasonable choice to prevent overload. - If traffic spikes cause unwanted timeouts and you're confident that the - server will catch up eventually, you can increase ``ping_timeout`` or you - can disable keepalive entirely with ``ping_interval=None``. + If traffic spikes cause unwanted timeouts and you're confident that the server + will catch up eventually, you can increase ``ping_timeout`` or you can set it + to :obj:`None` to disable heartbeat entirely. The same reasoning applies to situations where the server sends more traffic than the client can accept. -You can monitor latency as follows: - -.. code-block:: python - - import asyncio - import logging - import time - - async def log_latency(websocket, logger): - t0 = time.perf_counter() - pong_waiter = await websocket.ping() - await pong_waiter - t1 = time.perf_counter() - logger.info("Connection latency: %.3f seconds", t1 - t0) - - asyncio.create_task(log_latency(websocket, logging.getLogger())) +The latency measured during the last exchange of Ping and Pong frames is +available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.latency` +attribute. Alternatively, you can measure the latency at any time with the +:attr:`~legacy.protocol.WebSocketCommonProtocol.ping` method. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 04726033e..1b6e58efa 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -7,6 +7,7 @@ import random import ssl import struct +import time import uuid import warnings from typing import ( @@ -21,6 +22,7 @@ List, Mapping, Optional, + Tuple, Union, cast, ) @@ -286,7 +288,19 @@ def __init__( self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, asyncio.Future[None]] = {} + self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + This value is updated after sending a ping frame and receiving a + matching pong frame. Before the first ping, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends ping frames automatically at regular intervals. You can also + send ping frames and measure latency with :meth:`ping`. + """ # Task running the data transfer. self.transfer_data_task: asyncio.Task[None] @@ -802,8 +816,8 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 - A ping may serve as a keepalive or as a check that the remote endpoint - received all messages up to this point + A ping may serve as a keepalive, as a check that the remote endpoint + received all messages up to this point, or to measure :attr:`latency`. Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return immediately, it means the write buffer is full. If you don't want to @@ -818,14 +832,16 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: containing four random bytes. Returns: - ~asyncio.Future: A future that will be completed when the - corresponding pong is received. You can ignore it if you - don't intend to wait. + ~asyncio.Future[float]: A future that will be completed when the + corresponding pong is received. You can ignore it if you don't + intend to wait. The result of the future is the latency of the + connection in seconds. :: pong_waiter = await ws.ping() - await pong_waiter # only if you want to wait for the pong + # only if you want to wait for the corresponding pong + latency = await pong_waiter Raises: ConnectionClosed: when the connection is closed. @@ -846,12 +862,14 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: while data is None or data in self.pings: data = struct.pack("!I", random.getrandbits(32)) - ping_future = self.loop.create_future() - self.pings[data] = ping_future + pong_waiter = self.loop.create_future() + # Resolution of time.monotonic() may be too low on Windows. + ping_timestamp = time.perf_counter() + self.pings[data] = (pong_waiter, ping_timestamp) await self.write_frame(True, OP_PING, data) - return asyncio.shield(ping_future) + return asyncio.shield(pong_waiter) async def pong(self, data: Data = b"") -> None: """ @@ -1122,15 +1140,17 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: elif frame.opcode == OP_PONG: if frame.data in self.pings: + pong_timestamp = time.perf_counter() # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pings.items(): + for ping_id, (pong_waiter, ping_timestamp) in self.pings.items(): ping_ids.append(ping_id) - if not ping.done(): - ping.set_result(None) + if not pong_waiter.done(): + pong_waiter.set_result(pong_timestamp - ping_timestamp) if ping_id == frame.data: + self.latency = pong_timestamp - ping_timestamp break else: # pragma: no cover assert False, "ping_id is in self.pings" @@ -1454,13 +1474,13 @@ def abort_pings(self) -> None: assert self.state is State.CLOSED exc = self.connection_closed_exc() - for ping in self.pings.values(): - ping.set_exception(exc) + for pong_waiter, _ping_timestamp in self.pings.values(): + pong_waiter.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does # nothing, but it prevents logging the exception. - ping.cancel() + pong_waiter.cancel() # asyncio.Protocol methods diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 4fe0092a5..1f830ebee 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -938,33 +938,33 @@ def test_ignore_pong(self): self.assertNoFrameSent() def test_acknowledge_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) - self.assertFalse(ping.done()) + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) + self.assertFalse(pong_waiter.done()) ping_frame = self.last_sent_frame() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.run_loop_once() self.run_loop_once() - self.assertTrue(ping.done()) + self.assertTrue(pong_waiter.done()) def test_abort_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() - self.assertFalse(ping.done()) + self.assertFalse(pong_waiter.done()) self.close_connection() - self.assertTrue(ping.done()) - self.assertIsInstance(ping.exception(), ConnectionClosed) + self.assertTrue(pong_waiter.done()) + self.assertIsInstance(pong_waiter.exception(), ConnectionClosed) def test_abort_ping_does_not_log_exception_if_not_retreived(self): self.loop.run_until_complete(self.protocol.ping()) # Get the internal Future, which isn't directly returned by ping(). - (ping,) = self.protocol.pings.values() + ((pong_waiter, _timestamp),) = self.protocol.pings.values() # Remove the frame from the buffer, else close_connection() complains. self.last_sent_frame() self.close_connection() # Check a private attribute, for lack of a better solution. - self.assertFalse(ping._log_traceback) + self.assertFalse(pong_waiter._log_traceback) def test_acknowledge_previous_pings(self): pings = [ @@ -987,7 +987,7 @@ def test_acknowledge_previous_pings(self): self.assertFalse(pings[2][0].done()) def test_acknowledge_aborted_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() # Clog incoming queue. This lets connection_lost() abort pending pings # with a ConnectionClosed exception before transfer_data_task @@ -1003,7 +1003,7 @@ def test_acknowledge_aborted_ping(self): self.loop.run_until_complete(self.protocol.wait_closed()) # Ping receives a ConnectionClosed exception. with self.assertRaises(ConnectionClosed): - ping.result() + pong_waiter.result() # transfer_data doesn't crash, which would be logged. with self.assertNoLogs(): @@ -1012,14 +1012,14 @@ def test_acknowledge_aborted_ping(self): self.loop.run_until_complete(self.protocol.recv()) def test_canceled_ping(self): - ping = self.loop.run_until_complete(self.protocol.ping()) + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) ping_frame = self.last_sent_frame() - ping.cancel() + pong_waiter.cancel() pong_frame = Frame(True, OP_PONG, ping_frame.data) self.receive_frame(pong_frame) self.run_loop_once() self.run_loop_once() - self.assertTrue(ping.cancelled()) + self.assertTrue(pong_waiter.cancelled()) def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b"foobar")) @@ -1028,6 +1028,23 @@ def test_duplicate_ping(self): self.loop.run_until_complete(self.protocol.ping(b"foobar")) self.assertNoFrameSent() + # Test the protocol's logic for measuring latency + + def test_record_latency_on_pong(self): + self.assertEqual(self.protocol.latency, 0) + self.loop.run_until_complete(self.protocol.ping(b"test")) + self.receive_frame(Frame(True, OP_PONG, b"test")) + self.run_loop_once() + self.assertGreater(self.protocol.latency, 0) + + def test_return_latency_on_pong(self): + pong_waiter = self.loop.run_until_complete(self.protocol.ping()) + ping_frame = self.last_sent_frame() + pong_frame = Frame(True, OP_PONG, ping_frame.data) + self.receive_frame(pong_frame) + latency = self.loop.run_until_complete(pong_waiter) + self.assertGreater(latency, 0) + # Test the protocol's logic for rebuilding fragmented messages. def test_fragmented_text(self): @@ -1244,14 +1261,14 @@ def test_keepalive_ping_does_not_crash_when_connection_lost(self): self.receive_frame(Frame(True, OP_TEXT, b"2")) # Ping is sent at 3ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) - (ping_waiter,) = tuple(self.protocol.pings.values()) + ((pong_waiter, _timestamp),) = self.protocol.pings.values() # Connection drops. self.receive_eof() self.loop.run_until_complete(self.protocol.wait_closed()) # The ping waiter receives a ConnectionClosed exception. with self.assertRaises(ConnectionClosed): - ping_waiter.result() + pong_waiter.result() # The keepalive ping task terminated properly. self.assertIsNone(self.protocol.keepalive_ping_task.result()) From d5cf4a94e599dd678255a2d23712cb82a43ec41a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Oct 2022 15:32:09 +0200 Subject: [PATCH 329/762] Document workaround for bug in Python < 3.10. Fix #1182. --- docs/topics/logging.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 393efeb71..95acf57ff 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -126,7 +126,8 @@ Here's how to include them in logs, assuming they're in the async with websockets.serve( ..., - logger=LoggerAdapter(logging.getLogger("websockets.server")), + # Python < 3.10 requires passing None as the second argument. + logger=LoggerAdapter(logging.getLogger("websockets.server"), None), ): ... @@ -167,7 +168,8 @@ a :class:`~logging.LoggerAdapter`:: async with websockets.serve( ..., - logger=LoggerAdapter(logging.getLogger("websockets.server")), + # Python < 3.10 requires passing None as the second argument. + logger=LoggerAdapter(logging.getLogger("websockets.server"), None), ): ... From a07f4e8190755de47f9915ef89535de18c02f88e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Oct 2022 15:45:48 +0200 Subject: [PATCH 330/762] Highlight issue with convenience imports in FAQ. Fix #1183. --- docs/faq/misc.rst | 51 +++++++++++++++++++++++++++++----------- docs/reference/index.rst | 15 ++++++++---- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 9ef07ef2d..681c5e45a 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -3,27 +3,35 @@ Miscellaneous .. currentmodule:: websockets -Can I use websockets without ``async`` and ``await``? -..................................................... +Why do I get the error: ``module 'websockets' has no attribute '...'``? +....................................................................... -No, there is no convenient way to do this. You should use another library. +Often, this is because you created a script called ``websockets.py`` in your +current working directory. Then ``import websockets`` imports this module +instead of the websockets library. -Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? -............................................................................ +Why does my IDE fail to show documentation for websockets APIs? +............................................................... -No, there aren't. +You are probably using the convenience imports e.g.:: -websockets provides high-level, coroutine-based APIs. Compared to callbacks, -coroutines make it easier to manage control flow in concurrent code. + import websockets -If you prefer callback-based APIs, you should use another library. + websockets.connect(...) + websockets.serve(...) -Why do I get the error: ``module 'websockets' has no attribute '...'``? -....................................................................... +This is incompatible with static code analysis. It may break auto-completion and +contextual documentation in IDEs, type checking with mypy_, etc. -Often, this is because you created a script called ``websockets.py`` in your -current working directory. Then ``import websockets`` imports this module -instead of the websockets library. +.. _mypy: https://github.com/python/mypy + +Instead, use the real import paths e.g.:: + + import websockets.client + import websockets.server + + websockets.client.connect(...) + websockets.server.serve(...) Why is websockets slower than another Python library in my benchmark? ..................................................................... @@ -37,3 +45,18 @@ you may need to disable: * UTF-8 decoding: send ``bytes`` rather than ``str`` If websockets is still slower than another Python library, please file a bug. + +Can I use websockets without ``async`` and ``await``? +..................................................... + +No, there is no convenient way to do this. You should use another library. + +Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? +............................................................................ + +No, there aren't. + +websockets provides high-level, coroutine-based APIs. Compared to callbacks, +coroutines make it easier to manage control flow in concurrent code. + +If you prefer callback-based APIs, you should use another library. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 8147c4cf3..51af8e622 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -58,9 +58,14 @@ Public API documented in the API reference are subject to the Anything that isn't listed in the API reference is a private API. There's no guarantees of behavior or backwards-compatibility for private APIs. -For convenience, many public APIs can be imported from the ``websockets`` -package. However, this feature is incompatible with static code analysis. It -breaks auto-completion in an IDE or type checking with mypy_. If you're using -such tools, use the real import paths. +.. admonition:: Convenience imports are incompatible with some development tools. + :class: caution -.. _mypy: https://github.com/python/mypy + For convenience, most public APIs can be imported from the ``websockets`` + package. However, this is incompatible with static code analysis. + + It may break auto-completion and contextual documentation in IDEs, type + checking with mypy_, etc. If you're using such tools, stick to the full + import paths. + + .. _mypy: https://github.com/python/mypy From 580c41744330a83ff0e4eda728e2aab1a10a4e79 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Oct 2022 15:46:31 +0200 Subject: [PATCH 331/762] Remove obsolete justification. --- docs/reference/index.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 51af8e622..f164dde91 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -37,9 +37,6 @@ are available on client connections created with :func:`~client.connect` and on server connections received in argument by the connection handler of :func:`~server.serve`. -Since websockets provides the same API — and uses the same code — for client -and server connections, common methods are documented in a "Both sides" page. - .. toctree:: :titlesonly: From ad797212ce45bcab7c4cf57d21095a12e8f284ba Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 10 Oct 2022 08:52:29 +0200 Subject: [PATCH 332/762] Don't log when connection drops during handshake. Fix #1237. Refs #984. --- src/websockets/legacy/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 0c69d1317..c32c85612 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -598,7 +598,8 @@ async def handshake( # The connection may drop while process_request is running. if self.state is State.CLOSED: - raise self.connection_closed_exc() # pragma: no cover + # This subclass of ConnectionError is silently ignored in handler(). + raise BrokenPipeError("connection closed during opening handshake") # Change the response to a 503 error if the server is shutting down. if not self.ws_server.is_serving(): From cdd30e453f421d65d9650ee10425996ed12912b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 14 Oct 2022 22:27:50 +0200 Subject: [PATCH 333/762] Make a note to use a new API. --- tests/legacy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index fd5dfc294..bc37ee7bf 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -12,6 +12,8 @@ class AsyncioTestCase(unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. + Replace with IsolatedAsyncioTestCase when dropping Python < 3.8. + """ def __init_subclass__(cls, **kwargs): From eb86f67bce6942b8ec5e3fba68a5f468d01b6eaf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Oct 2022 08:48:20 +0200 Subject: [PATCH 334/762] Remove User-Agent/Server from the Sans-I/O layer. It doesn't make sense to set it at the library level. It should be set by the embedding program. Partially reverts 2a07325c. --- src/websockets/client.py | 9 --------- src/websockets/server.py | 12 ------------ tests/test_client.py | 25 ------------------------- tests/test_server.py | 31 ------------------------------- 4 files changed, 77 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index c18c76f22..373e6b751 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -23,7 +23,6 @@ parse_subprotocol, parse_upgrade, ) -from .http import USER_AGENT from .http11 import Request, Response from .typing import ( ConnectionOption, @@ -64,9 +63,6 @@ class ClientConnection(Connection): logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. - user_agent_header: value of the ``User-Agent`` request header; - defaults to ``"Python/x.y.z websockets/X.Y"``; - :obj:`None` removes the header. """ @@ -76,7 +72,6 @@ def __init__( origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - user_agent_header: Optional[str] = USER_AGENT, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -91,7 +86,6 @@ def __init__( self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols - self.user_agent_header = user_agent_header self.key = generate_key() def connect(self) -> Request: # noqa: F811 @@ -136,9 +130,6 @@ def connect(self) -> Request: # noqa: F811 protocol_header = build_subprotocol(self.available_subprotocols) headers["Sec-WebSocket-Protocol"] = protocol_header - if self.user_agent_header is not None: - headers["User-Agent"] = self.user_agent_header - return Request(self.wsuri.resource_name, headers) def process_response(self, response: Response) -> None: diff --git a/src/websockets/server.py b/src/websockets/server.py index ca48449d6..edd1764c3 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -25,7 +25,6 @@ parse_subprotocol, parse_upgrade, ) -from .http import USER_AGENT from .http11 import Request, Response from .typing import ( ConnectionOption, @@ -64,9 +63,6 @@ class ServerConnection(Connection): logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../topics/logging>` for details. - server_header: value of the ``Server`` response header; - defaults to ``"Python/x.y.z websockets/X.Y"``; - :obj:`None` removes the header. """ @@ -75,7 +71,6 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - server_header: Optional[str] = USER_AGENT, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -89,7 +84,6 @@ def __init__( self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols - self.server_header = server_header def accept(self, request: Request) -> Response: """ @@ -175,9 +169,6 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - if self.server_header is not None: - headers["Server"] = self.server_header - self.logger.info("connection open") return Response(101, "Switching Protocols", headers) @@ -477,9 +468,6 @@ def reject( ("Content-Type", "text/plain; charset=utf-8"), ] ) - if self.server_header is not None: - headers["Server"] = self.server_header - response = Response(status.value, status.phrase, headers, body) # When reject() is called from accept(), handshake_exc is already set. # If a user calls reject(), set handshake_exc to guarantee invariant: diff --git a/tests/test_client.py b/tests/test_client.py index 0504f79e6..9ed36c1d4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,7 +7,6 @@ from websockets.datastructures import Headers from websockets.exceptions import InvalidHandshake, InvalidHeader from websockets.frames import OP_TEXT, Frame -from websockets.http import USER_AGENT from websockets.http11 import Request, Response from websockets.uri import parse_uri from websockets.utils import accept_key @@ -38,7 +37,6 @@ def test_send_connect(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" - f"User-Agent: {USER_AGENT}\r\n" f"\r\n".encode() ], ) @@ -58,7 +56,6 @@ def test_connect_request(self): "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", - "User-Agent": USER_AGENT, } ), ) @@ -130,7 +127,6 @@ def test_receive_accept(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) @@ -148,7 +144,6 @@ def test_receive_reject(self): ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" @@ -173,7 +168,6 @@ def test_accept_response(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) @@ -188,7 +182,6 @@ def test_accept_response(self): "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, "Date": DATE, - "Server": USER_AGENT, } ), ) @@ -202,7 +195,6 @@ def test_reject_response(self): ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" - f"Server: {USER_AGENT}\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" @@ -218,7 +210,6 @@ def test_reject_response(self): Headers( { "Date": DATE, - "Server": USER_AGENT, "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", "Connection": "close", @@ -574,22 +565,6 @@ def test_unsupported_subprotocol(self): raise client.handshake_exc self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") - def test_no_user_agent_header(self): - client = ClientConnection( - parse_uri("wss://example.com/"), - user_agent_header=None, - ) - request = client.connect() - self.assertNotIn("User-Agent", request.headers) - - def test_custom_user_agent_header(self): - client = ClientConnection( - parse_uri("wss://example.com/"), - user_agent_header="websockets", - ) - request = client.connect() - self.assertEqual(request.headers["User-Agent"], "websockets") - class MiscTests(unittest.TestCase): def test_bypass_handshake(self): diff --git a/tests/test_server.py b/tests/test_server.py index c7f398cdf..f1404499b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,7 +7,6 @@ from websockets.datastructures import Headers from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade from websockets.frames import OP_TEXT, Frame -from websockets.http import USER_AGENT from websockets.http11 import Request, Response from websockets.server import * @@ -32,7 +31,6 @@ def test_receive_connect(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" - f"User-Agent: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) @@ -51,7 +49,6 @@ def test_connect_request(self): f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" - f"User-Agent: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) @@ -66,7 +63,6 @@ def test_connect_request(self): "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", - "User-Agent": USER_AGENT, } ), ) @@ -83,7 +79,6 @@ def make_request(self): "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", - "User-Agent": USER_AGENT, } ), ) @@ -102,7 +97,6 @@ def test_send_accept(self): f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" - f"Server: {USER_AGENT}\r\n" f"\r\n".encode() ], ) @@ -123,7 +117,6 @@ def test_send_reject(self): f"Connection: close\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" - f"Server: {USER_AGENT}\r\n" f"\r\n" f"Sorry folks.\n".encode(), b"", @@ -147,7 +140,6 @@ def test_accept_response(self): "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, - "Server": USER_AGENT, } ), ) @@ -168,7 +160,6 @@ def test_reject_response(self): "Connection": "close", "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", - "Server": USER_AGENT, } ), ) @@ -608,28 +599,6 @@ def test_unsupported_subprotocol(self): self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) - def test_no_server_header(self): - server = ServerConnection(server_header=None) - request = self.make_request() - response = server.accept(request) - self.assertNotIn("Server", response.headers) - - def test_custom_server_header(self): - server = ServerConnection(server_header="websockets") - request = self.make_request() - response = server.accept(request) - self.assertEqual(response.headers["Server"], "websockets") - - def test_reject_response_no_server_header(self): - server = ServerConnection(server_header=None) - response = server.reject(http.HTTPStatus.OK, "Hello world!\n") - self.assertNotIn("Server", response.headers) - - def test_reject_response_custom_server_header(self): - server = ServerConnection(server_header="websockets") - response = server.reject(http.HTTPStatus.OK, "Hello world!\n") - self.assertEqual(response.headers["Server"], "websockets") - class MiscTests(unittest.TestCase): def test_bypass_handshake(self): From 4ae6d7b2a21b02bffa52f2653d344b6da07413cb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 17:58:49 +0200 Subject: [PATCH 335/762] Include correct files in make coverage & tox -e coverage. --- Makefile | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 578f6b1ae..cd8095ae1 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ test: python -m unittest coverage: - coverage run --source websockets,tests -m unittest + coverage run --source src/websockets,tests -m unittest coverage html coverage report --show-missing --fail-under=100 diff --git a/tox.ini b/tox.ini index 3a284ed31..2f37cdcbf 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,7 @@ commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarni [testenv:coverage] commands = python -m coverage erase - python -W default -m coverage run -m unittest {posargs} + python -m coverage run --source {envsitepackagesdir}/websockets,tests -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 deps = coverage From 781f5b8c46070b74d1761c709b67820348eb978b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Oct 2022 09:09:33 +0200 Subject: [PATCH 336/762] Standardize access to State and Side enum values. Access them directly rather than as attributes of the State class. --- tests/test_connection.py | 395 +++++++++++++++++++-------------------- 1 file changed, 196 insertions(+), 199 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 3d4d98436..44f5bc35e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,6 +1,7 @@ import unittest.mock from websockets.connection import * +from websockets.connection import CLIENT, CLOSED, CLOSING, SERVER from websockets.exceptions import ( ConnectionClosedError, ConnectionClosedOK, @@ -38,7 +39,7 @@ def assertFrameSent(self, connection, frame, eof=False): if write is SEND_EOF else self.parse( write, - mask=connection.side is Side.CLIENT, + mask=connection.side is CLIENT, extensions=connection.extensions, ) for write in connection.data_to_send() @@ -71,9 +72,7 @@ def assertConnectionClosing(self, connection, code=None, reason=""): # A close frame was received. self.assertFrameReceived(connection, close_frame) # A close frame and possibly the end of stream were sent. - self.assertFrameSent( - connection, close_frame, eof=connection.side is Side.SERVER - ) + self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) def assertConnectionFailing(self, connection, code=None, reason=""): """ @@ -87,9 +86,7 @@ def assertConnectionFailing(self, connection, code=None, reason=""): # No frame was received. self.assertFrameReceived(connection, None) # A close frame and possibly the end of stream were sent. - self.assertFrameSent( - connection, close_frame, eof=connection.side is Side.SERVER - ) + self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) class MaskingTests(ConnectionTestCase): @@ -104,18 +101,18 @@ class MaskingTests(ConnectionTestCase): masked_text_frame_data = b"\x81\x84\x00\xff\x00\xff\x53\x8f\x61\x92" def test_client_sends_masked_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\xff\x00\xff"): client.send_text(b"Spam", True) self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) def test_server_sends_unmasked_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text(b"Spam", True) self.assertEqual(server.data_to_send(), [self.unmasked_text_frame_date]) def test_client_receives_unmasked_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(self.unmasked_text_frame_date) self.assertFrameReceived( client, @@ -123,7 +120,7 @@ def test_client_receives_unmasked_frame(self): ) def test_server_receives_masked_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(self.masked_text_frame_data) self.assertFrameReceived( server, @@ -131,14 +128,14 @@ def test_server_receives_masked_frame(self): ) def test_client_receives_masked_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(self.masked_text_frame_data) self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incorrect masking") self.assertConnectionFailing(client, 1002, "incorrect masking") def test_server_receives_unmasked_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(self.unmasked_text_frame_date) self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incorrect masking") @@ -152,33 +149,33 @@ class ContinuationTests(ConnectionTestCase): """ def test_client_sends_unexpected_continuation(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.assertRaises(ProtocolError) as raised: client.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_server_sends_unexpected_continuation(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_client_receives_unexpected_continuation(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x00\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "unexpected continuation frame") self.assertConnectionFailing(client, 1002, "unexpected continuation frame") def test_server_receives_unexpected_continuation(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x00\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "unexpected continuation frame") self.assertConnectionFailing(server, 1002, "unexpected continuation frame") def test_client_sends_continuation_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) # Since it isn't possible to send a close frame in a fragmented # message (see test_client_send_close_in_fragmented_message), in fact, # this is the same test as test_client_sends_unexpected_continuation. @@ -193,7 +190,7 @@ def test_server_sends_continuation_after_sending_close(self): # Since it isn't possible to send a close frame in a fragmented # message (see test_server_send_close_in_fragmented_message), in fact, # this is the same test as test_server_sends_unexpected_continuation. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(ProtocolError) as raised: @@ -201,7 +198,7 @@ def test_server_sends_continuation_after_sending_close(self): self.assertEqual(str(raised.exception), "unexpected continuation frame") def test_client_receives_continuation_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x00\x00") @@ -209,7 +206,7 @@ def test_client_receives_continuation_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_continuation_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x00\x80\x00\xff\x00\xff") @@ -224,7 +221,7 @@ class TextTests(ConnectionTestCase): """ def test_client_sends_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_text("😀".encode()) self.assertEqual( @@ -232,12 +229,12 @@ def test_client_sends_text(self): ) def test_server_sends_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text("😀".encode()) self.assertEqual(server.data_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) def test_client_receives_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, @@ -245,7 +242,7 @@ def test_client_receives_text(self): ) def test_server_receives_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, @@ -253,21 +250,21 @@ def test_server_receives_text(self): ) def test_client_receives_text_over_size_limit(self): - client = Connection(Side.CLIENT, max_size=3) + client = Connection(CLIENT, max_size=3) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") def test_server_receives_text_over_size_limit(self): - server = Connection(Side.SERVER, max_size=3) + server = Connection(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") def test_client_receives_text_without_size_limit(self): - client = Connection(Side.CLIENT, max_size=None) + client = Connection(CLIENT, max_size=None) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertFrameReceived( client, @@ -275,7 +272,7 @@ def test_client_receives_text_without_size_limit(self): ) def test_server_receives_text_without_size_limit(self): - server = Connection(Side.SERVER, max_size=None) + server = Connection(SERVER, max_size=None) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertFrameReceived( server, @@ -283,7 +280,7 @@ def test_server_receives_text_without_size_limit(self): ) def test_client_sends_fragmented_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_text("😀".encode()[:2], fin=False) self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) @@ -297,7 +294,7 @@ def test_client_sends_fragmented_text(self): self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) def test_server_sends_fragmented_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text("😀".encode()[:2], fin=False) self.assertEqual(server.data_to_send(), [b"\x01\x02\xf0\x9f"]) server.send_continuation("😀😀".encode()[2:6], fin=False) @@ -306,7 +303,7 @@ def test_server_sends_fragmented_text(self): self.assertEqual(server.data_to_send(), [b"\x80\x02\x98\x80"]) def test_client_receives_fragmented_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, @@ -324,7 +321,7 @@ def test_client_receives_fragmented_text(self): ) def test_server_receives_fragmented_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, @@ -342,7 +339,7 @@ def test_server_receives_fragmented_text(self): ) def test_client_receives_fragmented_text_over_size_limit(self): - client = Connection(Side.CLIENT, max_size=3) + client = Connection(CLIENT, max_size=3) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, @@ -354,7 +351,7 @@ def test_client_receives_fragmented_text_over_size_limit(self): self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") def test_server_receives_fragmented_text_over_size_limit(self): - server = Connection(Side.SERVER, max_size=3) + server = Connection(SERVER, max_size=3) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, @@ -366,7 +363,7 @@ def test_server_receives_fragmented_text_over_size_limit(self): self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") def test_client_receives_fragmented_text_without_size_limit(self): - client = Connection(Side.CLIENT, max_size=None) + client = Connection(CLIENT, max_size=None) client.receive_data(b"\x01\x02\xf0\x9f") self.assertFrameReceived( client, @@ -384,7 +381,7 @@ def test_client_receives_fragmented_text_without_size_limit(self): ) def test_server_receives_fragmented_text_without_size_limit(self): - server = Connection(Side.SERVER, max_size=None) + server = Connection(SERVER, max_size=None) server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") self.assertFrameReceived( server, @@ -402,21 +399,21 @@ def test_server_receives_fragmented_text_without_size_limit(self): ) def test_client_sends_unexpected_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_text(b"", fin=False) with self.assertRaises(ProtocolError) as raised: client.send_text(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_server_sends_unexpected_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text(b"", fin=False) with self.assertRaises(ProtocolError) as raised: server.send_text(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receives_unexpected_text(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x01\x00") self.assertFrameReceived( client, @@ -428,7 +425,7 @@ def test_client_receives_unexpected_text(self): self.assertConnectionFailing(client, 1002, "expected a continuation frame") def test_server_receives_unexpected_text(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertFrameReceived( server, @@ -440,7 +437,7 @@ def test_server_receives_unexpected_text(self): self.assertConnectionFailing(server, 1002, "expected a continuation frame") def test_client_sends_text_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) @@ -448,14 +445,14 @@ def test_client_sends_text_after_sending_close(self): client.send_text(b"") def test_server_sends_text_after_sending_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_text(b"") def test_client_receives_text_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x81\x00") @@ -463,7 +460,7 @@ def test_client_receives_text_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_text_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x81\x80\x00\xff\x00\xff") @@ -478,7 +475,7 @@ class BinaryTests(ConnectionTestCase): """ def test_client_sends_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02\xfe\xff") self.assertEqual( @@ -486,12 +483,12 @@ def test_client_sends_binary(self): ) def test_server_sends_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_binary(b"\x01\x02\xfe\xff") self.assertEqual(server.data_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) def test_client_receives_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertFrameReceived( client, @@ -499,7 +496,7 @@ def test_client_receives_binary(self): ) def test_server_receives_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertFrameReceived( server, @@ -507,21 +504,21 @@ def test_server_receives_binary(self): ) def test_client_receives_binary_over_size_limit(self): - client = Connection(Side.CLIENT, max_size=3) + client = Connection(CLIENT, max_size=3) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") def test_server_receives_binary_over_size_limit(self): - server = Connection(Side.SERVER, max_size=3) + server = Connection(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") def test_client_sends_fragmented_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02", fin=False) self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) @@ -535,7 +532,7 @@ def test_client_sends_fragmented_binary(self): self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) def test_server_sends_fragmented_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_binary(b"\x01\x02", fin=False) self.assertEqual(server.data_to_send(), [b"\x02\x02\x01\x02"]) server.send_continuation(b"\xee\xff\x01\x02", fin=False) @@ -544,7 +541,7 @@ def test_server_sends_fragmented_binary(self): self.assertEqual(server.data_to_send(), [b"\x80\x02\xee\xff"]) def test_client_receives_fragmented_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, @@ -562,7 +559,7 @@ def test_client_receives_fragmented_binary(self): ) def test_server_receives_fragmented_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, @@ -580,7 +577,7 @@ def test_server_receives_fragmented_binary(self): ) def test_client_receives_fragmented_binary_over_size_limit(self): - client = Connection(Side.CLIENT, max_size=3) + client = Connection(CLIENT, max_size=3) client.receive_data(b"\x02\x02\x01\x02") self.assertFrameReceived( client, @@ -592,7 +589,7 @@ def test_client_receives_fragmented_binary_over_size_limit(self): self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") def test_server_receives_fragmented_binary_over_size_limit(self): - server = Connection(Side.SERVER, max_size=3) + server = Connection(SERVER, max_size=3) server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") self.assertFrameReceived( server, @@ -604,21 +601,21 @@ def test_server_receives_fragmented_binary_over_size_limit(self): self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") def test_client_sends_unexpected_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_binary(b"", fin=False) with self.assertRaises(ProtocolError) as raised: client.send_binary(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_server_sends_unexpected_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_binary(b"", fin=False) with self.assertRaises(ProtocolError) as raised: server.send_binary(b"", fin=False) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receives_unexpected_binary(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x02\x00") self.assertFrameReceived( client, @@ -630,7 +627,7 @@ def test_client_receives_unexpected_binary(self): self.assertConnectionFailing(client, 1002, "expected a continuation frame") def test_server_receives_unexpected_binary(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertFrameReceived( server, @@ -642,7 +639,7 @@ def test_server_receives_unexpected_binary(self): self.assertConnectionFailing(server, 1002, "expected a continuation frame") def test_client_sends_binary_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) @@ -650,14 +647,14 @@ def test_client_sends_binary_after_sending_close(self): client.send_binary(b"") def test_server_sends_binary_after_sending_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_binary(b"") def test_client_receives_binary_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x82\x00") @@ -665,7 +662,7 @@ def test_client_receives_binary_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_binary_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x82\x80\x00\xff\x00\xff") @@ -686,78 +683,78 @@ class CloseTests(ConnectionTestCase): """ def test_close_code(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") client.receive_eof() self.assertEqual(client.close_code, 1000) def test_close_reason(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") server.receive_eof() self.assertEqual(server.close_reason, "OK") def test_close_code_not_provided(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x00\x00\x00\x00") server.receive_eof() self.assertEqual(server.close_code, 1005) def test_close_reason_not_provided(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") client.receive_eof() self.assertEqual(client.close_reason, "") def test_close_code_not_available(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_eof() self.assertEqual(client.close_code, 1006) def test_close_reason_not_available(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_eof() self.assertEqual(server.close_reason, "") def test_close_code_not_available_yet(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) self.assertIsNone(server.close_code) def test_close_reason_not_available_yet(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) self.assertIsNone(client.close_reason) def test_client_sends_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_sends_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close() self.assertEqual(server.data_to_send(), [b"\x88\x00"]) - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_receives_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): client.receive_data(b"\x88\x00") self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_sends_close_then_receives_close(self): # Client-initiated close handshake on the client side. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_close() self.assertFrameReceived(client, None) @@ -773,7 +770,7 @@ def test_client_sends_close_then_receives_close(self): def test_server_sends_close_then_receives_close(self): # Server-initiated close handshake on the server side. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close() self.assertFrameReceived(server, None) @@ -789,7 +786,7 @@ def test_server_sends_close_then_receives_close(self): def test_client_receives_close_then_sends_close(self): # Server-initiated close handshake on the client side. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) @@ -801,7 +798,7 @@ def test_client_receives_close_then_sends_close(self): def test_server_receives_close_then_sends_close(self): # Client-initiated close handshake on the server side. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) @@ -812,87 +809,87 @@ def test_server_receives_close_then_sends_close(self): self.assertFrameSent(server, None) def test_client_sends_close_with_code(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_sends_close_with_code(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000, "") - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001, "") - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_sends_close_with_code_and_reason(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001, "going away") self.assertEqual( client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] ) - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_sends_close_with_code_and_reason(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000, "OK") self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code_and_reason(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") self.assertConnectionClosing(client, 1000, "OK") - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code_and_reason(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") self.assertConnectionClosing(server, 1001, "going away") - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_sends_close_with_reason_only(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.assertRaises(ProtocolError) as raised: client.send_close(reason="going away") self.assertEqual(str(raised.exception), "cannot send a reason without a code") def test_server_sends_close_with_reason_only(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) with self.assertRaises(ProtocolError) as raised: server.send_close(reason="OK") self.assertEqual(str(raised.exception), "cannot send a reason without a code") def test_client_receives_close_with_truncated_code(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x01\x03") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "close frame too short") self.assertConnectionFailing(client, 1002, "close frame too short") - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close_with_truncated_code(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "close frame too short") self.assertConnectionFailing(server, 1002, "close frame too short") - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) def test_client_receives_close_with_non_utf8_reason(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x04\x03\xe8\xff\xff") self.assertIsInstance(client.parser_exc, UnicodeDecodeError) @@ -901,10 +898,10 @@ def test_client_receives_close_with_non_utf8_reason(self): "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") - self.assertIs(client.state, State.CLOSING) + self.assertIs(client.state, CLOSING) def test_server_receives_close_with_non_utf8_reason(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") self.assertIsInstance(server.parser_exc, UnicodeDecodeError) @@ -913,7 +910,7 @@ def test_server_receives_close_with_non_utf8_reason(self): "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") - self.assertIs(server.state, State.CLOSING) + self.assertIs(server.state, CLOSING) class PingTests(ConnectionTestCase): @@ -923,18 +920,18 @@ class PingTests(ConnectionTestCase): """ def test_client_sends_ping(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_ping(b"") self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x89\x00") self.assertFrameReceived( client, @@ -946,7 +943,7 @@ def test_client_receives_ping(self): ) def test_server_receives_ping(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x89\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, @@ -958,7 +955,7 @@ def test_server_receives_ping(self): ) def test_client_sends_ping_with_data(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"\x22\x66\xaa\xee") self.assertEqual( @@ -966,12 +963,12 @@ def test_client_sends_ping_with_data(self): ) def test_server_sends_ping_with_data(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_ping(b"\x22\x66\xaa\xee") self.assertEqual(server.data_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) def test_client_receives_ping_with_data(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, @@ -983,7 +980,7 @@ def test_client_receives_ping_with_data(self): ) def test_server_receives_ping_with_data(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, @@ -995,35 +992,35 @@ def test_server_receives_ping_with_data(self): ) def test_client_sends_fragmented_ping_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: client.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_ping_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: server.send_frame(Frame(OP_PING, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_ping_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x09\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing(client, 1002, "fragmented control frame") def test_server_receives_fragmented_ping_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing(server, 1002, "fragmented control frame") def test_client_sends_ping_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) @@ -1038,7 +1035,7 @@ def test_client_sends_ping_after_sending_close(self): ) def test_server_sends_ping_after_sending_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) @@ -1052,7 +1049,7 @@ def test_server_sends_ping_after_sending_close(self): ) def test_client_receives_ping_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") @@ -1060,7 +1057,7 @@ def test_client_receives_ping_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_ping_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") @@ -1075,18 +1072,18 @@ class PongTests(ConnectionTestCase): """ def test_client_sends_pong(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"") self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_pong(b"") self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x8a\x00") self.assertFrameReceived( client, @@ -1094,7 +1091,7 @@ def test_client_receives_pong(self): ) def test_server_receives_pong(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") self.assertFrameReceived( server, @@ -1102,7 +1099,7 @@ def test_server_receives_pong(self): ) def test_client_sends_pong_with_data(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"\x22\x66\xaa\xee") self.assertEqual( @@ -1110,12 +1107,12 @@ def test_client_sends_pong_with_data(self): ) def test_server_sends_pong_with_data(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_pong(b"\x22\x66\xaa\xee") self.assertEqual(server.data_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) def test_client_receives_pong_with_data(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived( client, @@ -1123,7 +1120,7 @@ def test_client_receives_pong_with_data(self): ) def test_server_receives_pong_with_data(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived( server, @@ -1131,35 +1128,35 @@ def test_server_receives_pong_with_data(self): ) def test_client_sends_fragmented_pong_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: client.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_server_sends_fragmented_pong_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) # This is only possible through a private API. with self.assertRaises(ProtocolError) as raised: server.send_frame(Frame(OP_PONG, b"", fin=False)) self.assertEqual(str(raised.exception), "fragmented control frame") def test_client_receives_fragmented_pong_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x0a\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") self.assertConnectionFailing(client, 1002, "fragmented control frame") def test_server_receives_fragmented_pong_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") self.assertConnectionFailing(server, 1002, "fragmented control frame") def test_client_sends_pong_after_sending_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(1001) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) @@ -1168,7 +1165,7 @@ def test_client_sends_pong_after_sending_close(self): client.send_pong(b"") def test_server_sends_pong_after_sending_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # websockets doesn't support sending a Pong frame after a Close frame. @@ -1176,7 +1173,7 @@ def test_server_sends_pong_after_sending_close(self): server.send_pong(b"") def test_client_receives_pong_after_receiving_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, 1000) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") @@ -1184,7 +1181,7 @@ def test_client_receives_pong_after_receiving_close(self): self.assertFrameSent(client, None) def test_server_receives_pong_after_receiving_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, 1001) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") @@ -1201,14 +1198,14 @@ class FailTests(ConnectionTestCase): """ def test_client_stops_processing_frames_after_fail(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.fail(1002) self.assertConnectionFailing(client, 1002) client.receive_data(b"\x88\x02\x03\xea") self.assertFrameReceived(client, None) def test_server_stops_processing_frames_after_fail(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.fail(1002) self.assertConnectionFailing(server, 1002) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") @@ -1224,7 +1221,7 @@ class FragmentationTests(ConnectionTestCase): """ def test_client_send_ping_pong_in_fragmented_message(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) client.send_ping(b"Ping") @@ -1237,7 +1234,7 @@ def test_client_send_ping_pong_in_fragmented_message(self): self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) def test_server_send_ping_pong_in_fragmented_message(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) server.send_ping(b"Ping") @@ -1250,7 +1247,7 @@ def test_server_send_ping_pong_in_fragmented_message(self): self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) def test_client_receive_ping_pong_in_fragmented_message(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, @@ -1282,7 +1279,7 @@ def test_client_receive_ping_pong_in_fragmented_message(self): ) def test_server_receive_ping_pong_in_fragmented_message(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, @@ -1314,7 +1311,7 @@ def test_server_receive_ping_pong_in_fragmented_message(self): ) def test_client_send_close_in_fragmented_message(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) # The spec says: "An endpoint MUST be capable of handling control @@ -1327,7 +1324,7 @@ def test_client_send_close_in_fragmented_message(self): client.send_continuation(b"Eggs", fin=True) def test_server_send_close_in_fragmented_message(self): - server = Connection(Side.CLIENT) + server = Connection(CLIENT) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) # The spec says: "An endpoint MUST be capable of handling control @@ -1339,7 +1336,7 @@ def test_server_send_close_in_fragmented_message(self): self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receive_close_in_fragmented_message(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x01\x04Spam") self.assertFrameReceived( client, @@ -1355,7 +1352,7 @@ def test_client_receive_close_in_fragmented_message(self): self.assertConnectionFailing(client, 1002, "incomplete fragmented message") def test_server_receive_close_in_fragmented_message(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") self.assertFrameReceived( server, @@ -1378,35 +1375,35 @@ class EOFTests(ConnectionTestCase): """ def test_client_receives_eof(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() - self.assertIs(client.state, State.CLOSED) + self.assertIs(client.state, CLOSED) def test_server_receives_eof(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() - self.assertIs(server.state, State.CLOSED) + self.assertIs(server.state, CLOSED) def test_client_receives_eof_between_frames(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) self.assertEqual(str(client.parser_exc), "unexpected end of stream") - self.assertIs(client.state, State.CLOSED) + self.assertIs(client.state, CLOSED) def test_server_receives_eof_between_frames(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) self.assertEqual(str(server.parser_exc), "unexpected end of stream") - self.assertIs(server.state, State.CLOSED) + self.assertIs(server.state, CLOSED) def test_client_receives_eof_inside_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x81") client.receive_eof() self.assertIsInstance(client.parser_exc, EOFError) @@ -1414,10 +1411,10 @@ def test_client_receives_eof_inside_frame(self): str(client.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) - self.assertIs(client.state, State.CLOSED) + self.assertIs(client.state, CLOSED) def test_server_receives_eof_inside_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x81") server.receive_eof() self.assertIsInstance(server.parser_exc, EOFError) @@ -1425,38 +1422,38 @@ def test_server_receives_eof_inside_frame(self): str(server.parser_exc), "stream ends after 1 bytes, expected 2 bytes", ) - self.assertIs(server.state, State.CLOSED) + self.assertIs(server.state, CLOSED) def test_client_receives_data_after_exception(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_data(b"\x00\x00") self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_data(b"\x00\x00") self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_eof() self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_eof() self.assertFrameSent(server, None) def test_client_receives_data_and_eof_after_exception(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\xff\xff") self.assertConnectionFailing(client, 1002, "invalid opcode") client.receive_data(b"\x00\x00") @@ -1464,7 +1461,7 @@ def test_client_receives_data_and_eof_after_exception(self): self.assertFrameSent(client, None, eof=True) def test_server_receives_data_and_eof_after_exception(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\xff\xff") self.assertConnectionFailing(server, 1002, "invalid opcode") server.receive_data(b"\x00\x00") @@ -1472,7 +1469,7 @@ def test_server_receives_data_and_eof_after_exception(self): self.assertFrameSent(server, None) def test_client_receives_data_after_eof(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() @@ -1481,7 +1478,7 @@ def test_client_receives_data_after_eof(self): self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_data_after_eof(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() @@ -1490,7 +1487,7 @@ def test_server_receives_data_after_eof(self): self.assertEqual(str(raised.exception), "stream ended") def test_client_receives_eof_after_eof(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() @@ -1499,7 +1496,7 @@ def test_client_receives_eof_after_eof(self): self.assertEqual(str(raised.exception), "stream ended") def test_server_receives_eof_after_eof(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() @@ -1515,52 +1512,52 @@ class TCPCloseTests(ConnectionTestCase): """ def test_client_default(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) self.assertFalse(client.close_expected()) def test_server_default(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) self.assertFalse(server.close_expected()) def test_client_sends_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_close() self.assertTrue(client.close_expected()) def test_server_sends_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close() self.assertTrue(server.close_expected()) def test_client_receives_close(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") self.assertTrue(client.close_expected()) def test_client_receives_close_then_eof(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x00") client.receive_eof() self.assertFalse(client.close_expected()) def test_server_receives_close_then_eof(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") server.receive_eof() self.assertFalse(server.close_expected()) def test_server_receives_close(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertTrue(server.close_expected()) def test_client_fails_connection(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.fail(1002) self.assertTrue(client.close_expected()) def test_server_fails_connection(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.fail(1002) self.assertTrue(server.close_expected()) @@ -1573,7 +1570,7 @@ class ConnectionClosedTests(ConnectionTestCase): def test_client_sends_close_then_receives_close(self): # Client-initiated close handshake on the client side complete. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_close(1000, "") client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() @@ -1585,7 +1582,7 @@ def test_client_sends_close_then_receives_close(self): def test_server_sends_close_then_receives_close(self): # Server-initiated close handshake on the server side complete. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000, "") server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() @@ -1597,7 +1594,7 @@ def test_server_sends_close_then_receives_close(self): def test_client_receives_close_then_sends_close(self): # Server-initiated close handshake on the client side complete. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() exc = client.close_exc @@ -1608,7 +1605,7 @@ def test_client_receives_close_then_sends_close(self): def test_server_receives_close_then_sends_close(self): # Client-initiated close handshake on the server side complete. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() exc = server.close_exc @@ -1619,7 +1616,7 @@ def test_server_receives_close_then_sends_close(self): def test_client_sends_close_then_receives_eof(self): # Client-initiated close handshake on the client side times out. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.send_close(1000, "") client.receive_eof() exc = client.close_exc @@ -1630,7 +1627,7 @@ def test_client_sends_close_then_receives_eof(self): def test_server_sends_close_then_receives_eof(self): # Server-initiated close handshake on the server side times out. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.send_close(1000, "") server.receive_eof() exc = server.close_exc @@ -1641,7 +1638,7 @@ def test_server_sends_close_then_receives_eof(self): def test_client_receives_eof(self): # Server-initiated close handshake on the client side times out. - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) @@ -1651,7 +1648,7 @@ def test_client_receives_eof(self): def test_server_receives_eof(self): # Client-initiated close handshake on the server side times out. - server = Connection(Side.SERVER) + server = Connection(SERVER) server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) @@ -1667,7 +1664,7 @@ class ErrorTests(ConnectionTestCase): """ def test_client_hits_internal_error_reading_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) # This isn't supposed to happen, so we're simulating it. with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): client.receive_data(b"\x81\x00") @@ -1676,7 +1673,7 @@ def test_client_hits_internal_error_reading_frame(self): self.assertConnectionFailing(client, 1011, "") def test_server_hits_internal_error_reading_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) # This isn't supposed to happen, so we're simulating it. with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): server.receive_data(b"\x81\x80\x00\x00\x00\x00") @@ -1692,26 +1689,26 @@ class ExtensionsTests(ConnectionTestCase): """ def test_client_extension_encodes_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.extensions = [Rsv2Extension()] with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) def test_server_extension_encodes_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.extensions = [Rsv2Extension()] server.send_ping(b"") self.assertEqual(server.data_to_send(), [b"\xa9\x00"]) def test_client_extension_decodes_frame(self): - client = Connection(Side.CLIENT) + client = Connection(CLIENT) client.extensions = [Rsv2Extension()] client.receive_data(b"\xaa\x00") self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) def test_server_extension_decodes_frame(self): - server = Connection(Side.SERVER) + server = Connection(SERVER) server.extensions = [Rsv2Extension()] server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) From 413ae6ee4e75c03a790400f24d691e3e6badeb45 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 Oct 2022 11:24:32 +0200 Subject: [PATCH 337/762] Mark function for future removal. --- tests/legacy/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index bc37ee7bf..302fd70be 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -57,6 +57,7 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() + # Remove when dropping Python < 3.10 @contextlib.contextmanager def assertNoLogs(self, logger="websockets", level=logging.ERROR): """ From 42ce29ab423233c936e59c0502c44549ec2bdf87 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 17 Aug 2022 20:17:34 +0200 Subject: [PATCH 338/762] Add Python 3.11. --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4785e8d20..4c1b1e899 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,6 +43,7 @@ jobs: - "3.8" - "3.9" - "3.10" + - "3.11" - "pypy-3.7" - "pypy-3.8" - "pypy-3.9" From d160c1bc51ea83700d6de2d7da928aba3a034b52 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 15 Oct 2022 05:18:59 +0000 Subject: [PATCH 339/762] Bump pypa/cibuildwheel from 2.10.0 to 2.11.1 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.10.0 to 2.11.1. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.10.0...v2.11.1) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index cc22cd4c6..0013ad103 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -48,7 +48,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.10.0 + uses: pypa/cibuildwheel@v2.11.1 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From 9230cca4b5e6e100a57407ac61dbf20008b7225a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Oct 2022 21:48:26 +0200 Subject: [PATCH 340/762] Complete changelog for 10.4. --- docs/project/changelog.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8ad7bc5ab..c641a050e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -44,6 +44,13 @@ New features * Supported overriding or removing the ``User-Agent`` header in clients and the ``Server`` header in servers. +* Added deployment guides for more Platform as a Service providers. + +Improvements +............ + +* Improved FAQ. + 10.3 ---- From d8c17625857797db029110d74417ae5f840aeb75 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Oct 2022 21:39:56 +0200 Subject: [PATCH 341/762] Release version 10.4 --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c641a050e..8ad509a08 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ They may change at any time. 10.4 ---- -*In development* +*October 25, 2022* New features ............ diff --git a/src/websockets/version.py b/src/websockets/version.py index 29d658ce2..3fdc4fad9 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "10.4" From b2847c2786a88c8c6f6017081f828032fb42ded3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Oct 2022 22:04:09 +0200 Subject: [PATCH 342/762] Start version 11.0 --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8ad509a08..61136a36c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented APIs are considered private. They may change at any time. +11.0 +---- + +*In development* + 10.4 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index 3fdc4fad9..1a3d884d8 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "10.4" +tag = version = commit = "11.0" if not released: # pragma: no cover From 06ffba57f9460ff577eada5ad3fe2558593be4b9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 25 Oct 2022 22:42:28 +0200 Subject: [PATCH 343/762] Fix deprecation warning. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a300ce628..9126f55d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ python-tag = py37.py38.py39.py310.py311 [metadata] -license_file = LICENSE +license_files = LICENSE project_urls = Changelog = https://websockets.readthedocs.io/en/stable/project/changelog.html Documentation = https://websockets.readthedocs.io/ From 2cbb8134acfa8b56aba5c40f0479f12434104a62 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 22 Aug 2022 21:52:48 +0200 Subject: [PATCH 344/762] Add option keep connections open when closing server. Fix #1174. --- docs/faq/server.rst | 15 +++++++ docs/project/changelog.rst | 5 +++ src/websockets/legacy/server.py | 69 +++++++++++++++++------------- tests/legacy/test_client_server.py | 17 +++++++- 4 files changed, 74 insertions(+), 32 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 68490d755..22a9d3a4c 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -260,6 +260,21 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/shutdown_server.py :emphasize-lines: 12-15,18 +How do I stop a server while keeping existing connections open? +--------------------------------------------------------------- + +Call the server's :meth:`~server.WebSocketServer.close` method with +``close_connections=False``. + +Here's how to adapt the example just above:: + + async def server(): + ... + + server = await websockets.serve(echo, "localhost", 8765) + await stop + await server.close(close_connections=False) + How do I run HTTP and WebSocket servers on the same port? --------------------------------------------------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 61136a36c..7851db80d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,11 @@ They may change at any time. *In development* +New features +............ + +* Made it possible to close a server without closing existing connections. + 10.4 ---- diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c32c85612..89e5322bc 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -730,26 +730,30 @@ def unregister(self, protocol: WebSocketServerProtocol) -> None: """ self.websockets.remove(protocol) - def close(self) -> None: + def close(self, close_connections: bool = True) -> None: """ Close the server. - This method: + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: - * closes the underlying :class:`~asyncio.Server`; - * rejects new WebSocket connections with an HTTP 503 (service - unavailable) error; this happens when the server accepted the TCP - connection but didn't complete the WebSocket opening handshake prior - to closing; - * closes open WebSocket connections with close code 1001 (going away). + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. :meth:`close` is idempotent. """ if self.close_task is None: - self.close_task = self.get_loop().create_task(self._close()) + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) - async def _close(self) -> None: + async def _close(self, close_connections: bool) -> None: """ Implementation of :meth:`close`. @@ -770,21 +774,22 @@ async def _close(self) -> None: # register(). See https://bugs.python.org/issue34852 for details. await asyncio.sleep(0, **loop_if_py_lt_38(self.get_loop())) - # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING connections with a HTTP 503 - # error. Wait until all connections are closed. - - close_tasks = [ - asyncio.create_task(websocket.close(1001)) - for websocket in self.websockets - if websocket.state is not State.CONNECTING - ] - # asyncio.wait doesn't accept an empty first argument. - if close_tasks: - await asyncio.wait( - close_tasks, - **loop_if_py_lt_38(self.get_loop()), - ) + if close_connections: + # Close OPEN connections with status code 1001. Since the server was + # closed, handshake() closes OPENING connections with a HTTP 503 + # error. Wait until all connections are closed. + + close_tasks = [ + asyncio.create_task(websocket.close(1001)) + for websocket in self.websockets + if websocket.state is not State.CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait( + close_tasks, + **loop_if_py_lt_38(self.get_loop()), + ) # Wait until all connection handlers are complete. @@ -903,18 +908,22 @@ class Serve: server performs the closing handshake and closes the connection. Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object - provides :meth:`~WebSocketServer.close` and - :meth:`~WebSocketServer.wait_closed` methods for shutting down the server. + provides a :meth:`~WebSocketServer.close` method to shut down the server:: + + stop = asyncio.Future() # set this future to exit the server + + server = await serve(...) + await stop + await server.close() - :func:`serve` can be used as an asynchronous context manager:: + :func:`serve` can be used as an asynchronous context manager. Then, the + server is shut down automatically when exiting the context:: stop = asyncio.Future() # set this future to exit the server async with serve(...): await stop - The server is shut down automatically when exiting the context. - Args: ws_handler: connection handler. It receives the WebSocket connection, which is a :class:`WebSocketServerProtocol`, in argument. diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index d02daede1..9db15c0f1 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1214,14 +1214,27 @@ def test_server_shuts_down_during_connection_handling(self): server_ws = next(iter(self.server.websockets)) self.server.close() with self.assertRaises(ConnectionClosed): + self.loop.run_until_complete(self.client.send("Hello!")) self.loop.run_until_complete(self.client.recv()) - # Websocket connection closes properly with 1001 Going Away. + # Server closed the connection with 1001 Going Away. self.assertEqual(self.client.close_code, 1001) self.assertEqual(server_ws.close_code, 1001) @with_server() - def test_server_shuts_down_waits_until_handlers_terminate(self): + def test_server_shuts_down_gracefully_during_connection_handling(self): + with self.temp_client(): + server_ws = next(iter(self.server.websockets)) + self.server.close(close_connections=False) + self.loop.run_until_complete(self.client.send("Hello!")) + self.loop.run_until_complete(self.client.recv()) + + # Client closed the connection with 1000 OK. + self.assertEqual(self.client.close_code, 1000) + self.assertEqual(server_ws.close_code, 1000) + + @with_server() + def test_server_shuts_down_and_waits_until_handlers_terminate(self): # This handler waits a bit after the connection is closed in order # to test that wait_closed() really waits for handlers to complete. self.start_client("/slow_stop") From 7a398921d8b41094ad4e13940c91a69db2bcfb1f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 26 Oct 2022 16:39:42 +0200 Subject: [PATCH 345/762] Reduce usage of # pragma: no cover. --- setup.cfg | 1 + src/websockets/connection.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 9126f55d0..74efc8042 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,3 +32,4 @@ source = exclude_lines = if self.debug: pragma: no cover + raise AssertionError diff --git a/src/websockets/connection.py b/src/websockets/connection.py index db8b53699..7a9be9f2e 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -568,7 +568,7 @@ def parse(self) -> Generator[None, None, None]: # During an abnormal closure, execution ends here after catching an # exception. At this point, fail() replaced parse() by discard(). yield - raise AssertionError("parse() shouldn't step after error") # pragma: no cover + raise AssertionError("parse() shouldn't step after error") def discard(self) -> Generator[None, None, None]: """ @@ -598,7 +598,7 @@ def discard(self) -> Generator[None, None, None]: yield # Once the reader reaches EOF, its feed_data/eof() methods raise an # error, so our receive_data/eof() methods don't step the generator. - raise AssertionError("discard() shouldn't step after EOF") # pragma: no cover + raise AssertionError("discard() shouldn't step after EOF") def recv_frame(self, frame: Frame) -> None: """ @@ -674,7 +674,7 @@ def recv_frame(self, frame: Frame) -> None: self.parser = self.discard() next(self.parser) # start coroutine - else: # pragma: no cover + else: # This can't happen because Frame.parse() validates opcodes. raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") From c19506b34dbf2bcb796dbe370ed41c269436ce92 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 26 Oct 2022 16:40:23 +0200 Subject: [PATCH 346/762] Ignore tests that no longer run on Python 3.11. --- tests/legacy/test_client_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 9db15c0f1..3004c685d 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1371,7 +1371,7 @@ def test_checking_lack_of_origin_succeeds_backwards_compatibility(self): @unittest.skipIf( sys.version_info[:2] >= (3, 11), "asyncio.coroutine has been removed in Python 3.11" ) -class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): +class YieldFromTests(ClientServerTestsMixin, AsyncioTestCase): # pragma: no cover @with_server() def test_client(self): # @asyncio.coroutine is deprecated on Python ≥ 3.8 From d8d2ad5a824f155b1b44c9750ebf0ece792df1cc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Oct 2022 22:42:44 +0100 Subject: [PATCH 347/762] Raise coverage of websockets.exceptions to 100%. Add tests missing from 62eb267c. --- src/websockets/exceptions.py | 8 ++++++-- tests/test_exceptions.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 0c4fc5185..291b6d2cd 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -120,11 +120,15 @@ def __str__(self) -> str: @property def code(self) -> int: - return 1006 if self.rcvd is None else self.rcvd.code + if self.rcvd is None: + return 1006 + return self.rcvd.code @property def reason(self) -> str: - return "" if self.rcvd is None else self.rcvd.reason + if self.rcvd is None: + return "" + return self.rcvd.reason class ConnectionClosedError(ConnectionClosed): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3ede25fdb..a0f9dfcd2 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -157,3 +157,13 @@ def test_str(self): ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) + + def test_connection_closed_attributes_backwards_compatibility(self): + exception = ConnectionClosed(Close(1000, "OK"), None, None) + self.assertEqual(exception.code, 1000) + self.assertEqual(exception.reason, "OK") + + def test_connection_closed_attributes_backwards_compatibility_defaults(self): + exception = ConnectionClosed(None, None, None) + self.assertEqual(exception.code, 1006) + self.assertEqual(exception.reason, "") From ca4968eb607a3a1b763e82d3b683a12b0eba67d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Oct 2022 22:43:28 +0100 Subject: [PATCH 348/762] Raise coverage of websockets.connection to 100%. Add tests for the logger argument. --- tests/test_connection.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_connection.py b/tests/test_connection.py index 44f5bc35e..3858d2521 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,3 +1,4 @@ +import logging import unittest.mock from websockets.connection import * @@ -1712,3 +1713,25 @@ def test_server_extension_decodes_frame(self): server.extensions = [Rsv2Extension()] server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) + + +class MiscTests(unittest.TestCase): + def test_client_default_logger(self): + client = Connection(CLIENT) + logger = logging.getLogger("websockets.client") + self.assertIs(client.logger, logger) + + def test_server_default_logger(self): + server = Connection(SERVER) + logger = logging.getLogger("websockets.server") + self.assertIs(server.logger, logger) + + def test_client_custom_logger(self): + logger = logging.getLogger("test") + client = Connection(CLIENT, logger=logger) + self.assertIs(client.logger, logger) + + def test_server_custom_logger(self): + logger = logging.getLogger("test") + server = Connection(SERVER, logger=logger) + self.assertIs(server.logger, logger) From 892c86a017379cb57dd49ea03c142561917d353b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Oct 2022 22:44:28 +0100 Subject: [PATCH 349/762] Move all test extensions to the same module. Some were in test_base.py, others in utils.py. --- tests/extensions/test_base.py | 36 --------------------- tests/extensions/test_permessage_deflate.py | 2 +- tests/extensions/utils.py | 35 ++++++++++++++++++++ tests/legacy/test_client_server.py | 2 +- 4 files changed, 37 insertions(+), 38 deletions(-) diff --git a/tests/extensions/test_base.py b/tests/extensions/test_base.py index 0daa34211..ba8657b65 100644 --- a/tests/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,40 +1,4 @@ -from websockets.exceptions import NegotiationError from websockets.extensions.base import * # noqa # Abstract classes don't provide any behavior to test. - - -class ClientNoOpExtensionFactory: - name = "x-no-op" - - def get_request_params(self): - return [] - - def process_response_params(self, params, accepted_extensions): - if params: - raise NegotiationError() - return NoOpExtension() - - -class ServerNoOpExtensionFactory: - name = "x-no-op" - - def __init__(self, params=None): - self.params = params or [] - - def process_request_params(self, params, accepted_extensions): - return self.params, NoOpExtension() - - -class NoOpExtension: - name = "x-no-op" - - def __repr__(self): - return "NoOpExtension()" - - def decode(self, frame, *, max_size=None): - return frame - - def encode(self, frame): - return frame diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 3cc9172df..d433762c5 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -21,7 +21,7 @@ Frame, ) -from .test_base import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory +from .utils import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory class ExtensionTestsMixin: diff --git a/tests/extensions/utils.py b/tests/extensions/utils.py index 1eabc163f..24fb74b4e 100644 --- a/tests/extensions/utils.py +++ b/tests/extensions/utils.py @@ -46,6 +46,41 @@ def process_request_params(self, params, accepted_extensions): return [("op", self.op)], OpExtension(self.op) +class NoOpExtension: + name = "x-no-op" + + def __repr__(self): + return "NoOpExtension()" + + def decode(self, frame, *, max_size=None): + return frame + + def encode(self, frame): + return frame + + +class ClientNoOpExtensionFactory: + name = "x-no-op" + + def get_request_params(self): + return [] + + def process_response_params(self, params, accepted_extensions): + if params: + raise NegotiationError() + return NoOpExtension() + + +class ServerNoOpExtensionFactory: + name = "x-no-op" + + def __init__(self, params=None): + self.params = params or [] + + def process_request_params(self, params, accepted_extensions): + return self.params, NoOpExtension() + + class Rsv2Extension: name = "x-rsv2" diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 3004c685d..8ee7187eb 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -37,7 +37,7 @@ from websockets.legacy.server import * from websockets.uri import parse_uri -from ..extensions.test_base import ( +from ..extensions.utils import ( ClientNoOpExtensionFactory, NoOpExtension, ServerNoOpExtensionFactory, From 1948936b1af6aba31fa551dd93248f6b6f4db60e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 31 Oct 2022 22:12:21 +0100 Subject: [PATCH 350/762] Add a script to measure coverage per module. This makes it possible to increase coverage threshold to "each module has 100% branch coverage from its own tests". This implementation supports `make maxi_cov` and `tox -e maxi_cov`. It is also enabled in CI. --- .github/workflows/tests.yml | 23 +++++- Makefile | 7 +- setup.cfg | 4 +- tests/maxi_cov.py | 148 ++++++++++++++++++++++++++++++++++++ tox.ini | 6 ++ 5 files changed, 183 insertions(+), 5 deletions(-) create mode 100755 tests/maxi_cov.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4c1b1e899..4d5cc3cd0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,8 +9,8 @@ on: - main jobs: - main: - name: Run code quality checks + coverage: + name: Run test coverage checks runs-on: ubuntu-latest steps: - name: Check out repository @@ -23,6 +23,21 @@ jobs: run: pip install tox - name: Run tests with coverage run: tox -e coverage + - name: Run tests with per-module coverage + run: tox -e maxi_cov + + quality: + name: Run code quality checks + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v3 + - name: Install Python 3.x + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install tox + run: pip install tox - name: Check code formatting run: tox -e black - name: Check code style @@ -34,7 +49,9 @@ jobs: matrix: name: Run tests on Python ${{ matrix.python }} - needs: main + needs: + - coverage + - quality runs-on: ubuntu-latest strategy: matrix: diff --git a/Makefile b/Makefile index cd8095ae1..ac5d6a4aa 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: default style test coverage build clean +.PHONY: default style test coverage maxi_cov build clean export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src @@ -20,6 +20,11 @@ coverage: coverage html coverage report --show-missing --fail-under=100 +maxi_cov: + python tests/maxi_cov.py + coverage html + coverage report --show-missing --fail-under=100 + build: python setup.py build_ext --inplace diff --git a/setup.cfg b/setup.cfg index 74efc8042..4c1f6091f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,7 +21,9 @@ lines_after_imports = 2 [coverage:run] branch = True omit = - */__main__.py + # */websockets matches src/websockets and .tox/**/site-packages/websockets + */websockets/__main__.py + tests/maxi_cov.py [coverage:paths] source = diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py new file mode 100755 index 000000000..5acb979de --- /dev/null +++ b/tests/maxi_cov.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python + +"""Measure coverage of each module by its test module.""" + +import glob +import os.path +import subprocess +import sys + + +UNMAPPED_SRC_FILES = ["websockets/version.py"] +UNMAPPED_TEST_FILES = ["tests/test_exports.py"] + +IGNORED_FILES = [ + # */websockets matches src/websockets and .tox/**/site-packages/websockets. + # There are no tests for the __main__ module. + "*/websockets/__main__.py", + # This approach isn't applicable to the test suite of the legacy + # implementation, due to the huge test_client_server test module. + "*/websockets/legacy/*", + "tests/legacy/*", + # Test utilities don't fit anywhere because they are shared. + "tests/extensions/utils.py", + "tests/utils.py", + # There is no point measure the coverage of this script. + "tests/maxi_cov.py", +] + + +def check_environment(): + """Check that prerequisites for running this script are met.""" + try: + import websockets # noqa: F401 + except ImportError: + print("failed to import websockets; is src on PYTHONPATH?") + return False + try: + import coverage # noqa: F401 + except ImportError: + print("failed to locate Coverage.py; is it installed?") + return False + return True + + +def get_mapping(src_dir="src"): + """Return a dict mapping each source file to its test file.""" + + # List source and test files. + + src_files = glob.glob( + os.path.join(src_dir, "websockets/**/*.py"), + recursive=True, + ) + + test_files = glob.glob( + "tests/**/*.py", + recursive=True, + ) + + src_files = [ + os.path.relpath(src_file, src_dir) + for src_file in sorted(src_files) + if os.path.basename(src_file) != "__init__.py" + and os.path.basename(src_file) != "__main__.py" + and "legacy" not in os.path.dirname(src_file) + ] + test_files = [ + test_file + for test_file in sorted(test_files) + if os.path.basename(test_file) != "__init__.py" + and os.path.basename(test_file).startswith("test_") + and "legacy" not in os.path.dirname(test_file) + ] + + # Map source files to test files. + + mapping = {} + unmapped_test_files = [] + + for test_file in test_files: + dir_name, file_name = os.path.split(test_file) + assert dir_name.startswith("tests") + assert file_name.startswith("test_") + src_file = os.path.join( + "websockets" + dir_name[len("tests") :], + file_name[len("test_") :], + ) + if src_file in src_files: + mapping[src_file] = test_file + else: + unmapped_test_files.append(test_file) + + unmapped_src_files = list(set(src_files) - set(mapping)) + + # Ensure that all files are mapped. + + assert unmapped_src_files == UNMAPPED_SRC_FILES + assert unmapped_test_files == UNMAPPED_TEST_FILES + + return mapping + + +def run_coverage(mapping, src_dir="src"): + # Initialize a new coverage measurement session. The --source option + # includes all files in the report, even if they're never imported. + print("\nInitializing session\n", flush=True) + subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + "--source", + ",".join([os.path.join(src_dir, "websockets"), "tests"]), + "--omit", + ",".join(IGNORED_FILES), + "-m", + "unittest", + ] + + UNMAPPED_TEST_FILES, + check=True, + ) + # Append coverage of each source module by the corresponding test module. + for src_file, test_file in mapping.items(): + print(f"\nTesting {src_file} with {test_file}\n", flush=True) + subprocess.run( + [ + sys.executable, + "-m", + "coverage", + "run", + "--append", + "--include", + ",".join([os.path.join(src_dir, src_file), test_file]), + "-m", + "unittest", + test_file, + ], + check=True, + ) + + +if __name__ == "__main__": + if not check_environment(): + sys.exit(1) + src_dir = sys.argv[1] if len(sys.argv) == 2 else "src" + mapping = get_mapping(src_dir) + run_coverage(mapping, src_dir) diff --git a/tox.ini b/tox.ini index 2f37cdcbf..0fcab4d79 100644 --- a/tox.ini +++ b/tox.ini @@ -21,6 +21,12 @@ commands = python -m coverage report --show-missing --fail-under=100 deps = coverage +[testenv:maxi_cov] +commands = + python tests/maxi_cov.py {envsitepackagesdir} + python -m coverage report --show-missing --fail-under=100 +deps = coverage + [testenv:black] commands = black --check src tests deps = black From 17d5d41140aad185e543092f98d55ab87e7c2f29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Nov 2022 07:50:40 +0100 Subject: [PATCH 351/762] Determine excluded test files automatically. --- tests/maxi_cov.py | 47 +++++++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 5acb979de..77d374f96 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -11,21 +11,6 @@ UNMAPPED_SRC_FILES = ["websockets/version.py"] UNMAPPED_TEST_FILES = ["tests/test_exports.py"] -IGNORED_FILES = [ - # */websockets matches src/websockets and .tox/**/site-packages/websockets. - # There are no tests for the __main__ module. - "*/websockets/__main__.py", - # This approach isn't applicable to the test suite of the legacy - # implementation, due to the huge test_client_server test module. - "*/websockets/legacy/*", - "tests/legacy/*", - # Test utilities don't fit anywhere because they are shared. - "tests/extensions/utils.py", - "tests/utils.py", - # There is no point measure the coverage of this script. - "tests/maxi_cov.py", -] - def check_environment(): """Check that prerequisites for running this script are met.""" @@ -51,7 +36,6 @@ def get_mapping(src_dir="src"): os.path.join(src_dir, "websockets/**/*.py"), recursive=True, ) - test_files = glob.glob( "tests/**/*.py", recursive=True, @@ -60,16 +44,16 @@ def get_mapping(src_dir="src"): src_files = [ os.path.relpath(src_file, src_dir) for src_file in sorted(src_files) + if "legacy" not in os.path.dirname(src_file) if os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" - and "legacy" not in os.path.dirname(src_file) ] test_files = [ test_file for test_file in sorted(test_files) - if os.path.basename(test_file) != "__init__.py" + if "legacy" not in os.path.dirname(test_file) + and os.path.basename(test_file) != "__init__.py" and os.path.basename(test_file).startswith("test_") - and "legacy" not in os.path.dirname(test_file) ] # Map source files to test files. @@ -100,7 +84,30 @@ def get_mapping(src_dir="src"): return mapping +def get_ignored_files(src_dir="src"): + """Return the list of files to exclude from coverage measurement.""" + + return [ + # */websockets matches src/websockets and .tox/**/site-packages/websockets. + # There are no tests for the __main__ module. + "*/websockets/__main__.py", + # This approach isn't applicable to the test suite of the legacy + # implementation, due to the huge test_client_server test module. + "*/websockets/legacy/*", + "tests/legacy/*", + ] + [ + # Exclude test utilities that are shared between several test modules. + # Also excludes this script. + test_file + for test_file in sorted(glob.glob("tests/**/*.py", recursive=True)) + if "legacy" not in os.path.dirname(test_file) + and os.path.basename(test_file) != "__init__.py" + and not os.path.basename(test_file).startswith("test_") + ] + + def run_coverage(mapping, src_dir="src"): + print(get_ignored_files(src_dir)) # Initialize a new coverage measurement session. The --source option # includes all files in the report, even if they're never imported. print("\nInitializing session\n", flush=True) @@ -113,7 +120,7 @@ def run_coverage(mapping, src_dir="src"): "--source", ",".join([os.path.join(src_dir, "websockets"), "tests"]), "--omit", - ",".join(IGNORED_FILES), + ",".join(get_ignored_files(src_dir)), "-m", "unittest", ] From 077e6429df312f549c9ecf348ec63abfd249da28 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Nov 2022 08:29:02 +0100 Subject: [PATCH 352/762] Perform version check at compile time. This is a small performance optimization. --- src/websockets/legacy/compatibility.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index df81de9db..296e9c584 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -5,9 +5,20 @@ from typing import Any, Dict -def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: - """ - Helper for the removal of the loop argument in Python 3.10. +if sys.version_info[:2] >= (3, 8): - """ - return {"loop": loop} if sys.version_info[:2] < (3, 8) else {} + def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: + """ + Helper for the removal of the loop argument in Python 3.10. + + """ + return {} + +else: # pragma: no cover + + def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: + """ + Helper for the removal of the loop argument in Python 3.10. + + """ + return {"loop": loop} From 68c041ec43a6a4d093d2316bf75c2ca0c51f4893 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 1 Nov 2022 13:17:26 +0100 Subject: [PATCH 353/762] Reduce usage of pragma: no cover. --- setup.cfg | 1 + tests/legacy/test_client_server.py | 24 +++++++++--------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/setup.cfg b/setup.cfg index 4c1f6091f..47acf312c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,3 +35,4 @@ exclude_lines = if self.debug: pragma: no cover raise AssertionError + self.fail\(".*"\) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 8ee7187eb..1a668a431 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -267,21 +267,15 @@ async def start_client(): self.assertDeprecationWarnings(recorded_warnings, expected_warnings) def stop_client(self): - try: - self.loop.run_until_complete( - asyncio.wait_for(self.client.close_connection_task, timeout=1) - ) - except asyncio.TimeoutError: # pragma: no cover - self.fail("Client failed to stop") + self.loop.run_until_complete( + asyncio.wait_for(self.client.close_connection_task, timeout=1) + ) def stop_server(self): self.server.close() - try: - self.loop.run_until_complete( - asyncio.wait_for(self.server.wait_closed(), timeout=1) - ) - except asyncio.TimeoutError: # pragma: no cover - self.fail("Server failed to stop") + self.loop.run_until_complete( + asyncio.wait_for(self.server.wait_closed(), timeout=1) + ) @contextlib.contextmanager def temp_server(self, **kwargs): @@ -380,13 +374,13 @@ def test_infinite_redirect(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): with self.temp_client("/infinite"): - self.fail("Did not raise") # pragma: no cover + self.fail("did not raise") def test_redirect_missing_location(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHeader): with self.temp_client("/missing_location"): - self.fail("Did not raise") # pragma: no cover + self.fail("did not raise") def test_loop_backwards_compatibility(self): with self.temp_server( @@ -1327,7 +1321,7 @@ def test_redirect_insecure(self): with temp_test_redirecting_server(self): with self.assertRaises(InvalidHandshake): with self.temp_client("/force_insecure"): - self.fail("Did not raise") # pragma: no cover + self.fail("did not raise") class ClientServerOriginTests(ClientServerTestsMixin, AsyncioTestCase): From 591d047884d08d938324c4a484e8138a88a7a4a9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 2 Nov 2022 07:52:29 +0100 Subject: [PATCH 354/762] Remove debug statement. --- tests/maxi_cov.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 77d374f96..b7c07b698 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -107,7 +107,6 @@ def get_ignored_files(src_dir="src"): def run_coverage(mapping, src_dir="src"): - print(get_ignored_files(src_dir)) # Initialize a new coverage measurement session. The --source option # includes all files in the report, even if they're never imported. print("\nInitializing session\n", flush=True) From 9e960b508988c4049eb9f3377c505f506a3af060 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 11 Nov 2022 14:26:45 +0100 Subject: [PATCH 355/762] Fix example of shutting down a client. Fix #1261. --- example/shutdown_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/shutdown_client.py b/example/shutdown_client.py index 539dd0304..65cfca41b 100755 --- a/example/shutdown_client.py +++ b/example/shutdown_client.py @@ -10,7 +10,7 @@ async def client(): # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() loop.add_signal_handler( - signal.SIGTERM, loop.create_task, websocket.close()) + signal.SIGTERM, loop.create_task, websocket.close) # Process messages received on the connection. async for message in websocket: From f199a31361118f58bc7cf5f928df8fc64575a824 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Nov 2022 15:08:48 +0100 Subject: [PATCH 356/762] Format setup.py with black. --- setup.py | 54 +++++++++++++++++++++++++++--------------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index 492b1597c..86b8bf9b8 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -long_description = (root_dir / 'README.rst').read_text(encoding='utf-8') +long_description = (root_dir / "README.rst").read_text(encoding="utf-8") # PyPI disables the "raw" directive. long_description = re.sub( @@ -18,47 +18,47 @@ flags=re.DOTALL | re.MULTILINE, ) -exec((root_dir / 'src' / 'websockets' / 'version.py').read_text(encoding='utf-8')) +exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) -packages = ['websockets', 'websockets/legacy', 'websockets/extensions'] +packages = ["websockets", "websockets/legacy", "websockets/extensions"] ext_modules = [ setuptools.Extension( - 'websockets.speedups', - sources=['src/websockets/speedups.c'], - optional=not (root_dir / '.cibuildwheel').exists(), + "websockets.speedups", + sources=["src/websockets/speedups.c"], + optional=not (root_dir / ".cibuildwheel").exists(), ) ] setuptools.setup( - name='websockets', + name="websockets", version=version, description=description, long_description=long_description, - url='https://github.com/aaugustin/websockets', - author='Aymeric Augustin', - author_email='aymeric.augustin@m4x.org', - license='BSD', + url="https://github.com/aaugustin/websockets", + author="Aymeric Augustin", + author_email="aymeric.augustin@m4x.org", + license="BSD", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], - package_dir = {'': 'src'}, - package_data = {'websockets': ['py.typed']}, + package_dir={"": "src"}, + package_data={"websockets": ["py.typed"]}, packages=packages, ext_modules=ext_modules, include_package_data=True, zip_safe=False, - python_requires='>=3.7', - test_loader='unittest:TestLoader', + python_requires=">=3.7", + test_loader="unittest:TestLoader", ) From 26e1946f4d02357c71ca909a8e347f38892900d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Nov 2022 15:43:05 +0100 Subject: [PATCH 357/762] Update for the latest version of mypy. https://github.com/python/mypy/issues/2350 --- setup.cfg | 1 + src/websockets/extensions/base.py | 5 +++++ src/websockets/legacy/protocol.py | 2 -- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 47acf312c..48703df87 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,4 +35,5 @@ exclude_lines = if self.debug: pragma: no cover raise AssertionError + raise NotImplementedError self.fail\(".*"\) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 060967618..6c481a46c 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -38,6 +38,7 @@ def decode( PayloadTooBig: if decoding the payload exceeds ``max_size``. """ + raise NotImplementedError def encode(self, frame: frames.Frame) -> frames.Frame: """ @@ -50,6 +51,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: Frame: Encoded frame. """ + raise NotImplementedError class ClientExtensionFactory: @@ -69,6 +71,7 @@ def get_request_params(self) -> List[ExtensionParameter]: List[ExtensionParameter]: Parameters to send to the server. """ + raise NotImplementedError def process_response_params( self, @@ -91,6 +94,7 @@ def process_response_params( NegotiationError: if parameters aren't acceptable. """ + raise NotImplementedError class ServerExtensionFactory: @@ -126,3 +130,4 @@ def process_request_params( the client aren't acceptable. """ + raise NotImplementedError diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 1b6e58efa..21784eb55 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -220,8 +220,6 @@ def __init__( # Logger or LoggerAdapter for this connection. if logger is None: logger = logging.getLogger("websockets.protocol") - # https://github.com/python/typeshed/issues/5561 - logger = cast(logging.Logger, logger) self.logger: LoggerLike = logging.LoggerAdapter(logger, {"websocket": self}) """Logger for this connection.""" From 0f4ecfc3d0abe4fdc2bc5104ea98e134be399c9a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 20 Nov 2022 22:39:49 +0100 Subject: [PATCH 358/762] Don't treat close code 1005 as an error. Fix #1260. --- docs/project/changelog.rst | 14 ++++++++++++++ src/websockets/exceptions.py | 7 ++++--- src/websockets/frames.py | 6 +++++- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/protocol.py | 6 +++--- src/websockets/legacy/server.py | 2 +- 6 files changed, 28 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7851db80d..40e090c31 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,20 @@ They may change at any time. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: Closing a connection without an empty close frame is OK. + :class: note + + Receiving an empty close frame now results in + :exc:`~exceptions.ConnectionClosedOK` instead of + :exc:`~exceptions.ConnectionClosedError`. + + As a consequence, calling ``WebSocket.close()`` without arguments in a + browser isn't reported as an error anymore. + + New features ............ diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 291b6d2cd..46f314d9e 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -135,8 +135,8 @@ class ConnectionClosedError(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated with an error. - A close code other than 1000 (OK) or 1001 (going away) was received or - sent, or the closing handshake didn't complete properly. + A close frame with a code other than 1000 (OK) or 1001 (going away) was + received or sent, or the closing handshake didn't complete properly. """ @@ -145,7 +145,8 @@ class ConnectionClosedOK(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated properly. - A close code 1000 (OK) or 1001 (going away) was received and sent. + A close code with code 1000 (OK) or 1001 (going away) or without a code was + received and sent. """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index ec6b8547d..45d006e3f 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -90,7 +90,11 @@ class Opcode(enum.IntEnum): 1014, } -OK_CLOSE_CODES = {1000, 1001} +OK_CLOSE_CODES = { + 1000, + 1001, + 1005, +} BytesLike = bytes, bytearray, memoryview diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 1e3c6e741..9b953df8f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -65,7 +65,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): await process(message) The iterator exits normally when the connection is closed with close code - 1000 (OK) or 1001 (going away). It raises + 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 21784eb55..9f2bda1ab 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -485,9 +485,9 @@ async def __aiter__(self) -> AsyncIterator[Data]: Iterate on incoming messages. The iterator exits normally when the connection is closed with the - close code 1000 (OK) or 1001(going away). It raises - a :exc:`~websockets.exceptions.ConnectionClosedError` exception when - the connection is closed with any other code. + close code 1000 (OK) or 1001(going away) or without a close code. It + raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception + when the connection is closed with any other code. """ try: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 89e5322bc..9359472fe 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -73,7 +73,7 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): await process(message) The iterator exits normally when the connection is closed with close code - 1000 (OK) or 1001 (going away). It raises + 1000 (OK) or 1001 (going away) or without a close code. It raises a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. From 9c1430379bcb42120e45ee0e9f9dca142b1d7560 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 22 Nov 2022 07:26:59 +0100 Subject: [PATCH 359/762] Use standard licence text and SPDX identifier. --- LICENSE | 9 ++++----- setup.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/LICENSE b/LICENSE index 119b29ef3..5d61ece22 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,4 @@ -Copyright (c) 2013-2021 Aymeric Augustin and contributors. -All rights reserved. +Copyright (c) Aymeric Augustin and contributors Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -9,9 +8,9 @@ modification, are permitted provided that the following conditions are met: * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - * Neither the name of websockets nor the names of its contributors may - be used to endorse or promote products derived from this software without - specific prior written permission. + * Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED diff --git a/setup.py b/setup.py index 86b8bf9b8..564ada85c 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ url="https://github.com/aaugustin/websockets", author="Aymeric Augustin", author_email="aymeric.augustin@m4x.org", - license="BSD", + license="BSD-3-Clause", classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", From f1a18247e78f9efcdca34a4b5a616d9733e92ff6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 22 Nov 2022 08:05:49 +0100 Subject: [PATCH 360/762] Rename test class with a more specific name. --- tests/extensions/test_permessage_deflate.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index d433762c5..c341fdb32 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -24,7 +24,7 @@ from .utils import ClientNoOpExtensionFactory, ServerNoOpExtensionFactory -class ExtensionTestsMixin: +class PerMessageDeflateTestsMixin: def assertExtensionEqual(self, extension1, extension2): self.assertEqual( extension1.remote_no_context_takeover, @@ -44,7 +44,7 @@ def assertExtensionEqual(self, extension1, extension2): ) -class PerMessageDeflateTests(unittest.TestCase, ExtensionTestsMixin): +class PerMessageDeflateTests(unittest.TestCase, PerMessageDeflateTestsMixin): def setUp(self): # Set up an instance of the permessage-deflate extension with the most # common settings. Since the extension is symmetrical, this instance @@ -278,7 +278,9 @@ def test_decompress_max_size(self): self.extension.decode(enc_frame, max_size=10) -class ClientPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): +class ClientPerMessageDeflateFactoryTests( + unittest.TestCase, PerMessageDeflateTestsMixin +): def test_name(self): assert ClientPerMessageDeflateFactory.name == "permessage-deflate" @@ -616,7 +618,9 @@ def test_enable_client_permessage_deflate(self): ) -class ServerPerMessageDeflateFactoryTests(unittest.TestCase, ExtensionTestsMixin): +class ServerPerMessageDeflateFactoryTests( + unittest.TestCase, PerMessageDeflateTestsMixin +): def test_name(self): assert ServerPerMessageDeflateFactory.name == "permessage-deflate" From 0b884ed68f2c4b482f9eadbf38adc01f7d869f1a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 22 Nov 2022 08:07:22 +0100 Subject: [PATCH 361/762] Rename test class consistently with others. --- tests/test_exports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_exports.py b/tests/test_exports.py index 568c50c54..978b1d0e7 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -25,7 +25,7 @@ ) -class TestExportsAllSubmodules(unittest.TestCase): +class ExportsTests(unittest.TestCase): def test_top_level_module_reexports_all_submodule_exports(self): self.assertEqual(set(combined_exports), set(websockets.__all__)) From 173aac8e536346155b6bb9edac6ac5b5a351b7e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Nov 2022 07:52:56 +0100 Subject: [PATCH 362/762] Organize the examples directory. --- docs/faq/client.rst | 2 +- docs/faq/server.rst | 14 +++++++++++++- docs/topics/deployment.rst | 4 ++-- docs/topics/logging.rst | 2 +- example/{ => faq}/health_check_server.py | 0 example/{ => faq}/shutdown_client.py | 0 example/{ => faq}/shutdown_server.py | 0 example/{ => legacy}/basic_auth_client.py | 0 example/{ => legacy}/basic_auth_server.py | 0 example/{ => legacy}/unix_client.py | 0 example/{ => legacy}/unix_server.py | 0 example/{ => logging}/json_log_formatter.py | 0 12 files changed, 17 insertions(+), 5 deletions(-) rename example/{ => faq}/health_check_server.py (100%) rename example/{ => faq}/shutdown_client.py (100%) rename example/{ => faq}/shutdown_server.py (100%) rename example/{ => legacy}/basic_auth_client.py (100%) rename example/{ => legacy}/basic_auth_server.py (100%) rename example/{ => legacy}/unix_client.py (100%) rename example/{ => legacy}/unix_server.py (100%) rename example/{ => logging}/json_log_formatter.py (100%) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 5bbbd6ded..73825e480 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -74,7 +74,7 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: -.. literalinclude:: ../../example/shutdown_client.py +.. literalinclude:: ../../example/faq/shutdown_client.py :emphasize-lines: 10-13 How do I disable TLS/SSL certificate verification? diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 22a9d3a4c..feec65a58 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -257,7 +257,7 @@ Exit the :func:`~server.serve` context manager. Here's an example that terminates cleanly when it receives SIGTERM on Unix: -.. literalinclude:: ../../example/shutdown_server.py +.. literalinclude:: ../../example/faq/shutdown_server.py :emphasize-lines: 12-15,18 How do I stop a server while keeping existing connections open? @@ -275,6 +275,18 @@ Here's how to adapt the example just above:: await stop await server.close(close_connections=False) +How do I implement a health check? +---------------------------------- + +Intercept WebSocket handshake requests with the +:meth:`~server.WebSocketServerProtocol.process_request` hook. + +When a request is sent to the health check endpoint, treat is as an HTTP request +and return a ``(status, headers, body)`` tuple, as in this example: + +.. literalinclude:: ../../example/faq/health_check_server.py + :emphasize-lines: 7-9,18 + How do I run HTTP and WebSocket servers on the same port? --------------------------------------------------------- diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index ac0a8ed4c..2a1fe9a78 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -97,7 +97,7 @@ signal and exit the server to ensure a graceful shutdown. Here's an example: -.. literalinclude:: ../../example/shutdown_server.py +.. literalinclude:: ../../example/faq/shutdown_server.py :emphasize-lines: 12-15,18 When exiting the context manager, :func:`~server.serve` closes all connections @@ -177,5 +177,5 @@ websockets provide minimal support for responding to HTTP requests with the Here's an example: -.. literalinclude:: ../../example/health_check_server.py +.. literalinclude:: ../../example/faq/health_check_server.py :emphasize-lines: 7-9,18 diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 95acf57ff..e2b4a7be1 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -139,7 +139,7 @@ output logs as JSON with a bit of effort. First, we need a :class:`~logging.Formatter` that renders JSON: -.. literalinclude:: ../../example/json_log_formatter.py +.. literalinclude:: ../../example/logging/json_log_formatter.py Then, we configure logging to apply this formatter:: diff --git a/example/health_check_server.py b/example/faq/health_check_server.py similarity index 100% rename from example/health_check_server.py rename to example/faq/health_check_server.py diff --git a/example/shutdown_client.py b/example/faq/shutdown_client.py similarity index 100% rename from example/shutdown_client.py rename to example/faq/shutdown_client.py diff --git a/example/shutdown_server.py b/example/faq/shutdown_server.py similarity index 100% rename from example/shutdown_server.py rename to example/faq/shutdown_server.py diff --git a/example/basic_auth_client.py b/example/legacy/basic_auth_client.py similarity index 100% rename from example/basic_auth_client.py rename to example/legacy/basic_auth_client.py diff --git a/example/basic_auth_server.py b/example/legacy/basic_auth_server.py similarity index 100% rename from example/basic_auth_server.py rename to example/legacy/basic_auth_server.py diff --git a/example/unix_client.py b/example/legacy/unix_client.py similarity index 100% rename from example/unix_client.py rename to example/legacy/unix_client.py diff --git a/example/unix_server.py b/example/legacy/unix_server.py similarity index 100% rename from example/unix_server.py rename to example/legacy/unix_server.py diff --git a/example/json_log_formatter.py b/example/logging/json_log_formatter.py similarity index 100% rename from example/json_log_formatter.py rename to example/logging/json_log_formatter.py From cac6e8575e7b3c6339817c74fcbba846a4f935dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 23 Nov 2022 08:04:15 +0100 Subject: [PATCH 363/762] Standardize to "an HTTP". --- README.rst | 2 +- docs/faq/server.rst | 6 +++--- docs/howto/django.rst | 2 +- docs/howto/sansio.rst | 2 +- docs/intro/tutorial1.rst | 2 +- docs/intro/tutorial2.rst | 2 +- docs/intro/tutorial3.rst | 2 +- docs/reference/limitations.rst | 2 +- docs/topics/authentication.rst | 8 ++++---- docs/topics/design.rst | 10 +++++----- docs/topics/logging.rst | 2 +- src/websockets/exceptions.py | 8 ++++---- src/websockets/legacy/auth.py | 4 ++-- src/websockets/legacy/server.py | 6 +++--- tests/legacy/test_client_server.py | 6 +++--- 15 files changed, 32 insertions(+), 32 deletions(-) diff --git a/README.rst b/README.rst index 7cbccfe13..fa1d91061 100644 --- a/README.rst +++ b/README.rst @@ -123,7 +123,7 @@ Why shouldn't I use ``websockets``? * If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP - is minimal — just enough for a HTTP health check. + is minimal — just enough for an HTTP health check. If you want to do both in the same server, look at HTTP frameworks that build on top of ``websockets`` to support WebSocket connections, like diff --git a/docs/faq/server.rst b/docs/faq/server.rst index feec65a58..ab665f2b8 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -295,15 +295,15 @@ You don't. HTTP and WebSocket have widely different operational characteristics. Running them with the same server becomes inconvenient when you scale. -Providing a HTTP server is out of scope for websockets. It only aims at +Providing an HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the :attr:`~server.WebSocketServerProtocol.process_request` hook. -If you need more, pick a HTTP server and run it separately. +If you need more, pick an HTTP server and run it separately. -Alternatively, pick a HTTP framework that builds on top of ``websockets`` to +Alternatively, pick an HTTP framework that builds on top of ``websockets`` to support WebSocket connections, like Sanic_. .. _Sanic: https://sanicframework.org/en/ diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 5bb2e296b..c955a5ec1 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -10,7 +10,7 @@ WebSocket, you have two main options. 2. Deploying a separate WebSocket server next to your Django project. This technique is well suited when you need to add a small set of real-time - features — maybe a notification service — to a HTTP application. + features — maybe a notification service — to an HTTP application. .. _Channels: https://channels.readthedocs.io/ diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index 1373c81a5..197e29dc9 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -85,7 +85,7 @@ with :meth:`~server.ServerConnection.send_response`:: response = connection.accept(request) connection.send_response(response) -Alternatively, you may reject the WebSocket handshake and return a HTTP +Alternatively, you may reject the WebSocket handshake and return an HTTP response with :meth:`~server.ServerConnection.reject`:: response = connection.reject(status, explanation) diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index ab4f39d79..ff85003b5 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -163,7 +163,7 @@ page loads, draw the board: createBoard(board); }); -Open a shell, navigate to the directory containing these files, and start a +Open a shell, navigate to the directory containing these files, and start an HTTP server: .. code-block:: console diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 53f295d3f..5ac4ae9dd 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -369,7 +369,7 @@ common pattern in servers that handle different clients. it in the URI because URIs end up in logs. For the purposes of this tutorial, both approaches are equivalent because - the join key comes from a HTTP URL. There isn't much at risk anyway! + the join key comes from an HTTP URL. There isn't much at risk anyway! Now you can restore the logic for playing moves and you'll have a fully functional two-player game. diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 9a447f39b..4d42447b7 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -13,7 +13,7 @@ Part 3 - Deploy to the web web; you can play from any browser connected to the Internet. In the first and second parts of the tutorial, for local development, you ran -a HTTP server on ``http://localhost:8000/`` with: +an HTTP server on ``http://localhost:8000/`` with: .. code-block:: console diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst index 3304bdb8c..696aa38fd 100644 --- a/docs/reference/limitations.rst +++ b/docs/reference/limitations.rst @@ -13,7 +13,7 @@ right layer for enforcing this constraint. It's the caller's responsibility. .. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 -The client doesn't support connecting through a HTTP proxy (`issue 364`_) or a +The client doesn't support connecting through an HTTP proxy (`issue 364`_) or a SOCKS proxy (`issue 475`_). .. _issue 364: https://github.com/aaugustin/websockets/issues/364 diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 31bfd6465..1849d635a 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -78,7 +78,7 @@ WebSocket server. 3. **Setting a cookie on the domain of the WebSocket URI.** Cookies are undoubtedly the most common and hardened mechanism for sending - credentials from a web application to a server. In a HTTP application, + credentials from a web application to a server. In an HTTP application, credentials would be a session identifier or a serialized, signed session. Unfortunately, when the WebSocket server runs on a different domain from @@ -208,7 +208,7 @@ opening the connection: // ... The server intercepts the HTTP request, extracts the token and authenticates -the user. If authentication fails, it returns a HTTP 401: +the user. If authentication fails, it returns an HTTP 401: .. code-block:: python @@ -254,7 +254,7 @@ This sequence must be synchronized between the main window and the iframe. This involves several events. Look at the full implementation for details. The server intercepts the HTTP request, extracts the token and authenticates -the user. If authentication fails, it returns a HTTP 401: +the user. If authentication fails, it returns an HTTP 401: .. code-block:: python @@ -295,7 +295,7 @@ Since HTTP Basic Auth is designed to accept a username and a password rather than a token, we send ``token`` as username and the token as password. The server intercepts the HTTP request, extracts the token and authenticates -the user. If authentication fails, it returns a HTTP 401: +the user. If authentication fails, it returns an HTTP 401: .. code-block:: python diff --git a/docs/topics/design.rst b/docs/topics/design.rst index b5c55afc9..33dd187b9 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -125,23 +125,23 @@ one another. On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: -- builds a HTTP request based on the ``uri`` and parameters passed to +- builds an HTTP request based on the ``uri`` and parameters passed to :meth:`~client.connect`; - writes the HTTP request to the network; -- reads a HTTP response from the network; +- reads an HTTP response from the network; - checks the HTTP response, validates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - moves to the ``OPEN`` state. On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: -- reads a HTTP request from the network; +- reads an HTTP request from the network; - calls :meth:`~server.WebSocketServerProtocol.process_request` which may - abort the WebSocket handshake and return a HTTP response instead; this + abort the WebSocket handshake and return an HTTP response instead; this hook only makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; -- builds a HTTP response based on the above and parameters passed to +- builds an HTTP response based on the above and parameters passed to :meth:`~server.serve`; - writes the HTTP response to the network; - moves to the ``OPEN`` state; diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index e2b4a7be1..294a6cda8 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -104,7 +104,7 @@ can set ``logger`` to a :class:`~logging.LoggerAdapter` that enriches logs. For example, if the server is behind a reverse proxy, :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` gives the IP address of the proxy, which isn't useful. IP addresses of clients are -provided in a HTTP header set by the proxy. +provided in an HTTP header set by the proxy. Here's how to include them in logs, assuming they're in the ``X-Forwarded-For`` header:: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 46f314d9e..1f4b9265c 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -176,7 +176,7 @@ class InvalidMessage(InvalidHandshake): class InvalidHeader(InvalidHandshake): """ - Raised when a HTTP header doesn't have a valid format or value. + Raised when an HTTP header doesn't have a valid format or value. """ @@ -195,7 +195,7 @@ def __str__(self) -> str: class InvalidHeaderFormat(InvalidHeader): """ - Raised when a HTTP header cannot be parsed. + Raised when an HTTP header cannot be parsed. The format of the header doesn't match the grammar for that header. @@ -207,7 +207,7 @@ def __init__(self, name: str, error: str, header: str, pos: int) -> None: class InvalidHeaderValue(InvalidHeader): """ - Raised when a HTTP header has a wrong value. + Raised when an HTTP header has a wrong value. The format of the header is correct but a value isn't acceptable. @@ -315,7 +315,7 @@ def __str__(self) -> str: class AbortHandshake(InvalidHandshake): """ - Raised to abort the handshake on purpose and return a HTTP response. + Raised to abort the handshake on purpose and return an HTTP response. This exception is an implementation detail. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8825c14ec..ac24c179e 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -67,7 +67,7 @@ async def check_credentials(self, username: str, password: str) -> bool: Returns: bool: :obj:`True` if the handshake should continue; - :obj:`False` if it should fail with a HTTP 401 error. + :obj:`False` if it should fail with an HTTP 401 error. """ if self._check_credentials is not None: @@ -81,7 +81,7 @@ async def process_request( request_headers: Headers, ) -> Optional[HTTPResponse]: """ - Check HTTP Basic Auth and return a HTTP 401 response if needed. + Check HTTP Basic Auth and return an HTTP 401 response if needed. """ try: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9359472fe..f37a16c20 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -329,9 +329,9 @@ async def process_request( You may override this method in a :class:`WebSocketServerProtocol` subclass, for example: - * to return a HTTP 200 OK response on a given path; then a load + * to return an HTTP 200 OK response on a given path; then a load balancer can use this path for a health check; - * to authenticate the request and return a HTTP 401 Unauthorized or a + * to authenticate the request and return an HTTP 401 Unauthorized or an HTTP 403 Forbidden when authentication fails. You may also override this method with the ``process_request`` @@ -776,7 +776,7 @@ async def _close(self, close_connections: bool) -> None: if close_connections: # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING connections with a HTTP 503 + # closed, handshake() closes OPENING connections with an HTTP 503 # error. Wait until all connections are closed. close_tasks = [ diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 1a668a431..f8a79027e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -727,7 +727,7 @@ def test_protocol_custom_server_header(self): @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_http_endpoint(self): - # Making a HTTP request to a HTTP endpoint succeeds. + # Making an HTTP request to an HTTP endpoint succeeds. response = self.loop.run_until_complete(self.make_http_request("/__health__/")) with contextlib.closing(response): @@ -736,7 +736,7 @@ def test_http_request_http_endpoint(self): @with_server(create_protocol=HealthCheckServerProtocol) def test_http_request_ws_endpoint(self): - # Making a HTTP request to a WS endpoint fails. + # Making an HTTP request to a WS endpoint fails. with self.assertRaises(urllib.error.HTTPError) as raised: self.loop.run_until_complete(self.make_http_request()) @@ -745,7 +745,7 @@ def test_http_request_ws_endpoint(self): @with_server(create_protocol=HealthCheckServerProtocol) def test_ws_connection_http_endpoint(self): - # Making a WS connection to a HTTP endpoint fails. + # Making a WS connection to an HTTP endpoint fails. with self.assertRaises(InvalidStatusCode) as raised: self.start_client("/__health__/") From 39c53fb67d6815cab0a72be03b0bba60df504831 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Nov 2022 19:02:44 +0100 Subject: [PATCH 364/762] websockets CAN handle multiple clients. Ref #1268. --- docs/faq/server.rst | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index ab665f2b8..4e4622dce 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -43,19 +43,44 @@ Why can only one client connect at a time? ------------------------------------------ Your connection handler blocks the event loop. Look for blocking calls. + Any call that may take some time must be asynchronous. -For example, if you have:: +For example, this connection handler prevents the event loop from running during +one second:: async def handler(websocket): time.sleep(1) + ... -change it to:: +Change it to:: async def handler(websocket): await asyncio.sleep(1) + ... + +In addition, calling a coroutine doesn't guarantee that it will yield control to +the event loop. + +For example, this connection handler blocks the event loop by sending messages +continuously:: + + async def handler(websocket): + while True: + await websocket.send("firehose!") + +:meth:`~legacy.protocol.WebSocketCommonProtocol.send` completes synchronously as +long as there's space in send buffers. The event loop never runs. (This pattern +is uncommon in real-world applications. It occurs mostly in toy programs.) + +You can avoid the issue by yielding control to the event loop explicitly:: + + async def handler(websocket): + while True: + await websocket.send("firehose!") + await asyncio.sleep(0) -This is part of learning asyncio. It isn't specific to websockets. +All this is part of learning asyncio. It isn't specific to websockets. See also Python's documentation about `running blocking code`_. From f5ea94ab818d873664ffb76274b675fe307e289b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 09:25:50 +0200 Subject: [PATCH 365/762] Rename Connection to Protocol. This makes the following naming possible: * Connection = TCP/TLS connection + pointer to protocol; concerned with opening the network connection, moving bytes, and closing it. * Protocol = Sans-I/O protocol; concerned with parsing and serializing messages and with the state of the connection. Previously, these names were reversed in the sockets & threads branch. The previous choice was influenced by the legacy implementation using "protocol" to describe the two layers together, itself influenced by asyncio using the same word. --- docs/howto/sansio.rst | 130 ++- docs/project/changelog.rst | 17 +- docs/reference/client.rst | 2 +- docs/reference/common.rst | 4 +- docs/reference/server.rst | 2 +- docs/reference/types.rst | 2 +- src/websockets/__init__.py | 8 +- src/websockets/client.py | 18 +- src/websockets/connection.py | 705 +---------- src/websockets/http11.py | 4 +- src/websockets/legacy/protocol.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/protocol.py | 702 +++++++++++ src/websockets/server.py | 18 +- tests/legacy/test_client_server.py | 2 +- tests/legacy/test_protocol.py | 2 +- tests/test_client.py | 98 +- tests/test_connection.py | 1743 +--------------------------- tests/test_protocol.py | 1737 +++++++++++++++++++++++++++ tests/test_server.py | 108 +- tests/utils.py | 29 + 21 files changed, 2722 insertions(+), 2613 deletions(-) create mode 100644 src/websockets/protocol.py create mode 100644 tests/test_protocol.py diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index 197e29dc9..08b09f7ce 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -33,35 +33,36 @@ If you're building a client, parse the URI you'd like to connect to:: Open a TCP connection to ``(wsuri.host, wsuri.port)`` and perform a TLS handshake if ``wsuri.secure`` is :obj:`True`. -Initialize a :class:`~client.ClientConnection`:: +Initialize a :class:`~client.ClientProtocol`:: - from websockets.client import ClientConnection + from websockets.client import ClientProtocol - connection = ClientConnection(wsuri) + protocol = ClientProtocol(wsuri) Create a WebSocket handshake request -with :meth:`~client.ClientConnection.connect` and send it -with :meth:`~client.ClientConnection.send_request`:: +with :meth:`~client.ClientProtocol.connect` and send it +with :meth:`~client.ClientProtocol.send_request`:: - request = connection.connect() - connection.send_request(request) + request = protocol.connect() + protocol.send_request(request) -Then, call :meth:`~connection.Connection.data_to_send` and send its output to +Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as described in `Send data`_ below. -The first event returned by :meth:`~connection.Connection.events_received` is -the WebSocket handshake response. +Once you receive enough data, as explained in `Receive data`_ below, the first +event returned by :meth:`~protocol.Protocol.events_received` is the WebSocket +handshake response. When the handshake fails, the reason is available in -:attr:`~client.ClientConnection.handshake_exc`:: +:attr:`~client.ClientProtocol.handshake_exc`:: - if connection.handshake_exc is not None: - raise connection.handshake_exc + if protocol.handshake_exc is not None: + raise protocol.handshake_exc Else, the WebSocket connection is open. A WebSocket client API usually performs the handshake then returns a wrapper -around the network connection and the :class:`~client.ClientConnection`. +around the network socket and the :class:`~client.ClientProtocol`. Server-side ........... @@ -69,45 +70,46 @@ Server-side If you're building a server, accept network connections from clients and perform a TLS handshake if desired. -For each connection, initialize a :class:`~server.ServerConnection`:: +For each connection, initialize a :class:`~server.ServerProtocol`:: - from websockets.server import ServerConnection + from websockets.server import ServerProtocol - connection = ServerConnection() + protocol = ServerProtocol() -The first event returned by :meth:`~connection.Connection.events_received` is -the WebSocket handshake request. +Once you receive enough data, as explained in `Receive data`_ below, the first +event returned by :meth:`~protocol.Protocol.events_received` is the WebSocket +handshake request. Create a WebSocket handshake response -with :meth:`~server.ServerConnection.accept` and send it -with :meth:`~server.ServerConnection.send_response`:: +with :meth:`~server.ServerProtocol.accept` and send it +with :meth:`~server.ServerProtocol.send_response`:: - response = connection.accept(request) - connection.send_response(response) + response = protocol.accept(request) + protocol.send_response(response) Alternatively, you may reject the WebSocket handshake and return an HTTP -response with :meth:`~server.ServerConnection.reject`:: +response with :meth:`~server.ServerProtocol.reject`:: - response = connection.reject(status, explanation) - connection.send_response(response) + response = protocol.reject(status, explanation) + protocol.send_response(response) -Then, call :meth:`~connection.Connection.data_to_send` and send its output to +Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as described in `Send data`_ below. -Even when you call :meth:`~server.ServerConnection.accept`, the WebSocket +Even when you call :meth:`~server.ServerProtocol.accept`, the WebSocket handshake may fail if the request is incorrect or unsupported. When the handshake fails, the reason is available in -:attr:`~server.ServerConnection.handshake_exc`:: +:attr:`~server.ServerProtocol.handshake_exc`:: - if connection.handshake_exc is not None: - raise connection.handshake_exc + if protocol.handshake_exc is not None: + raise protocol.handshake_exc Else, the WebSocket connection is open. -A WebSocket server API usually builds a wrapper around the network connection -and the :class:`~server.ServerConnection`. Then it invokes a connection -handler that accepts the wrapper in argument. +A WebSocket server API usually builds a wrapper around the network socket and +the :class:`~server.ServerProtocol`. Then it invokes a connection handler that +accepts the wrapper in argument. It may also provide a way to close all connections and to shut down the server gracefully. @@ -122,11 +124,11 @@ Go through the five steps below until you reach the end of the data stream. Receive data ............ -When receiving data from the network, feed it to the connection's -:meth:`~connection.Connection.receive_data` method. +When receiving data from the network, feed it to the protocol's +:meth:`~protocol.Protocol.receive_data` method. -When reaching the end of the data stream, call the connection's -:meth:`~connection.Connection.receive_eof` method. +When reaching the end of the data stream, call the protocol's +:meth:`~protocol.Protocol.receive_eof` method. For example, if ``sock`` is a :obj:`~socket.socket`:: @@ -135,21 +137,21 @@ For example, if ``sock`` is a :obj:`~socket.socket`:: except OSError: # socket closed data = b"" if data: - connection.receive_data(data) + protocol.receive_data(data) else: - connection.receive_eof() + protocol.receive_eof() These methods aren't expected to raise exceptions — unless you call them again -after calling :meth:`~connection.Connection.receive_eof`, which is an error. +after calling :meth:`~protocol.Protocol.receive_eof`, which is an error. (If you get an exception, please file a bug!) Send data ......... -Then, call :meth:`~connection.Connection.data_to_send` and send its output to +Then, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network:: - for data in connection.data_to_send(): + for data in protocol.data_to_send(): if data: sock.sendall(data) else: @@ -170,7 +172,7 @@ server starts the four-way TCP closing handshake. If the network fails at the wrong point, you can end up waiting until the TCP timeout, which is very long. To prevent dangling TCP connections when you expect the end of the data stream -but you never reach it, call :meth:`~connection.Connection.close_expected` +but you never reach it, call :meth:`~protocol.Protocol.close_expected` and, if it returns :obj:`True`, schedule closing the TCP connection after a short timeout:: @@ -185,11 +187,11 @@ data stream, possibly with an exception. Close TCP connection .................... -If you called :meth:`~connection.Connection.receive_eof`, close the TCP +If you called :meth:`~protocol.Protocol.receive_eof`, close the TCP connection now. This is a clean closure because the receive buffer is empty. -After :meth:`~connection.Connection.receive_eof` signals the end of the read -stream, :meth:`~connection.Connection.data_to_send` always signals the end of +After :meth:`~protocol.Protocol.receive_eof` signals the end of the read +stream, :meth:`~protocol.Protocol.data_to_send` always signals the end of the write stream, unless it already ended. So, at this point, the TCP connection is already half-closed. The only reason for closing it now is to release resources related to the socket. @@ -199,8 +201,8 @@ Now you can exit the loop relaying data from the network to the application. Receive events .............. -Finally, call :meth:`~connection.Connection.events_received` to obtain events -parsed from the data provided to :meth:`~connection.Connection.receive_data`:: +Finally, call :meth:`~protocol.Protocol.events_received` to obtain events +parsed from the data provided to :meth:`~protocol.Protocol.receive_data`:: events = connection.events_received() @@ -224,9 +226,9 @@ The connection object provides one method for each type of WebSocket frame. For sending a data frame: -* :meth:`~connection.Connection.send_continuation` -* :meth:`~connection.Connection.send_text` -* :meth:`~connection.Connection.send_binary` +* :meth:`~protocol.Protocol.send_continuation` +* :meth:`~protocol.Protocol.send_text` +* :meth:`~protocol.Protocol.send_binary` These methods raise :exc:`~exceptions.ProtocolError` if you don't set the :attr:`FIN ` bit correctly in fragmented @@ -234,21 +236,21 @@ messages. For sending a control frame: -* :meth:`~connection.Connection.send_close` -* :meth:`~connection.Connection.send_ping` -* :meth:`~connection.Connection.send_pong` +* :meth:`~protocol.Protocol.send_close` +* :meth:`~protocol.Protocol.send_ping` +* :meth:`~protocol.Protocol.send_pong` -:meth:`~connection.Connection.send_close` initiates the closing handshake. +:meth:`~protocol.Protocol.send_close` initiates the closing handshake. See `Closing a connection`_ below for details. If you encounter an unrecoverable error and you must fail the WebSocket -connection, call :meth:`~connection.Connection.fail`. +connection, call :meth:`~protocol.Protocol.fail`. -After any of the above, call :meth:`~connection.Connection.data_to_send` and +After any of the above, call :meth:`~protocol.Protocol.data_to_send` and send its output to the network, as shown in `Send data`_ above. -If you called :meth:`~connection.Connection.send_close` -or :meth:`~connection.Connection.fail`, you expect the end of the data +If you called :meth:`~protocol.Protocol.send_close` +or :meth:`~protocol.Protocol.fail`, you expect the end of the data stream. You should follow the process described in `Close TCP connection`_ above in order to prevent dangling TCP connections. @@ -272,10 +274,10 @@ When a client wants to close the TCP connection: Applying the rules described earlier in this document gives the intended result. As a reminder, the rules are: -* When :meth:`~connection.Connection.data_to_send` returns the empty +* When :meth:`~protocol.Protocol.data_to_send` returns the empty bytestring, close the write side of the TCP connection. * When you reach the end of the read stream, close the TCP connection. -* When :meth:`~connection.Connection.close_expected` returns :obj:`True`, if +* When :meth:`~protocol.Protocol.close_expected` returns :obj:`True`, if you don't reach the end of the read stream quickly, close the TCP connection. Fragmentation @@ -306,14 +308,14 @@ should happen automatically in a cooperative multitasking environment. However, you still have to make sure you don't break this property by accident. For example, serialize writes to the network -when :meth:`~connection.Connection.data_to_send` returns multiple values to +when :meth:`~protocol.Protocol.data_to_send` returns multiple values to prevent concurrent writes from interleaving incorrectly. Avoid buffers ............. The Sans-I/O layer doesn't do any buffering. It makes events available in -:meth:`~connection.Connection.events_received` as soon as they're received. +:meth:`~protocol.Protocol.events_received` as soon as they're received. You should make incoming messages available to the application immediately and stop further processing until the application fetches them. This will usually diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 40e090c31..532eeffb2 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -33,6 +33,18 @@ They may change at any time. Backwards-incompatible changes .............................. +.. admonition:: The Sans-I/O implementation was moved. + :class: caution + + Aliases provide compatibility for all previously public APIs according to + the `backwards-compatibility policy`_ + + * The ``connection`` module was renamed to ``protocol``. + + * The ``connection.Connection``, ``server.ServerConnection``, and + ``client.ClientConnection`` classes were renamed to ``protocol.Protocol``, + ``server.ServerProtocol``, and ``client.ClientProtocol``. + .. admonition:: Closing a connection without an empty close frame is OK. :class: note @@ -43,7 +55,6 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. - New features ............ @@ -86,8 +97,8 @@ Backwards-incompatible changes .. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated. :class: note - Use the ``handshake_exc`` attribute of :class:`~server.ServerConnection` and - :class:`~client.ClientConnection` instead. + Use the ``handshake_exc`` attribute of :class:`~server.ServerProtocol` and + :class:`~client.ClientProtocol` instead. See :doc:`../howto/sansio` for details. diff --git a/docs/reference/client.rst b/docs/reference/client.rst index b72f49f5d..44f053b1e 100644 --- a/docs/reference/client.rst +++ b/docs/reference/client.rst @@ -69,7 +69,7 @@ Using a connection Sans-I/O -------- -.. autoclass:: ClientConnection(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) .. automethod:: receive_data diff --git a/docs/reference/common.rst b/docs/reference/common.rst index 6ba11bff5..b42f5ea3e 100644 --- a/docs/reference/common.rst +++ b/docs/reference/common.rst @@ -57,9 +57,9 @@ asyncio Sans-I/O -------- -.. automodule:: websockets.connection +.. automodule:: websockets.protocol -.. autoclass:: Connection(side, state=State.OPEN, max_size=2 ** 20, logger=None) +.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) .. automethod:: receive_data diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 12fe1f806..50ef4ee3c 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -115,7 +115,7 @@ websockets supports HTTP Basic Authentication according to Sans-I/O -------- -.. autoclass:: ServerConnection(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) .. automethod:: receive_data diff --git a/docs/reference/types.rst b/docs/reference/types.rst index d86429be4..88550d08d 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -15,7 +15,7 @@ Types .. autodata:: ExtensionParameter -.. autodata:: websockets.connection.Event +.. autodata:: websockets.protocol.Event .. autodata:: websockets.datastructures.HeadersLike diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index ec3484124..826decc48 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -9,7 +9,7 @@ "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", "broadcast", - "ClientConnection", + "ClientProtocol", "connect", "ConnectionClosed", "ConnectionClosedError", @@ -40,7 +40,7 @@ "RedirectHandshake", "SecurityError", "serve", - "ServerConnection", + "ServerProtocol", "Subprotocol", "unix_connect", "unix_serve", @@ -60,7 +60,7 @@ "basic_auth_protocol_factory": ".legacy.auth", "BasicAuthWebSocketServerProtocol": ".legacy.auth", "broadcast": ".legacy.protocol", - "ClientConnection": ".client", + "ClientProtocol": ".client", "connect": ".legacy.client", "unix_connect": ".legacy.client", "WebSocketClientProtocol": ".legacy.client", @@ -93,7 +93,7 @@ "WebSocketProtocolError": ".exceptions", "protocol": ".legacy", "WebSocketCommonProtocol": ".legacy.protocol", - "ServerConnection": ".server", + "ServerProtocol": ".server", "serve": ".legacy.server", "unix_serve": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", diff --git a/src/websockets/client.py b/src/websockets/client.py index 373e6b751..a439ab846 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Generator, List, Optional, Sequence +import warnings +from typing import Any, Generator, List, Optional, Sequence -from .connection import CLIENT, CONNECTING, OPEN, Connection, State from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, @@ -24,6 +24,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State from .typing import ( ConnectionOption, ExtensionHeader, @@ -40,10 +41,10 @@ from .legacy.client import * # isort:skip # noqa -__all__ = ["ClientConnection"] +__all__ = ["ClientProtocol"] -class ClientConnection(Connection): +class ClientProtocol(Protocol): """ Sans-I/O implementation of a WebSocket client connection. @@ -342,3 +343,12 @@ def parse(self) -> Generator[None, None, None]: self.events.append(response) yield from super().parse() + + +class ClientConnection(ClientProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "ClientConnection was renamed to ClientProtocol", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 7a9be9f2e..5ce4d6a3b 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -1,702 +1,13 @@ from __future__ import annotations -import enum -import logging -import uuid -from typing import Generator, List, Optional, Type, Union +import warnings -from .exceptions import ( - ConnectionClosed, - ConnectionClosedError, - ConnectionClosedOK, - InvalidState, - PayloadTooBig, - ProtocolError, -) -from .extensions import Extension -from .frames import ( - OK_CLOSE_CODES, - OP_BINARY, - OP_CLOSE, - OP_CONT, - OP_PING, - OP_PONG, - OP_TEXT, - Close, - Frame, -) -from .http11 import Request, Response -from .streams import StreamReader -from .typing import LoggerLike, Origin, Subprotocol - - -__all__ = [ - "Connection", - "Side", - "State", - "SEND_EOF", -] - -Event = Union[Request, Response, Frame] -"""Events that :meth:`~Connection.events_received` may return.""" - - -class Side(enum.IntEnum): - """A WebSocket connection is either a server or a client.""" - - SERVER, CLIENT = range(2) - - -SERVER = Side.SERVER -CLIENT = Side.CLIENT - - -class State(enum.IntEnum): - """A WebSocket connection is in one of these four states.""" - - CONNECTING, OPEN, CLOSING, CLOSED = range(4) - - -CONNECTING = State.CONNECTING -OPEN = State.OPEN -CLOSING = State.CLOSING -CLOSED = State.CLOSED - - -SEND_EOF = b"" -"""Sentinel signaling that the TCP connection must be half-closed.""" - - -class Connection: - """ - Sans-I/O implementation of a WebSocket connection. - - Args: - side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. - logger: logger for this connection; depending on ``side``, - defaults to ``logging.getLogger("websockets.client")`` - or ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../topics/logging>` for details. - - """ - - def __init__( - self, - side: Side, - state: State = OPEN, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, - ) -> None: - # Unique identifier. For logs. - self.id: uuid.UUID = uuid.uuid4() - """Unique identifier of the connection. Useful in logs.""" - - # Logger or LoggerAdapter for this connection. - if logger is None: - logger = logging.getLogger(f"websockets.{side.name.lower()}") - self.logger: LoggerLike = logger - """Logger for this connection.""" - - # Track if DEBUG is enabled. Shortcut logging calls if it isn't. - self.debug = logger.isEnabledFor(logging.DEBUG) - - # Connection side. CLIENT or SERVER. - self.side = side - - # Connection state. Initially OPEN because subclasses handle CONNECTING. - self.state = state - - # Maximum size of incoming messages in bytes. - self.max_size = max_size - - # Current size of incoming message in bytes. Only set while reading a - # fragmented message i.e. a data frames with the FIN bit not set. - self.cur_size: Optional[int] = None - - # True while sending a fragmented message i.e. a data frames with the - # FIN bit not set. - self.expect_continuation_frame = False - - # WebSocket protocol parameters. - self.origin: Optional[Origin] = None - self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None - - # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None - - # Track if an exception happened during the handshake. - self.handshake_exc: Optional[Exception] = None - """ - Exception to raise if the opening handshake failed. - - :obj:`None` if the opening handshake succeeded. - - """ - - # Track if send_eof() was called. - self.eof_sent = False - - # Parser state. - self.reader = StreamReader() - self.events: List[Event] = [] - self.writes: List[bytes] = [] - self.parser = self.parse() - next(self.parser) # start coroutine - self.parser_exc: Optional[Exception] = None - - @property - def state(self) -> State: - """ - WebSocket connection state. - - Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. - - """ - return self._state - - @state.setter - def state(self, state: State) -> None: - if self.debug: - self.logger.debug("= connection is %s", state.name) - self._state = state - - @property - def close_code(self) -> Optional[int]: - """ - `WebSocket close code`_. - - .. _WebSocket close code: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 - - :obj:`None` if the connection isn't closed yet. - - """ - if self.state is not CLOSED: - return None - elif self.close_rcvd is None: - return 1006 - else: - return self.close_rcvd.code - - @property - def close_reason(self) -> Optional[str]: - """ - `WebSocket close reason`_. - - .. _WebSocket close reason: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 - - :obj:`None` if the connection isn't closed yet. - - """ - if self.state is not CLOSED: - return None - elif self.close_rcvd is None: - return "" - else: - return self.close_rcvd.reason - - @property - def close_exc(self) -> ConnectionClosed: - """ - Exception to raise when trying to interact with a closed connection. - - Don't raise this exception while the connection :attr:`state` - is :attr:`~websockets.connection.State.CLOSING`; wait until - it's :attr:`~websockets.connection.State.CLOSED`. - - Indeed, the exception includes the close code and reason, which are - known only once the connection is closed. - - Raises: - AssertionError: if the connection isn't closed yet. - - """ - assert self.state is CLOSED, "connection isn't closed yet" - exc_type: Type[ConnectionClosed] - if ( - self.close_rcvd is not None - and self.close_sent is not None - and self.close_rcvd.code in OK_CLOSE_CODES - and self.close_sent.code in OK_CLOSE_CODES - ): - exc_type = ConnectionClosedOK - else: - exc_type = ConnectionClosedError - exc: ConnectionClosed = exc_type( - self.close_rcvd, - self.close_sent, - self.close_rcvd_then_sent, - ) - # Chain to the exception raised in the parser, if any. - exc.__cause__ = self.parser_exc - return exc - - # Public methods for receiving data. - - def receive_data(self, data: bytes) -> None: - """ - Receive data from the network. - - After calling this method: - - - You must call :meth:`data_to_send` and send this data to the network. - - You should call :meth:`events_received` and process resulting events. - - Raises: - EOFError: if :meth:`receive_eof` was called earlier. - - """ - self.reader.feed_data(data) - next(self.parser) - - def receive_eof(self) -> None: - """ - Receive the end of the data stream from the network. - - After calling this method: - - - You must call :meth:`data_to_send` and send this data to the network. - - You aren't expected to call :meth:`events_received`; it won't return - any new events. - - Raises: - EOFError: if :meth:`receive_eof` was called earlier. - - """ - self.reader.feed_eof() - next(self.parser) - - # Public methods for sending events. - - def send_continuation(self, data: bytes, fin: bool) -> None: - """ - Send a `Continuation frame`_. - - .. _Continuation frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 - - Parameters: - data: payload containing the same kind of data - as the initial frame. - fin: FIN bit; set it to :obj:`True` if this is the last frame - of a fragmented message and to :obj:`False` otherwise. - - Raises: - ProtocolError: if a fragmented message isn't in progress. - - """ - if not self.expect_continuation_frame: - raise ProtocolError("unexpected continuation frame") - self.expect_continuation_frame = not fin - self.send_frame(Frame(OP_CONT, data, fin)) - - def send_text(self, data: bytes, fin: bool = True) -> None: - """ - Send a `Text frame`_. - - .. _Text frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 - - Parameters: - data: payload containing text encoded with UTF-8. - fin: FIN bit; set it to :obj:`False` if this is the first frame of - a fragmented message. - - Raises: - ProtocolError: if a fragmented message is in progress. - - """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") - self.expect_continuation_frame = not fin - self.send_frame(Frame(OP_TEXT, data, fin)) - - def send_binary(self, data: bytes, fin: bool = True) -> None: - """ - Send a `Binary frame`_. +# lazy_import doesn't support this use case. +from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa - .. _Binary frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 - Parameters: - data: payload containing arbitrary binary data. - fin: FIN bit; set it to :obj:`False` if this is the first frame of - a fragmented message. - - Raises: - ProtocolError: if a fragmented message is in progress. - - """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") - self.expect_continuation_frame = not fin - self.send_frame(Frame(OP_BINARY, data, fin)) - - def send_close(self, code: Optional[int] = None, reason: str = "") -> None: - """ - Send a `Close frame`_. - - .. _Close frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 - - Parameters: - code: close code. - reason: close reason. - - Raises: - ProtocolError: if a fragmented message is being sent, if the code - isn't valid, or if a reason is provided without a code - - """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") - if code is None: - if reason != "": - raise ProtocolError("cannot send a reason without a code") - close = Close(1005, "") - data = b"" - else: - close = Close(code, reason) - data = close.serialize() - # send_frame() guarantees that self.state is OPEN at this point. - # 7.1.3. The WebSocket Closing Handshake is Started - self.send_frame(Frame(OP_CLOSE, data)) - self.close_sent = close - self.state = CLOSING - - def send_ping(self, data: bytes) -> None: - """ - Send a `Ping frame`_. - - .. _Ping frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 - - Parameters: - data: payload containing arbitrary binary data. - - """ - self.send_frame(Frame(OP_PING, data)) - - def send_pong(self, data: bytes) -> None: - """ - Send a `Pong frame`_. - - .. _Pong frame: - https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - - Parameters: - data: payload containing arbitrary binary data. - - """ - self.send_frame(Frame(OP_PONG, data)) - - def fail(self, code: int, reason: str = "") -> None: - """ - `Fail the WebSocket connection`_. - - .. _Fail the WebSocket connection: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 - - Parameters: - code: close code - reason: close reason - - Raises: - ProtocolError: if the code isn't valid. - """ - # 7.1.7. Fail the WebSocket Connection - - # Send a close frame when the state is OPEN (a close frame was already - # sent if it's CLOSING), except when failing the connection because - # of an error reading from or writing to the network. - if self.state is OPEN: - if code != 1006: - close = Close(code, reason) - data = close.serialize() - self.send_frame(Frame(OP_CLOSE, data)) - self.close_sent = close - self.state = CLOSING - - # When failing the connection, a server closes the TCP connection - # without waiting for the client to complete the handshake, while a - # client waits for the server to close the TCP connection, possibly - # after sending a close frame that the client will ignore. - if self.side is SERVER and not self.eof_sent: - self.send_eof() - - # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue - # to attempt to process data(including a responding Close frame) from - # the remote endpoint after being instructed to _Fail the WebSocket - # Connection_." - self.parser = self.discard() - next(self.parser) # start coroutine - - # Public method for getting incoming events after receiving data. - - def events_received(self) -> List[Event]: - """ - Fetch events generated from data received from the network. - - Call this method immediately after any of the ``receive_*()`` methods. - - Process resulting events, likely by passing them to the application. - - Returns: - List[Event]: Events read from the connection. - """ - events, self.events = self.events, [] - return events - - # Public method for getting outgoing data after receiving data or sending events. - - def data_to_send(self) -> List[bytes]: - """ - Obtain data to send to the network. - - Call this method immediately after any of the ``receive_*()``, - ``send_*()``, or :meth:`fail` methods. - - Write resulting data to the connection. - - The empty bytestring :data:`~websockets.connection.SEND_EOF` signals - the end of the data stream. When you receive it, half-close the TCP - connection. - - Returns: - List[bytes]: Data to write to the connection. - - """ - writes, self.writes = self.writes, [] - return writes - - def close_expected(self) -> bool: - """ - Tell if the TCP connection is expected to close soon. - - Call this method immediately after any of the ``receive_*()`` or - :meth:`fail` methods. - - If it returns :obj:`True`, schedule closing the TCP connection after a - short timeout if the other side hasn't already closed it. - - Returns: - bool: Whether the TCP connection is expected to close soon. - - """ - # We expect a TCP close if and only if we sent a close frame: - # * Normal closure: once we send a close frame, we expect a TCP close: - # server waits for client to complete the TCP closing handshake; - # client waits for server to initiate the TCP closing handshake. - # * Abnormal closure: we always send a close frame and the same logic - # applies, except on EOFError where we don't send a close frame - # because we already received the TCP close, so we don't expect it. - # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING or self.handshake_exc is not None - - # Private methods for receiving data. - - def parse(self) -> Generator[None, None, None]: - """ - Parse incoming data into frames. - - :meth:`receive_data` and :meth:`receive_eof` run this generator - coroutine until it needs more data or reaches EOF. - - """ - try: - while True: - if (yield from self.reader.at_eof()): - if self.debug: - self.logger.debug("< EOF") - # If the WebSocket connection is closed cleanly, with a - # closing handhshake, recv_frame() substitutes parse() - # with discard(). This branch is reached only when the - # connection isn't closed cleanly. - raise EOFError("unexpected end of stream") - - if self.max_size is None: - max_size = None - elif self.cur_size is None: - max_size = self.max_size - else: - max_size = self.max_size - self.cur_size - - # During a normal closure, execution ends here on the next - # iteration of the loop after receiving a close frame. At - # this point, recv_frame() replaced parse() by discard(). - frame = yield from Frame.parse( - self.reader.read_exact, - mask=self.side is SERVER, - max_size=max_size, - extensions=self.extensions, - ) - - if self.debug: - self.logger.debug("< %s", frame) - - self.recv_frame(frame) - - except ProtocolError as exc: - self.fail(1002, str(exc)) - self.parser_exc = exc - - except EOFError as exc: - self.fail(1006, str(exc)) - self.parser_exc = exc - - except UnicodeDecodeError as exc: - self.fail(1007, f"{exc.reason} at position {exc.start}") - self.parser_exc = exc - - except PayloadTooBig as exc: - self.fail(1009, str(exc)) - self.parser_exc = exc - - except Exception as exc: - self.logger.error("parser failed", exc_info=True) - # Don't include exception details, which may be security-sensitive. - self.fail(1011) - self.parser_exc = exc - - # During an abnormal closure, execution ends here after catching an - # exception. At this point, fail() replaced parse() by discard(). - yield - raise AssertionError("parse() shouldn't step after error") - - def discard(self) -> Generator[None, None, None]: - """ - Discard incoming data. - - This coroutine replaces :meth:`parse`: - - - after receiving a close frame, during a normal closure (1.4); - - after sending a close frame, during an abnormal closure (7.1.7). - - """ - # The server close the TCP connection in the same circumstances where - # discard() replaces parse(). The client closes the connection later, - # after the server closes the connection or a timeout elapses. - # (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.side is SERVER) == (self.eof_sent) - while not (yield from self.reader.at_eof()): - self.reader.discard() - if self.debug: - self.logger.debug("< EOF") - # A server closes the TCP connection immediately, while a client - # waits for the server to close the TCP connection. - if self.side is CLIENT: - self.send_eof() - self.state = CLOSED - # If discard() completes normally, execution ends here. - yield - # Once the reader reaches EOF, its feed_data/eof() methods raise an - # error, so our receive_data/eof() methods don't step the generator. - raise AssertionError("discard() shouldn't step after EOF") - - def recv_frame(self, frame: Frame) -> None: - """ - Process an incoming frame. - - """ - if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: - if self.cur_size is not None: - raise ProtocolError("expected a continuation frame") - if frame.fin: - self.cur_size = None - else: - self.cur_size = len(frame.data) - - elif frame.opcode is OP_CONT: - if self.cur_size is None: - raise ProtocolError("unexpected continuation frame") - if frame.fin: - self.cur_size = None - else: - self.cur_size += len(frame.data) - - elif frame.opcode is OP_PING: - # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST - # send a Pong frame in response" - pong_frame = Frame(OP_PONG, frame.data) - self.send_frame(pong_frame) - - elif frame.opcode is OP_PONG: - # 5.5.3 Pong: "A response to an unsolicited Pong frame is not - # expected." - pass - - elif frame.opcode is OP_CLOSE: - # 7.1.5. The WebSocket Connection Close Code - # 7.1.6. The WebSocket Connection Close Reason - self.close_rcvd = Close.parse(frame.data) - if self.state is CLOSING: - assert self.close_sent is not None - self.close_rcvd_then_sent = False - - if self.cur_size is not None: - raise ProtocolError("incomplete fragmented message") - - # 5.5.1 Close: "If an endpoint receives a Close frame and did - # not previously send a Close frame, the endpoint MUST send a - # Close frame in response. (When sending a Close frame in - # response, the endpoint typically echos the status code it - # received.)" - - if self.state is OPEN: - # Echo the original data instead of re-serializing it with - # Close.serialize() because that fails when the close frame - # is empty and Close.parse() synthetizes a 1005 close code. - # The rest is identical to send_close(). - self.send_frame(Frame(OP_CLOSE, frame.data)) - self.close_sent = self.close_rcvd - self.close_rcvd_then_sent = True - self.state = CLOSING - - # 7.1.2. Start the WebSocket Closing Handshake: "Once an - # endpoint has both sent and received a Close control frame, - # that endpoint SHOULD _Close the WebSocket Connection_" - - # A server closes the TCP connection immediately, while a client - # waits for the server to close the TCP connection. - if self.side is SERVER: - self.send_eof() - - # 1.4. Closing Handshake: "after receiving a control frame - # indicating the connection should be closed, a peer discards - # any further data received." - self.parser = self.discard() - next(self.parser) # start coroutine - - else: - # This can't happen because Frame.parse() validates opcodes. - raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") - - self.events.append(frame) - - # Private methods for sending events. - - def send_frame(self, frame: Frame) -> None: - if self.state is not OPEN: - raise InvalidState( - f"cannot write to a WebSocket in the {self.state.name} state" - ) - - if self.debug: - self.logger.debug("> %s", frame) - self.writes.append( - frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) - ) - - def send_eof(self) -> None: - assert not self.eof_sent - self.eof_sent = True - if self.debug: - self.logger.debug("> EOF") - self.writes.append(SEND_EOF) +warnings.warn( + "websockets.connection was renamed to websockets.protocol " + "and Connection was renamed to Protocol", + DeprecationWarning, +) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 68249192c..ec4e3b8b7 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -68,7 +68,7 @@ class Request: def exception(self) -> Optional[Exception]: # pragma: no cover warnings.warn( "Request.exception is deprecated; " - "use ServerConnection.handshake_exc instead", + "use ServerProtocol.handshake_exc instead", DeprecationWarning, ) return self._exception @@ -172,7 +172,7 @@ class Response: def exception(self) -> Optional[Exception]: # pragma: no cover warnings.warn( "Response.exception is deprecated; " - "use ClientConnection.handshake_exc instead", + "use ClientProtocol.handshake_exc instead", DeprecationWarning, ) return self._exception diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 9f2bda1ab..7881b947d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -27,7 +27,6 @@ cast, ) -from ..connection import State from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, @@ -51,6 +50,7 @@ prepare_ctrl, prepare_data, ) +from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol from .compatibility import loop_if_py_lt_38 from .framing import Frame diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index f37a16c20..eabeb8e96 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -25,7 +25,6 @@ cast, ) -from ..connection import State from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( AbortHandshake, @@ -45,6 +44,7 @@ validate_subprotocols, ) from ..http import USER_AGENT +from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from .compatibility import loop_if_py_lt_38 from .handshake import build_response, check_request diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py new file mode 100644 index 000000000..29d7e1596 --- /dev/null +++ b/src/websockets/protocol.py @@ -0,0 +1,702 @@ +from __future__ import annotations + +import enum +import logging +import uuid +from typing import Generator, List, Optional, Type, Union + +from .exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from .extensions import Extension +from .frames import ( + OK_CLOSE_CODES, + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Close, + Frame, +) +from .http11 import Request, Response +from .streams import StreamReader +from .typing import LoggerLike, Origin, Subprotocol + + +__all__ = [ + "Protocol", + "Side", + "State", + "SEND_EOF", +] + +Event = Union[Request, Response, Frame] +"""Events that :meth:`~Protocol.events_received` may return.""" + + +class Side(enum.IntEnum): + """A WebSocket connection is either a server or a client.""" + + SERVER, CLIENT = range(2) + + +SERVER = Side.SERVER +CLIENT = Side.CLIENT + + +class State(enum.IntEnum): + """A WebSocket connection is in one of these four states.""" + + CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + +CONNECTING = State.CONNECTING +OPEN = State.OPEN +CLOSING = State.CLOSING +CLOSED = State.CLOSED + + +SEND_EOF = b"" +"""Sentinel signaling that the TCP connection must be half-closed.""" + + +class Protocol: + """ + Sans-I/O implementation of a WebSocket connection. + + Args: + side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. + state: initial state of the WebSocket connection. + max_size: maximum size of incoming messages in bytes; + :obj:`None` to disable the limit. + logger: logger for this connection; depending on ``side``, + defaults to ``logging.getLogger("websockets.client")`` + or ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../topics/logging>` for details. + + """ + + def __init__( + self, + side: Side, + state: State = OPEN, + max_size: Optional[int] = 2**20, + logger: Optional[LoggerLike] = None, + ) -> None: + # Unique identifier. For logs. + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" + + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger(f"websockets.{side.name.lower()}") + self.logger: LoggerLike = logger + """Logger for this connection.""" + + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) + + # Connection side. CLIENT or SERVER. + self.side = side + + # Connection state. Initially OPEN because subclasses handle CONNECTING. + self.state = state + + # Maximum size of incoming messages in bytes. + self.max_size = max_size + + # Current size of incoming message in bytes. Only set while reading a + # fragmented message i.e. a data frames with the FIN bit not set. + self.cur_size: Optional[int] = None + + # True while sending a fragmented message i.e. a data frames with the + # FIN bit not set. + self.expect_continuation_frame = False + + # WebSocket protocol parameters. + self.origin: Optional[Origin] = None + self.extensions: List[Extension] = [] + self.subprotocol: Optional[Subprotocol] = None + + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Optional[Close] = None + self.close_sent: Optional[Close] = None + self.close_rcvd_then_sent: Optional[bool] = None + + # Track if an exception happened during the handshake. + self.handshake_exc: Optional[Exception] = None + """ + Exception to raise if the opening handshake failed. + + :obj:`None` if the opening handshake succeeded. + + """ + + # Track if send_eof() was called. + self.eof_sent = False + + # Parser state. + self.reader = StreamReader() + self.events: List[Event] = [] + self.writes: List[bytes] = [] + self.parser = self.parse() + next(self.parser) # start coroutine + self.parser_exc: Optional[Exception] = None + + @property + def state(self) -> State: + """ + WebSocket connection state. + + Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. + + """ + return self._state + + @state.setter + def state(self, state: State) -> None: + if self.debug: + self.logger.debug("= connection is %s", state.name) + self._state = state + + @property + def close_code(self) -> Optional[int]: + """ + `WebSocket close code`_. + + .. _WebSocket close code: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not CLOSED: + return None + elif self.close_rcvd is None: + return 1006 + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> Optional[str]: + """ + `WebSocket close reason`_. + + .. _WebSocket close reason: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + + @property + def close_exc(self) -> ConnectionClosed: + """ + Exception to raise when trying to interact with a closed connection. + + Don't raise this exception while the connection :attr:`state` + is :attr:`~websockets.protocol.State.CLOSING`; wait until + it's :attr:`~websockets.protocol.State.CLOSED`. + + Indeed, the exception includes the close code and reason, which are + known only once the connection is closed. + + Raises: + AssertionError: if the connection isn't closed yet. + + """ + assert self.state is CLOSED, "connection isn't closed yet" + exc_type: Type[ConnectionClosed] + if ( + self.close_rcvd is not None + and self.close_sent is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent.code in OK_CLOSE_CODES + ): + exc_type = ConnectionClosedOK + else: + exc_type = ConnectionClosedError + exc: ConnectionClosed = exc_type( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + # Chain to the exception raised in the parser, if any. + exc.__cause__ = self.parser_exc + return exc + + # Public methods for receiving data. + + def receive_data(self, data: bytes) -> None: + """ + Receive data from the network. + + After calling this method: + + - You must call :meth:`data_to_send` and send this data to the network. + - You should call :meth:`events_received` and process resulting events. + + Raises: + EOFError: if :meth:`receive_eof` was called earlier. + + """ + self.reader.feed_data(data) + next(self.parser) + + def receive_eof(self) -> None: + """ + Receive the end of the data stream from the network. + + After calling this method: + + - You must call :meth:`data_to_send` and send this data to the network. + - You aren't expected to call :meth:`events_received`; it won't return + any new events. + + Raises: + EOFError: if :meth:`receive_eof` was called earlier. + + """ + self.reader.feed_eof() + next(self.parser) + + # Public methods for sending events. + + def send_continuation(self, data: bytes, fin: bool) -> None: + """ + Send a `Continuation frame`_. + + .. _Continuation frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing the same kind of data + as the initial frame. + fin: FIN bit; set it to :obj:`True` if this is the last frame + of a fragmented message and to :obj:`False` otherwise. + + Raises: + ProtocolError: if a fragmented message isn't in progress. + + """ + if not self.expect_continuation_frame: + raise ProtocolError("unexpected continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_CONT, data, fin)) + + def send_text(self, data: bytes, fin: bool = True) -> None: + """ + Send a `Text frame`_. + + .. _Text frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing text encoded with UTF-8. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: if a fragmented message is in progress. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_TEXT, data, fin)) + + def send_binary(self, data: bytes, fin: bool = True) -> None: + """ + Send a `Binary frame`_. + + .. _Binary frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing arbitrary binary data. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: if a fragmented message is in progress. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_BINARY, data, fin)) + + def send_close(self, code: Optional[int] = None, reason: str = "") -> None: + """ + Send a `Close frame`_. + + .. _Close frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + + Parameters: + code: close code. + reason: close reason. + + Raises: + ProtocolError: if a fragmented message is being sent, if the code + isn't valid, or if a reason is provided without a code + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + if code is None: + if reason != "": + raise ProtocolError("cannot send a reason without a code") + close = Close(1005, "") + data = b"" + else: + close = Close(code, reason) + data = close.serialize() + # send_frame() guarantees that self.state is OPEN at this point. + # 7.1.3. The WebSocket Closing Handshake is Started + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close + self.state = CLOSING + + def send_ping(self, data: bytes) -> None: + """ + Send a `Ping frame`_. + + .. _Ping frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + Parameters: + data: payload containing arbitrary binary data. + + """ + self.send_frame(Frame(OP_PING, data)) + + def send_pong(self, data: bytes) -> None: + """ + Send a `Pong frame`_. + + .. _Pong frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + Parameters: + data: payload containing arbitrary binary data. + + """ + self.send_frame(Frame(OP_PONG, data)) + + def fail(self, code: int, reason: str = "") -> None: + """ + `Fail the WebSocket connection`_. + + .. _Fail the WebSocket connection: + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 + + Parameters: + code: close code + reason: close reason + + Raises: + ProtocolError: if the code isn't valid. + """ + # 7.1.7. Fail the WebSocket Connection + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because + # of an error reading from or writing to the network. + if self.state is OPEN: + if code != 1006: + close = Close(code, reason) + data = close.serialize() + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close + self.state = CLOSING + + # When failing the connection, a server closes the TCP connection + # without waiting for the client to complete the handshake, while a + # client waits for the server to close the TCP connection, possibly + # after sending a close frame that the client will ignore. + if self.side is SERVER and not self.eof_sent: + self.send_eof() + + # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue + # to attempt to process data(including a responding Close frame) from + # the remote endpoint after being instructed to _Fail the WebSocket + # Connection_." + self.parser = self.discard() + next(self.parser) # start coroutine + + # Public method for getting incoming events after receiving data. + + def events_received(self) -> List[Event]: + """ + Fetch events generated from data received from the network. + + Call this method immediately after any of the ``receive_*()`` methods. + + Process resulting events, likely by passing them to the application. + + Returns: + List[Event]: Events read from the connection. + """ + events, self.events = self.events, [] + return events + + # Public method for getting outgoing data after receiving data or sending events. + + def data_to_send(self) -> List[bytes]: + """ + Obtain data to send to the network. + + Call this method immediately after any of the ``receive_*()``, + ``send_*()``, or :meth:`fail` methods. + + Write resulting data to the connection. + + The empty bytestring :data:`~websockets.protocol.SEND_EOF` signals + the end of the data stream. When you receive it, half-close the TCP + connection. + + Returns: + List[bytes]: Data to write to the connection. + + """ + writes, self.writes = self.writes, [] + return writes + + def close_expected(self) -> bool: + """ + Tell if the TCP connection is expected to close soon. + + Call this method immediately after any of the ``receive_*()`` or + :meth:`fail` methods. + + If it returns :obj:`True`, schedule closing the TCP connection after a + short timeout if the other side hasn't already closed it. + + Returns: + bool: Whether the TCP connection is expected to close soon. + + """ + # We expect a TCP close if and only if we sent a close frame: + # * Normal closure: once we send a close frame, we expect a TCP close: + # server waits for client to complete the TCP closing handshake; + # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic + # applies, except on EOFError where we don't send a close frame + # because we already received the TCP close, so we don't expect it. + # We already got a TCP Close if and only if the state is CLOSED. + return self.state is CLOSING or self.handshake_exc is not None + + # Private methods for receiving data. + + def parse(self) -> Generator[None, None, None]: + """ + Parse incoming data into frames. + + :meth:`receive_data` and :meth:`receive_eof` run this generator + coroutine until it needs more data or reaches EOF. + + """ + try: + while True: + if (yield from self.reader.at_eof()): + if self.debug: + self.logger.debug("< EOF") + # If the WebSocket connection is closed cleanly, with a + # closing handhshake, recv_frame() substitutes parse() + # with discard(). This branch is reached only when the + # connection isn't closed cleanly. + raise EOFError("unexpected end of stream") + + if self.max_size is None: + max_size = None + elif self.cur_size is None: + max_size = self.max_size + else: + max_size = self.max_size - self.cur_size + + # During a normal closure, execution ends here on the next + # iteration of the loop after receiving a close frame. At + # this point, recv_frame() replaced parse() by discard(). + frame = yield from Frame.parse( + self.reader.read_exact, + mask=self.side is SERVER, + max_size=max_size, + extensions=self.extensions, + ) + + if self.debug: + self.logger.debug("< %s", frame) + + self.recv_frame(frame) + + except ProtocolError as exc: + self.fail(1002, str(exc)) + self.parser_exc = exc + + except EOFError as exc: + self.fail(1006, str(exc)) + self.parser_exc = exc + + except UnicodeDecodeError as exc: + self.fail(1007, f"{exc.reason} at position {exc.start}") + self.parser_exc = exc + + except PayloadTooBig as exc: + self.fail(1009, str(exc)) + self.parser_exc = exc + + except Exception as exc: + self.logger.error("parser failed", exc_info=True) + # Don't include exception details, which may be security-sensitive. + self.fail(1011) + self.parser_exc = exc + + # During an abnormal closure, execution ends here after catching an + # exception. At this point, fail() replaced parse() by discard(). + yield + raise AssertionError("parse() shouldn't step after error") + + def discard(self) -> Generator[None, None, None]: + """ + Discard incoming data. + + This coroutine replaces :meth:`parse`: + + - after receiving a close frame, during a normal closure (1.4); + - after sending a close frame, during an abnormal closure (7.1.7). + + """ + # The server close the TCP connection in the same circumstances where + # discard() replaces parse(). The client closes the connection later, + # after the server closes the connection or a timeout elapses. + # (The latter case cannot be handled in this Sans-I/O layer.) + assert (self.side is SERVER) == (self.eof_sent) + while not (yield from self.reader.at_eof()): + self.reader.discard() + if self.debug: + self.logger.debug("< EOF") + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. + if self.side is CLIENT: + self.send_eof() + self.state = CLOSED + # If discard() completes normally, execution ends here. + yield + # Once the reader reaches EOF, its feed_data/eof() methods raise an + # error, so our receive_data/eof() methods don't step the generator. + raise AssertionError("discard() shouldn't step after EOF") + + def recv_frame(self, frame: Frame) -> None: + """ + Process an incoming frame. + + """ + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: + if self.cur_size is not None: + raise ProtocolError("expected a continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size = len(frame.data) + + elif frame.opcode is OP_CONT: + if self.cur_size is None: + raise ProtocolError("unexpected continuation frame") + if frame.fin: + self.cur_size = None + else: + self.cur_size += len(frame.data) + + elif frame.opcode is OP_PING: + # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST + # send a Pong frame in response" + pong_frame = Frame(OP_PONG, frame.data) + self.send_frame(pong_frame) + + elif frame.opcode is OP_PONG: + # 5.5.3 Pong: "A response to an unsolicited Pong frame is not + # expected." + pass + + elif frame.opcode is OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_rcvd = Close.parse(frame.data) + if self.state is CLOSING: + assert self.close_sent is not None + self.close_rcvd_then_sent = False + + if self.cur_size is not None: + raise ProtocolError("incomplete fragmented message") + + # 5.5.1 Close: "If an endpoint receives a Close frame and did + # not previously send a Close frame, the endpoint MUST send a + # Close frame in response. (When sending a Close frame in + # response, the endpoint typically echos the status code it + # received.)" + + if self.state is OPEN: + # Echo the original data instead of re-serializing it with + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthetizes a 1005 close code. + # The rest is identical to send_close(). + self.send_frame(Frame(OP_CLOSE, frame.data)) + self.close_sent = self.close_rcvd + self.close_rcvd_then_sent = True + self.state = CLOSING + + # 7.1.2. Start the WebSocket Closing Handshake: "Once an + # endpoint has both sent and received a Close control frame, + # that endpoint SHOULD _Close the WebSocket Connection_" + + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. + if self.side is SERVER: + self.send_eof() + + # 1.4. Closing Handshake: "after receiving a control frame + # indicating the connection should be closed, a peer discards + # any further data received." + self.parser = self.discard() + next(self.parser) # start coroutine + + else: + # This can't happen because Frame.parse() validates opcodes. + raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") + + self.events.append(frame) + + # Private methods for sending events. + + def send_frame(self, frame: Frame) -> None: + if self.state is not OPEN: + raise InvalidState( + f"cannot write to a WebSocket in the {self.state.name} state" + ) + + if self.debug: + self.logger.debug("> %s", frame) + self.writes.append( + frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) + ) + + def send_eof(self) -> None: + assert not self.eof_sent + self.eof_sent = True + if self.debug: + self.logger.debug("> EOF") + self.writes.append(SEND_EOF) diff --git a/src/websockets/server.py b/src/websockets/server.py index edd1764c3..548048c92 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,9 +4,9 @@ import binascii import email.utils import http -from typing import Generator, List, Optional, Sequence, Tuple, cast +import warnings +from typing import Any, Generator, List, Optional, Sequence, Tuple, cast -from .connection import CONNECTING, OPEN, SERVER, Connection, State from .datastructures import Headers, MultipleValuesError from .exceptions import ( InvalidHandshake, @@ -26,6 +26,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .protocol import CONNECTING, OPEN, SERVER, Protocol, State from .typing import ( ConnectionOption, ExtensionHeader, @@ -41,10 +42,10 @@ from .legacy.server import * # isort:skip # noqa -__all__ = ["ServerConnection"] +__all__ = ["ServerProtocol"] -class ServerConnection(Connection): +class ServerProtocol(Protocol): """ Sans-I/O implementation of a WebSocket server connection. @@ -515,3 +516,12 @@ def parse(self) -> Generator[None, None, None]: self.events.append(request) yield from super().parse() + + +class ServerConnection(ServerProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "ServerConnection was renamed to ServerProtocol", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f8a79027e..4a4510536 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -16,7 +16,6 @@ import urllib.request import warnings -from websockets.connection import State from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -35,6 +34,7 @@ from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * +from websockets.protocol import State from websockets.uri import parse_uri from ..extensions.utils import ( diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 1f830ebee..e85402a39 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -4,7 +4,6 @@ import unittest.mock import warnings -from websockets.connection import State from websockets.exceptions import ConnectionClosed, InvalidState from websockets.frames import ( OP_BINARY, @@ -18,6 +17,7 @@ from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast +from websockets.protocol import State from .utils import MS, AsyncioTestCase diff --git a/tests/test_client.py b/tests/test_client.py index 9ed36c1d4..718219b9d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,11 +3,11 @@ import unittest.mock from websockets.client import * -from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers from websockets.exceptions import InvalidHandshake, InvalidHeader from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response +from websockets.protocol import CONNECTING, OPEN from websockets.uri import parse_uri from websockets.utils import accept_key @@ -18,13 +18,13 @@ Rsv2Extension, ) from .test_utils import ACCEPT, KEY -from .utils import DATE +from .utils import DATE, DeprecationTestCase class ConnectTests(unittest.TestCase): def test_send_connect(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("wss://example.com/test")) + client = ClientProtocol(parse_uri("wss://example.com/test")) request = client.connect() self.assertIsInstance(request, Request) client.send_request(request) @@ -44,7 +44,7 @@ def test_send_connect(self): def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("wss://example.com/test")) + client = ClientProtocol(parse_uri("wss://example.com/test")) request = client.connect() self.assertEqual(request.path, "/test") self.assertEqual( @@ -61,7 +61,7 @@ def test_connect_request(self): ) def test_path(self): - client = ClientConnection(parse_uri("wss://example.com/endpoint?test=1")) + client = ClientProtocol(parse_uri("wss://example.com/endpoint?test=1")) request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") @@ -76,19 +76,19 @@ def test_port(self): ("wss://example.com:8443/", "example.com:8443"), ]: with self.subTest(uri=uri): - client = ClientConnection(parse_uri(uri)) + client = ClientProtocol(parse_uri(uri)) request = client.connect() self.assertEqual(request.headers["Host"], host) def test_user_info(self): - client = ClientConnection(parse_uri("wss://hello:iloveyou@example.com/")) + client = ClientProtocol(parse_uri("wss://hello:iloveyou@example.com/")) request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), origin="https://example.com", ) @@ -97,7 +97,7 @@ def test_origin(self): self.assertEqual(request.headers["Origin"], "https://example.com") def test_extensions(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory()], ) @@ -106,7 +106,7 @@ def test_extensions(self): self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), subprotocols=["chat"], ) @@ -118,7 +118,7 @@ def test_subprotocols(self): class AcceptRejectTests(unittest.TestCase): def test_receive_accept(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("ws://example.com/test")) + client = ClientProtocol(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -138,7 +138,7 @@ def test_receive_accept(self): def test_receive_reject(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("ws://example.com/test")) + client = ClientProtocol(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -159,7 +159,7 @@ def test_receive_reject(self): def test_accept_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("ws://example.com/test")) + client = ClientProtocol(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -189,7 +189,7 @@ def test_accept_response(self): def test_reject_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientConnection(parse_uri("ws://example.com/test")) + client = ClientProtocol(parse_uri("ws://example.com/test")) client.connect() client.receive_data( ( @@ -235,7 +235,7 @@ def make_accept_response(self, client): ) def test_basic(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -243,7 +243,7 @@ def test_basic(self): self.assertEqual(client.state, OPEN) def test_missing_connection(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Connection"] client.receive_data(response.serialize()) @@ -255,7 +255,7 @@ def test_missing_connection(self): self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Connection"] response.headers["Connection"] = "close" @@ -268,7 +268,7 @@ def test_invalid_connection(self): self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Upgrade"] client.receive_data(response.serialize()) @@ -280,7 +280,7 @@ def test_missing_upgrade(self): self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Upgrade"] response.headers["Upgrade"] = "h2c" @@ -293,7 +293,7 @@ def test_invalid_upgrade(self): self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_accept(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] client.receive_data(response.serialize()) @@ -305,7 +305,7 @@ def test_missing_accept(self): self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") def test_multiple_accept(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Accept"] = ACCEPT client.receive_data(response.serialize()) @@ -321,7 +321,7 @@ def test_multiple_accept(self): ) def test_invalid_accept(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) del response.headers["Sec-WebSocket-Accept"] response.headers["Sec-WebSocket-Accept"] = ACCEPT @@ -336,7 +336,7 @@ def test_invalid_accept(self): ) def test_no_extensions(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -345,7 +345,7 @@ def test_no_extensions(self): self.assertEqual(client.extensions, []) def test_no_extension(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory()], ) @@ -358,7 +358,7 @@ def test_no_extension(self): self.assertEqual(client.extensions, [OpExtension()]) def test_extension(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientRsv2ExtensionFactory()], ) @@ -371,7 +371,7 @@ def test_extension(self): self.assertEqual(client.extensions, [Rsv2Extension()]) def test_unexpected_extension(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Extensions"] = "x-op; op" client.receive_data(response.serialize()) @@ -383,7 +383,7 @@ def test_unexpected_extension(self): self.assertEqual(str(raised.exception), "no extensions supported") def test_unsupported_extension(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientRsv2ExtensionFactory()], ) @@ -401,7 +401,7 @@ def test_unsupported_extension(self): ) def test_supported_extension_parameters(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory("this")], ) @@ -414,7 +414,7 @@ def test_supported_extension_parameters(self): self.assertEqual(client.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory("this")], ) @@ -432,7 +432,7 @@ def test_unsupported_extension_parameters(self): ) def test_multiple_supported_extension_parameters(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ ClientOpExtensionFactory("this"), @@ -448,7 +448,7 @@ def test_multiple_supported_extension_parameters(self): self.assertEqual(client.extensions, [OpExtension("that")]) def test_multiple_extensions(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], ) @@ -462,7 +462,7 @@ def test_multiple_extensions(self): self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], ) @@ -476,7 +476,7 @@ def test_multiple_extensions_order(self): self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -485,9 +485,7 @@ def test_no_subprotocols(self): self.assertIsNone(client.subprotocol) def test_no_subprotocol(self): - client = ClientConnection( - parse_uri("wss://example.com/"), subprotocols=["chat"] - ) + client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) response = self.make_accept_response(client) client.receive_data(response.serialize()) [response] = client.events_received() @@ -496,9 +494,7 @@ def test_no_subprotocol(self): self.assertIsNone(client.subprotocol) def test_subprotocol(self): - client = ClientConnection( - parse_uri("wss://example.com/"), subprotocols=["chat"] - ) + client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" client.receive_data(response.serialize()) @@ -508,7 +504,7 @@ def test_subprotocol(self): self.assertEqual(client.subprotocol, "chat") def test_unexpected_subprotocol(self): - client = ClientConnection(parse_uri("wss://example.com/")) + client = ClientProtocol(parse_uri("wss://example.com/")) response = self.make_accept_response(client) response.headers["Sec-WebSocket-Protocol"] = "chat" client.receive_data(response.serialize()) @@ -520,7 +516,7 @@ def test_unexpected_subprotocol(self): self.assertEqual(str(raised.exception), "no subprotocols supported") def test_multiple_subprotocols(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), subprotocols=["superchat", "chat"], ) @@ -538,7 +534,7 @@ def test_multiple_subprotocols(self): ) def test_supported_subprotocol(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), subprotocols=["superchat", "chat"], ) @@ -551,7 +547,7 @@ def test_supported_subprotocol(self): self.assertEqual(client.subprotocol, "chat") def test_unsupported_subprotocol(self): - client = ClientConnection( + client = ClientProtocol( parse_uri("wss://example.com/"), subprotocols=["superchat", "chat"], ) @@ -568,7 +564,7 @@ def test_unsupported_subprotocol(self): class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - client = ClientConnection(parse_uri("ws://example.com/test"), state=OPEN) + client = ClientProtocol(parse_uri("ws://example.com/test"), state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) @@ -576,5 +572,17 @@ def test_bypass_handshake(self): def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ClientConnection(parse_uri("wss://example.com/test"), logger=logger) + ClientProtocol(parse_uri("wss://example.com/test"), logger=logger) self.assertEqual(len(logs.records), 1) + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_client_connection_class(self): + with self.assertDeprecationWarning( + "ClientConnection was renamed to ClientProtocol" + ): + from websockets.client import ClientConnection + + client = ClientConnection("ws://localhost/") + + self.assertIsInstance(client, ClientProtocol) diff --git a/tests/test_connection.py b/tests/test_connection.py index 3858d2521..6592d67d0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,1737 +1,14 @@ -import logging -import unittest.mock +from websockets.protocol import Protocol -from websockets.connection import * -from websockets.connection import CLIENT, CLOSED, CLOSING, SERVER -from websockets.exceptions import ( - ConnectionClosedError, - ConnectionClosedOK, - InvalidState, - PayloadTooBig, - ProtocolError, -) -from websockets.frames import ( - OP_BINARY, - OP_CLOSE, - OP_CONT, - OP_PING, - OP_PONG, - OP_TEXT, - Close, - Frame, -) +from .utils import DeprecationTestCase -from .extensions.utils import Rsv2Extension -from .test_frames import FramesTestCase +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_connection_class(self): + with self.assertDeprecationWarning( + "websockets.connection was renamed to websockets.protocol " + "and Connection was renamed to Protocol" + ): + from websockets.connection import Connection -class ConnectionTestCase(FramesTestCase): - def assertFrameSent(self, connection, frame, eof=False): - """ - Outgoing data for ``connection`` contains the given frame. - - ``frame`` may be ``None`` if no frame is expected. - - When ``eof`` is ``True``, the end of the stream is also expected. - - """ - frames_sent = [ - None - if write is SEND_EOF - else self.parse( - write, - mask=connection.side is CLIENT, - extensions=connection.extensions, - ) - for write in connection.data_to_send() - ] - frames_expected = [] if frame is None else [frame] - if eof: - frames_expected += [None] - self.assertEqual(frames_sent, frames_expected) - - def assertFrameReceived(self, connection, frame): - """ - Incoming data for ``connection`` contains the given frame. - - ``frame`` may be ``None`` if no frame is expected. - - """ - frames_received = connection.events_received() - frames_expected = [] if frame is None else [frame] - self.assertEqual(frames_received, frames_expected) - - def assertConnectionClosing(self, connection, code=None, reason=""): - """ - Incoming data caused the "Start the WebSocket Closing Handshake" process. - - """ - close_frame = Frame( - OP_CLOSE, - b"" if code is None else Close(code, reason).serialize(), - ) - # A close frame was received. - self.assertFrameReceived(connection, close_frame) - # A close frame and possibly the end of stream were sent. - self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) - - def assertConnectionFailing(self, connection, code=None, reason=""): - """ - Incoming data caused the "Fail the WebSocket Connection" process. - - """ - close_frame = Frame( - OP_CLOSE, - b"" if code is None else Close(code, reason).serialize(), - ) - # No frame was received. - self.assertFrameReceived(connection, None) - # A close frame and possibly the end of stream were sent. - self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) - - -class MaskingTests(ConnectionTestCase): - """ - Test frame masking. - - 5.1. Overview - - """ - - unmasked_text_frame_date = b"\x81\x04Spam" - masked_text_frame_data = b"\x81\x84\x00\xff\x00\xff\x53\x8f\x61\x92" - - def test_client_sends_masked_frame(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\xff\x00\xff"): - client.send_text(b"Spam", True) - self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) - - def test_server_sends_unmasked_frame(self): - server = Connection(SERVER) - server.send_text(b"Spam", True) - self.assertEqual(server.data_to_send(), [self.unmasked_text_frame_date]) - - def test_client_receives_unmasked_frame(self): - client = Connection(CLIENT) - client.receive_data(self.unmasked_text_frame_date) - self.assertFrameReceived( - client, - Frame(OP_TEXT, b"Spam"), - ) - - def test_server_receives_masked_frame(self): - server = Connection(SERVER) - server.receive_data(self.masked_text_frame_data) - self.assertFrameReceived( - server, - Frame(OP_TEXT, b"Spam"), - ) - - def test_client_receives_masked_frame(self): - client = Connection(CLIENT) - client.receive_data(self.masked_text_frame_data) - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "incorrect masking") - self.assertConnectionFailing(client, 1002, "incorrect masking") - - def test_server_receives_unmasked_frame(self): - server = Connection(SERVER) - server.receive_data(self.unmasked_text_frame_date) - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "incorrect masking") - self.assertConnectionFailing(server, 1002, "incorrect masking") - - -class ContinuationTests(ConnectionTestCase): - """ - Test continuation frames without text or binary frames. - - """ - - def test_client_sends_unexpected_continuation(self): - client = Connection(CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.send_continuation(b"", fin=False) - self.assertEqual(str(raised.exception), "unexpected continuation frame") - - def test_server_sends_unexpected_continuation(self): - server = Connection(SERVER) - with self.assertRaises(ProtocolError) as raised: - server.send_continuation(b"", fin=False) - self.assertEqual(str(raised.exception), "unexpected continuation frame") - - def test_client_receives_unexpected_continuation(self): - client = Connection(CLIENT) - client.receive_data(b"\x00\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "unexpected continuation frame") - self.assertConnectionFailing(client, 1002, "unexpected continuation frame") - - def test_server_receives_unexpected_continuation(self): - server = Connection(SERVER) - server.receive_data(b"\x00\x80\x00\x00\x00\x00") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "unexpected continuation frame") - self.assertConnectionFailing(server, 1002, "unexpected continuation frame") - - def test_client_sends_continuation_after_sending_close(self): - client = Connection(CLIENT) - # Since it isn't possible to send a close frame in a fragmented - # message (see test_client_send_close_in_fragmented_message), in fact, - # this is the same test as test_client_sends_unexpected_continuation. - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(ProtocolError) as raised: - client.send_continuation(b"", fin=False) - self.assertEqual(str(raised.exception), "unexpected continuation frame") - - def test_server_sends_continuation_after_sending_close(self): - # Since it isn't possible to send a close frame in a fragmented - # message (see test_server_send_close_in_fragmented_message), in fact, - # this is the same test as test_server_sends_unexpected_continuation. - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(ProtocolError) as raised: - server.send_continuation(b"", fin=False) - self.assertEqual(str(raised.exception), "unexpected continuation frame") - - def test_client_receives_continuation_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x00\x00") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_continuation_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x00\x80\x00\xff\x00\xff") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class TextTests(ConnectionTestCase): - """ - Test text frames and continuation frames. - - """ - - def test_client_sends_text(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_text("😀".encode()) - self.assertEqual( - client.data_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] - ) - - def test_server_sends_text(self): - server = Connection(SERVER) - server.send_text("😀".encode()) - self.assertEqual(server.data_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) - - def test_client_receives_text(self): - client = Connection(CLIENT) - client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()), - ) - - def test_server_receives_text(self): - server = Connection(SERVER) - server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()), - ) - - def test_client_receives_text_over_size_limit(self): - client = Connection(CLIENT, max_size=3) - client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") - self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") - - def test_server_receives_text_over_size_limit(self): - server = Connection(SERVER, max_size=3) - server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") - self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") - - def test_client_receives_text_without_size_limit(self): - client = Connection(CLIENT, max_size=None) - client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()), - ) - - def test_server_receives_text_without_size_limit(self): - server = Connection(SERVER, max_size=None) - server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()), - ) - - def test_client_sends_fragmented_text(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_text("😀".encode()[:2], fin=False) - self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_continuation("😀😀".encode()[2:6], fin=False) - self.assertEqual( - client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] - ) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_continuation("😀".encode()[2:], fin=True) - self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) - - def test_server_sends_fragmented_text(self): - server = Connection(SERVER) - server.send_text("😀".encode()[:2], fin=False) - self.assertEqual(server.data_to_send(), [b"\x01\x02\xf0\x9f"]) - server.send_continuation("😀😀".encode()[2:6], fin=False) - self.assertEqual(server.data_to_send(), [b"\x00\x04\x98\x80\xf0\x9f"]) - server.send_continuation("😀".encode()[2:], fin=True) - self.assertEqual(server.data_to_send(), [b"\x80\x02\x98\x80"]) - - def test_client_receives_fragmented_text(self): - client = Connection(CLIENT) - client.receive_data(b"\x01\x02\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), - ) - client.receive_data(b"\x80\x02\x98\x80") - self.assertFrameReceived( - client, - Frame(OP_CONT, "😀".encode()[2:]), - ) - - def test_server_receives_fragmented_text(self): - server = Connection(SERVER) - server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") - self.assertFrameReceived( - server, - Frame(OP_CONT, "😀".encode()[2:]), - ) - - def test_client_receives_fragmented_text_over_size_limit(self): - client = Connection(CLIENT, max_size=3) - client.receive_data(b"\x01\x02\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - client.receive_data(b"\x80\x02\x98\x80") - self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") - - def test_server_receives_fragmented_text_over_size_limit(self): - server = Connection(SERVER, max_size=3) - server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") - self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") - - def test_client_receives_fragmented_text_without_size_limit(self): - client = Connection(CLIENT, max_size=None) - client.receive_data(b"\x01\x02\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") - self.assertFrameReceived( - client, - Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), - ) - client.receive_data(b"\x80\x02\x98\x80") - self.assertFrameReceived( - client, - Frame(OP_CONT, "😀".encode()[2:]), - ) - - def test_server_receives_fragmented_text_without_size_limit(self): - server = Connection(SERVER, max_size=None) - server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_TEXT, "😀".encode()[:2], fin=False), - ) - server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") - self.assertFrameReceived( - server, - Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") - self.assertFrameReceived( - server, - Frame(OP_CONT, "😀".encode()[2:]), - ) - - def test_client_sends_unexpected_text(self): - client = Connection(CLIENT) - client.send_text(b"", fin=False) - with self.assertRaises(ProtocolError) as raised: - client.send_text(b"", fin=False) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_server_sends_unexpected_text(self): - server = Connection(SERVER) - server.send_text(b"", fin=False) - with self.assertRaises(ProtocolError) as raised: - server.send_text(b"", fin=False) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_client_receives_unexpected_text(self): - client = Connection(CLIENT) - client.receive_data(b"\x01\x00") - self.assertFrameReceived( - client, - Frame(OP_TEXT, b"", fin=False), - ) - client.receive_data(b"\x01\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(client, 1002, "expected a continuation frame") - - def test_server_receives_unexpected_text(self): - server = Connection(SERVER) - server.receive_data(b"\x01\x80\x00\x00\x00\x00") - self.assertFrameReceived( - server, - Frame(OP_TEXT, b"", fin=False), - ) - server.receive_data(b"\x01\x80\x00\x00\x00\x00") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(server, 1002, "expected a continuation frame") - - def test_client_sends_text_after_sending_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): - client.send_text(b"") - - def test_server_sends_text_after_sending_close(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): - server.send_text(b"") - - def test_client_receives_text_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x81\x00") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_text_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x81\x80\x00\xff\x00\xff") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class BinaryTests(ConnectionTestCase): - """ - Test binary frames and continuation frames. - - """ - - def test_client_sends_binary(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_binary(b"\x01\x02\xfe\xff") - self.assertEqual( - client.data_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] - ) - - def test_server_sends_binary(self): - server = Connection(SERVER) - server.send_binary(b"\x01\x02\xfe\xff") - self.assertEqual(server.data_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) - - def test_client_receives_binary(self): - client = Connection(CLIENT) - client.receive_data(b"\x82\x04\x01\x02\xfe\xff") - self.assertFrameReceived( - client, - Frame(OP_BINARY, b"\x01\x02\xfe\xff"), - ) - - def test_server_receives_binary(self): - server = Connection(SERVER) - server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") - self.assertFrameReceived( - server, - Frame(OP_BINARY, b"\x01\x02\xfe\xff"), - ) - - def test_client_receives_binary_over_size_limit(self): - client = Connection(CLIENT, max_size=3) - client.receive_data(b"\x82\x04\x01\x02\xfe\xff") - self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") - - def test_server_receives_binary_over_size_limit(self): - server = Connection(SERVER, max_size=3) - server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") - self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") - - def test_client_sends_fragmented_binary(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_binary(b"\x01\x02", fin=False) - self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_continuation(b"\xee\xff\x01\x02", fin=False) - self.assertEqual( - client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] - ) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_continuation(b"\xee\xff", fin=True) - self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) - - def test_server_sends_fragmented_binary(self): - server = Connection(SERVER) - server.send_binary(b"\x01\x02", fin=False) - self.assertEqual(server.data_to_send(), [b"\x02\x02\x01\x02"]) - server.send_continuation(b"\xee\xff\x01\x02", fin=False) - self.assertEqual(server.data_to_send(), [b"\x00\x04\xee\xff\x01\x02"]) - server.send_continuation(b"\xee\xff", fin=True) - self.assertEqual(server.data_to_send(), [b"\x80\x02\xee\xff"]) - - def test_client_receives_fragmented_binary(self): - client = Connection(CLIENT) - client.receive_data(b"\x02\x02\x01\x02") - self.assertFrameReceived( - client, - Frame(OP_BINARY, b"\x01\x02", fin=False), - ) - client.receive_data(b"\x00\x04\xfe\xff\x01\x02") - self.assertFrameReceived( - client, - Frame(OP_CONT, b"\xfe\xff\x01\x02", fin=False), - ) - client.receive_data(b"\x80\x02\xfe\xff") - self.assertFrameReceived( - client, - Frame(OP_CONT, b"\xfe\xff"), - ) - - def test_server_receives_fragmented_binary(self): - server = Connection(SERVER) - server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") - self.assertFrameReceived( - server, - Frame(OP_BINARY, b"\x01\x02", fin=False), - ) - server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") - self.assertFrameReceived( - server, - Frame(OP_CONT, b"\xee\xff\x01\x02", fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") - self.assertFrameReceived( - server, - Frame(OP_CONT, b"\xfe\xff"), - ) - - def test_client_receives_fragmented_binary_over_size_limit(self): - client = Connection(CLIENT, max_size=3) - client.receive_data(b"\x02\x02\x01\x02") - self.assertFrameReceived( - client, - Frame(OP_BINARY, b"\x01\x02", fin=False), - ) - client.receive_data(b"\x80\x02\xfe\xff") - self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") - - def test_server_receives_fragmented_binary_over_size_limit(self): - server = Connection(SERVER, max_size=3) - server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") - self.assertFrameReceived( - server, - Frame(OP_BINARY, b"\x01\x02", fin=False), - ) - server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") - self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") - - def test_client_sends_unexpected_binary(self): - client = Connection(CLIENT) - client.send_binary(b"", fin=False) - with self.assertRaises(ProtocolError) as raised: - client.send_binary(b"", fin=False) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_server_sends_unexpected_binary(self): - server = Connection(SERVER) - server.send_binary(b"", fin=False) - with self.assertRaises(ProtocolError) as raised: - server.send_binary(b"", fin=False) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_client_receives_unexpected_binary(self): - client = Connection(CLIENT) - client.receive_data(b"\x02\x00") - self.assertFrameReceived( - client, - Frame(OP_BINARY, b"", fin=False), - ) - client.receive_data(b"\x02\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(client, 1002, "expected a continuation frame") - - def test_server_receives_unexpected_binary(self): - server = Connection(SERVER) - server.receive_data(b"\x02\x80\x00\x00\x00\x00") - self.assertFrameReceived( - server, - Frame(OP_BINARY, b"", fin=False), - ) - server.receive_data(b"\x02\x80\x00\x00\x00\x00") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(server, 1002, "expected a continuation frame") - - def test_client_sends_binary_after_sending_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): - client.send_binary(b"") - - def test_server_sends_binary_after_sending_close(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): - server.send_binary(b"") - - def test_client_receives_binary_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x82\x00") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_binary_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x82\x80\x00\xff\x00\xff") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class CloseTests(ConnectionTestCase): - """ - Test close frames. - - See RFC 6544: - - 5.5.1. Close - 7.1.6. The WebSocket Connection Close Reason - 7.1.7. Fail the WebSocket Connection - - """ - - def test_close_code(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x04\x03\xe8OK") - client.receive_eof() - self.assertEqual(client.close_code, 1000) - - def test_close_reason(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") - server.receive_eof() - self.assertEqual(server.close_reason, "OK") - - def test_close_code_not_provided(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x00\x00\x00\x00") - server.receive_eof() - self.assertEqual(server.close_code, 1005) - - def test_close_reason_not_provided(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - client.receive_eof() - self.assertEqual(client.close_reason, "") - - def test_close_code_not_available(self): - client = Connection(CLIENT) - client.receive_eof() - self.assertEqual(client.close_code, 1006) - - def test_close_reason_not_available(self): - server = Connection(SERVER) - server.receive_eof() - self.assertEqual(server.close_reason, "") - - def test_close_code_not_available_yet(self): - server = Connection(SERVER) - self.assertIsNone(server.close_code) - - def test_close_reason_not_available_yet(self): - client = Connection(CLIENT) - self.assertIsNone(client.close_reason) - - def test_client_sends_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): - client.send_close() - self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) - self.assertIs(client.state, CLOSING) - - def test_server_sends_close(self): - server = Connection(SERVER) - server.send_close() - self.assertEqual(server.data_to_send(), [b"\x88\x00"]) - self.assertIs(server.state, CLOSING) - - def test_client_receives_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): - client.receive_data(b"\x88\x00") - self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) - self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) - self.assertIs(client.state, CLOSING) - - def test_server_receives_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) - self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) - self.assertIs(server.state, CLOSING) - - def test_client_sends_close_then_receives_close(self): - # Client-initiated close handshake on the client side. - client = Connection(CLIENT) - - client.send_close() - self.assertFrameReceived(client, None) - self.assertFrameSent(client, Frame(OP_CLOSE, b"")) - - client.receive_data(b"\x88\x00") - self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) - self.assertFrameSent(client, None) - - client.receive_eof() - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None, eof=True) - - def test_server_sends_close_then_receives_close(self): - # Server-initiated close handshake on the server side. - server = Connection(SERVER) - - server.send_close() - self.assertFrameReceived(server, None) - self.assertFrameSent(server, Frame(OP_CLOSE, b"")) - - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) - self.assertFrameSent(server, None, eof=True) - - server.receive_eof() - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - def test_client_receives_close_then_sends_close(self): - # Server-initiated close handshake on the client side. - client = Connection(CLIENT) - - client.receive_data(b"\x88\x00") - self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) - self.assertFrameSent(client, Frame(OP_CLOSE, b"")) - - client.receive_eof() - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None, eof=True) - - def test_server_receives_close_then_sends_close(self): - # Client-initiated close handshake on the server side. - server = Connection(SERVER) - - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) - self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) - - server.receive_eof() - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - def test_client_sends_close_with_code(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - self.assertIs(client.state, CLOSING) - - def test_server_sends_close_with_code(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - self.assertIs(server.state, CLOSING) - - def test_client_receives_close_with_code(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000, "") - self.assertIs(client.state, CLOSING) - - def test_server_receives_close_with_code(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001, "") - self.assertIs(server.state, CLOSING) - - def test_client_sends_close_with_code_and_reason(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001, "going away") - self.assertEqual( - client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] - ) - self.assertIs(client.state, CLOSING) - - def test_server_sends_close_with_code_and_reason(self): - server = Connection(SERVER) - server.send_close(1000, "OK") - self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) - self.assertIs(server.state, CLOSING) - - def test_client_receives_close_with_code_and_reason(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x04\x03\xe8OK") - self.assertConnectionClosing(client, 1000, "OK") - self.assertIs(client.state, CLOSING) - - def test_server_receives_close_with_code_and_reason(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") - self.assertConnectionClosing(server, 1001, "going away") - self.assertIs(server.state, CLOSING) - - def test_client_sends_close_with_reason_only(self): - client = Connection(CLIENT) - with self.assertRaises(ProtocolError) as raised: - client.send_close(reason="going away") - self.assertEqual(str(raised.exception), "cannot send a reason without a code") - - def test_server_sends_close_with_reason_only(self): - server = Connection(SERVER) - with self.assertRaises(ProtocolError) as raised: - server.send_close(reason="OK") - self.assertEqual(str(raised.exception), "cannot send a reason without a code") - - def test_client_receives_close_with_truncated_code(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x01\x03") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "close frame too short") - self.assertConnectionFailing(client, 1002, "close frame too short") - self.assertIs(client.state, CLOSING) - - def test_server_receives_close_with_truncated_code(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "close frame too short") - self.assertConnectionFailing(server, 1002, "close frame too short") - self.assertIs(server.state, CLOSING) - - def test_client_receives_close_with_non_utf8_reason(self): - client = Connection(CLIENT) - - client.receive_data(b"\x88\x04\x03\xe8\xff\xff") - self.assertIsInstance(client.parser_exc, UnicodeDecodeError) - self.assertEqual( - str(client.parser_exc), - "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", - ) - self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") - self.assertIs(client.state, CLOSING) - - def test_server_receives_close_with_non_utf8_reason(self): - server = Connection(SERVER) - - server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") - self.assertIsInstance(server.parser_exc, UnicodeDecodeError) - self.assertEqual( - str(server.parser_exc), - "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", - ) - self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") - self.assertIs(server.state, CLOSING) - - -class PingTests(ConnectionTestCase): - """ - Test ping. See 5.5.2. Ping in RFC 6544. - - """ - - def test_client_sends_ping(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_ping(b"") - self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) - - def test_server_sends_ping(self): - server = Connection(SERVER) - server.send_ping(b"") - self.assertEqual(server.data_to_send(), [b"\x89\x00"]) - - def test_client_receives_ping(self): - client = Connection(CLIENT) - client.receive_data(b"\x89\x00") - self.assertFrameReceived( - client, - Frame(OP_PING, b""), - ) - self.assertFrameSent( - client, - Frame(OP_PONG, b""), - ) - - def test_server_receives_ping(self): - server = Connection(SERVER) - server.receive_data(b"\x89\x80\x00\x44\x88\xcc") - self.assertFrameReceived( - server, - Frame(OP_PING, b""), - ) - self.assertFrameSent( - server, - Frame(OP_PONG, b""), - ) - - def test_client_sends_ping_with_data(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_ping(b"\x22\x66\xaa\xee") - self.assertEqual( - client.data_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] - ) - - def test_server_sends_ping_with_data(self): - server = Connection(SERVER) - server.send_ping(b"\x22\x66\xaa\xee") - self.assertEqual(server.data_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) - - def test_client_receives_ping_with_data(self): - client = Connection(CLIENT) - client.receive_data(b"\x89\x04\x22\x66\xaa\xee") - self.assertFrameReceived( - client, - Frame(OP_PING, b"\x22\x66\xaa\xee"), - ) - self.assertFrameSent( - client, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) - - def test_server_receives_ping_with_data(self): - server = Connection(SERVER) - server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived( - server, - Frame(OP_PING, b"\x22\x66\xaa\xee"), - ) - self.assertFrameSent( - server, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) - - def test_client_sends_fragmented_ping_frame(self): - client = Connection(CLIENT) - # This is only possible through a private API. - with self.assertRaises(ProtocolError) as raised: - client.send_frame(Frame(OP_PING, b"", fin=False)) - self.assertEqual(str(raised.exception), "fragmented control frame") - - def test_server_sends_fragmented_ping_frame(self): - server = Connection(SERVER) - # This is only possible through a private API. - with self.assertRaises(ProtocolError) as raised: - server.send_frame(Frame(OP_PING, b"", fin=False)) - self.assertEqual(str(raised.exception), "fragmented control frame") - - def test_client_receives_fragmented_ping_frame(self): - client = Connection(CLIENT) - client.receive_data(b"\x09\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "fragmented control frame") - self.assertConnectionFailing(client, 1002, "fragmented control frame") - - def test_server_receives_fragmented_ping_frame(self): - server = Connection(SERVER) - server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "fragmented control frame") - self.assertConnectionFailing(server, 1002, "fragmented control frame") - - def test_client_sends_ping_after_sending_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: - client.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) - - def test_server_sends_ping_after_sending_close(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: - server.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) - - def test_client_receives_ping_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x89\x04\x22\x66\xaa\xee") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_ping_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class PongTests(ConnectionTestCase): - """ - Test pong frames. See 5.5.3. Pong in RFC 6544. - - """ - - def test_client_sends_pong(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_pong(b"") - self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) - - def test_server_sends_pong(self): - server = Connection(SERVER) - server.send_pong(b"") - self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) - - def test_client_receives_pong(self): - client = Connection(CLIENT) - client.receive_data(b"\x8a\x00") - self.assertFrameReceived( - client, - Frame(OP_PONG, b""), - ) - - def test_server_receives_pong(self): - server = Connection(SERVER) - server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") - self.assertFrameReceived( - server, - Frame(OP_PONG, b""), - ) - - def test_client_sends_pong_with_data(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_pong(b"\x22\x66\xaa\xee") - self.assertEqual( - client.data_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] - ) - - def test_server_sends_pong_with_data(self): - server = Connection(SERVER) - server.send_pong(b"\x22\x66\xaa\xee") - self.assertEqual(server.data_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) - - def test_client_receives_pong_with_data(self): - client = Connection(CLIENT) - client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") - self.assertFrameReceived( - client, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) - - def test_server_receives_pong_with_data(self): - server = Connection(SERVER) - server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived( - server, - Frame(OP_PONG, b"\x22\x66\xaa\xee"), - ) - - def test_client_sends_fragmented_pong_frame(self): - client = Connection(CLIENT) - # This is only possible through a private API. - with self.assertRaises(ProtocolError) as raised: - client.send_frame(Frame(OP_PONG, b"", fin=False)) - self.assertEqual(str(raised.exception), "fragmented control frame") - - def test_server_sends_fragmented_pong_frame(self): - server = Connection(SERVER) - # This is only possible through a private API. - with self.assertRaises(ProtocolError) as raised: - server.send_frame(Frame(OP_PONG, b"", fin=False)) - self.assertEqual(str(raised.exception), "fragmented control frame") - - def test_client_receives_fragmented_pong_frame(self): - client = Connection(CLIENT) - client.receive_data(b"\x0a\x00") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "fragmented control frame") - self.assertConnectionFailing(client, 1002, "fragmented control frame") - - def test_server_receives_fragmented_pong_frame(self): - server = Connection(SERVER) - server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "fragmented control frame") - self.assertConnectionFailing(server, 1002, "fragmented control frame") - - def test_client_sends_pong_after_sending_close(self): - client = Connection(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) - self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): - client.send_pong(b"") - - def test_server_sends_pong_after_sending_close(self): - server = Connection(SERVER) - server.send_close(1000) - self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): - server.send_pong(b"") - - def test_client_receives_pong_after_receiving_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) - client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") - self.assertFrameReceived(client, None) - self.assertFrameSent(client, None) - - def test_server_receives_pong_after_receiving_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) - server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") - self.assertFrameReceived(server, None) - self.assertFrameSent(server, None) - - -class FailTests(ConnectionTestCase): - """ - Test failing the connection. - - See 7.1.7. Fail the WebSocket Connection in RFC 6544. - - """ - - def test_client_stops_processing_frames_after_fail(self): - client = Connection(CLIENT) - client.fail(1002) - self.assertConnectionFailing(client, 1002) - client.receive_data(b"\x88\x02\x03\xea") - self.assertFrameReceived(client, None) - - def test_server_stops_processing_frames_after_fail(self): - server = Connection(SERVER) - server.fail(1002) - self.assertConnectionFailing(server, 1002) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") - self.assertFrameReceived(server, None) - - -class FragmentationTests(ConnectionTestCase): - """ - Test message fragmentation. - - See 5.4. Fragmentation in RFC 6544. - - """ - - def test_client_send_ping_pong_in_fragmented_message(self): - client = Connection(CLIENT) - client.send_text(b"Spam", fin=False) - self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - client.send_ping(b"Ping") - self.assertFrameSent(client, Frame(OP_PING, b"Ping")) - client.send_continuation(b"Ham", fin=False) - self.assertFrameSent(client, Frame(OP_CONT, b"Ham", fin=False)) - client.send_pong(b"Pong") - self.assertFrameSent(client, Frame(OP_PONG, b"Pong")) - client.send_continuation(b"Eggs", fin=True) - self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) - - def test_server_send_ping_pong_in_fragmented_message(self): - server = Connection(SERVER) - server.send_text(b"Spam", fin=False) - self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) - server.send_ping(b"Ping") - self.assertFrameSent(server, Frame(OP_PING, b"Ping")) - server.send_continuation(b"Ham", fin=False) - self.assertFrameSent(server, Frame(OP_CONT, b"Ham", fin=False)) - server.send_pong(b"Pong") - self.assertFrameSent(server, Frame(OP_PONG, b"Pong")) - server.send_continuation(b"Eggs", fin=True) - self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) - - def test_client_receive_ping_pong_in_fragmented_message(self): - client = Connection(CLIENT) - client.receive_data(b"\x01\x04Spam") - self.assertFrameReceived( - client, - Frame(OP_TEXT, b"Spam", fin=False), - ) - client.receive_data(b"\x89\x04Ping") - self.assertFrameReceived( - client, - Frame(OP_PING, b"Ping"), - ) - self.assertFrameSent( - client, - Frame(OP_PONG, b"Ping"), - ) - client.receive_data(b"\x00\x03Ham") - self.assertFrameReceived( - client, - Frame(OP_CONT, b"Ham", fin=False), - ) - client.receive_data(b"\x8a\x04Pong") - self.assertFrameReceived( - client, - Frame(OP_PONG, b"Pong"), - ) - client.receive_data(b"\x80\x04Eggs") - self.assertFrameReceived( - client, - Frame(OP_CONT, b"Eggs"), - ) - - def test_server_receive_ping_pong_in_fragmented_message(self): - server = Connection(SERVER) - server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") - self.assertFrameReceived( - server, - Frame(OP_TEXT, b"Spam", fin=False), - ) - server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") - self.assertFrameReceived( - server, - Frame(OP_PING, b"Ping"), - ) - self.assertFrameSent( - server, - Frame(OP_PONG, b"Ping"), - ) - server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") - self.assertFrameReceived( - server, - Frame(OP_CONT, b"Ham", fin=False), - ) - server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") - self.assertFrameReceived( - server, - Frame(OP_PONG, b"Pong"), - ) - server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") - self.assertFrameReceived( - server, - Frame(OP_CONT, b"Eggs"), - ) - - def test_client_send_close_in_fragmented_message(self): - client = Connection(CLIENT) - client.send_text(b"Spam", fin=False) - self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - client.send_close(1001) - self.assertEqual(str(raised.exception), "expected a continuation frame") - client.send_continuation(b"Eggs", fin=True) - - def test_server_send_close_in_fragmented_message(self): - server = Connection(CLIENT) - server.send_text(b"Spam", fin=False) - self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - server.send_close(1000) - self.assertEqual(str(raised.exception), "expected a continuation frame") - - def test_client_receive_close_in_fragmented_message(self): - client = Connection(CLIENT) - client.receive_data(b"\x01\x04Spam") - self.assertFrameReceived( - client, - Frame(OP_TEXT, b"Spam", fin=False), - ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - client.receive_data(b"\x88\x02\x03\xe8") - self.assertIsInstance(client.parser_exc, ProtocolError) - self.assertEqual(str(client.parser_exc), "incomplete fragmented message") - self.assertConnectionFailing(client, 1002, "incomplete fragmented message") - - def test_server_receive_close_in_fragmented_message(self): - server = Connection(SERVER) - server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") - self.assertFrameReceived( - server, - Frame(OP_TEXT, b"Spam", fin=False), - ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertIsInstance(server.parser_exc, ProtocolError) - self.assertEqual(str(server.parser_exc), "incomplete fragmented message") - self.assertConnectionFailing(server, 1002, "incomplete fragmented message") - - -class EOFTests(ConnectionTestCase): - """ - Test half-closes on connection termination. - - """ - - def test_client_receives_eof(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - self.assertConnectionClosing(client) - client.receive_eof() - self.assertIs(client.state, CLOSED) - - def test_server_receives_eof(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertConnectionClosing(server) - server.receive_eof() - self.assertIs(server.state, CLOSED) - - def test_client_receives_eof_between_frames(self): - client = Connection(CLIENT) - client.receive_eof() - self.assertIsInstance(client.parser_exc, EOFError) - self.assertEqual(str(client.parser_exc), "unexpected end of stream") - self.assertIs(client.state, CLOSED) - - def test_server_receives_eof_between_frames(self): - server = Connection(SERVER) - server.receive_eof() - self.assertIsInstance(server.parser_exc, EOFError) - self.assertEqual(str(server.parser_exc), "unexpected end of stream") - self.assertIs(server.state, CLOSED) - - def test_client_receives_eof_inside_frame(self): - client = Connection(CLIENT) - client.receive_data(b"\x81") - client.receive_eof() - self.assertIsInstance(client.parser_exc, EOFError) - self.assertEqual( - str(client.parser_exc), - "stream ends after 1 bytes, expected 2 bytes", - ) - self.assertIs(client.state, CLOSED) - - def test_server_receives_eof_inside_frame(self): - server = Connection(SERVER) - server.receive_data(b"\x81") - server.receive_eof() - self.assertIsInstance(server.parser_exc, EOFError) - self.assertEqual( - str(server.parser_exc), - "stream ends after 1 bytes, expected 2 bytes", - ) - self.assertIs(server.state, CLOSED) - - def test_client_receives_data_after_exception(self): - client = Connection(CLIENT) - client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") - client.receive_data(b"\x00\x00") - self.assertFrameSent(client, None) - - def test_server_receives_data_after_exception(self): - server = Connection(SERVER) - server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") - server.receive_data(b"\x00\x00") - self.assertFrameSent(server, None) - - def test_client_receives_eof_after_exception(self): - client = Connection(CLIENT) - client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") - client.receive_eof() - self.assertFrameSent(client, None, eof=True) - - def test_server_receives_eof_after_exception(self): - server = Connection(SERVER) - server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") - server.receive_eof() - self.assertFrameSent(server, None) - - def test_client_receives_data_and_eof_after_exception(self): - client = Connection(CLIENT) - client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") - client.receive_data(b"\x00\x00") - client.receive_eof() - self.assertFrameSent(client, None, eof=True) - - def test_server_receives_data_and_eof_after_exception(self): - server = Connection(SERVER) - server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") - server.receive_data(b"\x00\x00") - server.receive_eof() - self.assertFrameSent(server, None) - - def test_client_receives_data_after_eof(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - self.assertConnectionClosing(client) - client.receive_eof() - with self.assertRaises(EOFError) as raised: - client.receive_data(b"\x88\x00") - self.assertEqual(str(raised.exception), "stream ended") - - def test_server_receives_data_after_eof(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertConnectionClosing(server) - server.receive_eof() - with self.assertRaises(EOFError) as raised: - server.receive_data(b"\x88\x80\x00\x00\x00\x00") - self.assertEqual(str(raised.exception), "stream ended") - - def test_client_receives_eof_after_eof(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - self.assertConnectionClosing(client) - client.receive_eof() - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") - - def test_server_receives_eof_after_eof(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertConnectionClosing(server) - server.receive_eof() - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") - - -class TCPCloseTests(ConnectionTestCase): - """ - Test expectation of TCP close on connection termination. - - """ - - def test_client_default(self): - client = Connection(CLIENT) - self.assertFalse(client.close_expected()) - - def test_server_default(self): - server = Connection(SERVER) - self.assertFalse(server.close_expected()) - - def test_client_sends_close(self): - client = Connection(CLIENT) - client.send_close() - self.assertTrue(client.close_expected()) - - def test_server_sends_close(self): - server = Connection(SERVER) - server.send_close() - self.assertTrue(server.close_expected()) - - def test_client_receives_close(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - self.assertTrue(client.close_expected()) - - def test_client_receives_close_then_eof(self): - client = Connection(CLIENT) - client.receive_data(b"\x88\x00") - client.receive_eof() - self.assertFalse(client.close_expected()) - - def test_server_receives_close_then_eof(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - server.receive_eof() - self.assertFalse(server.close_expected()) - - def test_server_receives_close(self): - server = Connection(SERVER) - server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") - self.assertTrue(server.close_expected()) - - def test_client_fails_connection(self): - client = Connection(CLIENT) - client.fail(1002) - self.assertTrue(client.close_expected()) - - def test_server_fails_connection(self): - server = Connection(SERVER) - server.fail(1002) - self.assertTrue(server.close_expected()) - - -class ConnectionClosedTests(ConnectionTestCase): - """ - Test connection closed exception. - - """ - - def test_client_sends_close_then_receives_close(self): - # Client-initiated close handshake on the client side complete. - client = Connection(CLIENT) - client.send_close(1000, "") - client.receive_data(b"\x88\x02\x03\xe8") - client.receive_eof() - exc = client.close_exc - self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertFalse(exc.rcvd_then_sent) - - def test_server_sends_close_then_receives_close(self): - # Server-initiated close handshake on the server side complete. - server = Connection(SERVER) - server.send_close(1000, "") - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") - server.receive_eof() - exc = server.close_exc - self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertFalse(exc.rcvd_then_sent) - - def test_client_receives_close_then_sends_close(self): - # Server-initiated close handshake on the client side complete. - client = Connection(CLIENT) - client.receive_data(b"\x88\x02\x03\xe8") - client.receive_eof() - exc = client.close_exc - self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertTrue(exc.rcvd_then_sent) - - def test_server_receives_close_then_sends_close(self): - # Client-initiated close handshake on the server side complete. - server = Connection(SERVER) - server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") - server.receive_eof() - exc = server.close_exc - self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertTrue(exc.rcvd_then_sent) - - def test_client_sends_close_then_receives_eof(self): - # Client-initiated close handshake on the client side times out. - client = Connection(CLIENT) - client.send_close(1000, "") - client.receive_eof() - exc = client.close_exc - self.assertIsInstance(exc, ConnectionClosedError) - self.assertIsNone(exc.rcvd) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertIsNone(exc.rcvd_then_sent) - - def test_server_sends_close_then_receives_eof(self): - # Server-initiated close handshake on the server side times out. - server = Connection(SERVER) - server.send_close(1000, "") - server.receive_eof() - exc = server.close_exc - self.assertIsInstance(exc, ConnectionClosedError) - self.assertIsNone(exc.rcvd) - self.assertEqual(exc.sent, Close(1000, "")) - self.assertIsNone(exc.rcvd_then_sent) - - def test_client_receives_eof(self): - # Server-initiated close handshake on the client side times out. - client = Connection(CLIENT) - client.receive_eof() - exc = client.close_exc - self.assertIsInstance(exc, ConnectionClosedError) - self.assertIsNone(exc.rcvd) - self.assertIsNone(exc.sent) - self.assertIsNone(exc.rcvd_then_sent) - - def test_server_receives_eof(self): - # Client-initiated close handshake on the server side times out. - server = Connection(SERVER) - server.receive_eof() - exc = server.close_exc - self.assertIsInstance(exc, ConnectionClosedError) - self.assertIsNone(exc.rcvd) - self.assertIsNone(exc.sent) - self.assertIsNone(exc.rcvd_then_sent) - - -class ErrorTests(ConnectionTestCase): - """ - Test other error cases. - - """ - - def test_client_hits_internal_error_reading_frame(self): - client = Connection(CLIENT) - # This isn't supposed to happen, so we're simulating it. - with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): - client.receive_data(b"\x81\x00") - self.assertIsInstance(client.parser_exc, RuntimeError) - self.assertEqual(str(client.parser_exc), "BOOM") - self.assertConnectionFailing(client, 1011, "") - - def test_server_hits_internal_error_reading_frame(self): - server = Connection(SERVER) - # This isn't supposed to happen, so we're simulating it. - with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): - server.receive_data(b"\x81\x80\x00\x00\x00\x00") - self.assertIsInstance(server.parser_exc, RuntimeError) - self.assertEqual(str(server.parser_exc), "BOOM") - self.assertConnectionFailing(server, 1011, "") - - -class ExtensionsTests(ConnectionTestCase): - """ - Test how extensions affect frames. - - """ - - def test_client_extension_encodes_frame(self): - client = Connection(CLIENT) - client.extensions = [Rsv2Extension()] - with self.enforce_mask(b"\x00\x44\x88\xcc"): - client.send_ping(b"") - self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) - - def test_server_extension_encodes_frame(self): - server = Connection(SERVER) - server.extensions = [Rsv2Extension()] - server.send_ping(b"") - self.assertEqual(server.data_to_send(), [b"\xa9\x00"]) - - def test_client_extension_decodes_frame(self): - client = Connection(CLIENT) - client.extensions = [Rsv2Extension()] - client.receive_data(b"\xaa\x00") - self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) - - def test_server_extension_decodes_frame(self): - server = Connection(SERVER) - server.extensions = [Rsv2Extension()] - server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") - self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) - - -class MiscTests(unittest.TestCase): - def test_client_default_logger(self): - client = Connection(CLIENT) - logger = logging.getLogger("websockets.client") - self.assertIs(client.logger, logger) - - def test_server_default_logger(self): - server = Connection(SERVER) - logger = logging.getLogger("websockets.server") - self.assertIs(server.logger, logger) - - def test_client_custom_logger(self): - logger = logging.getLogger("test") - client = Connection(CLIENT, logger=logger) - self.assertIs(client.logger, logger) - - def test_server_custom_logger(self): - logger = logging.getLogger("test") - server = Connection(SERVER, logger=logger) - self.assertIs(server.logger, logger) + self.assertIs(Connection, Protocol) diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 000000000..7321d2594 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,1737 @@ +import logging +import unittest.mock + +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from websockets.frames import ( + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Close, + Frame, +) +from websockets.protocol import * +from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER + +from .extensions.utils import Rsv2Extension +from .test_frames import FramesTestCase + + +class ProtocolTestCase(FramesTestCase): + def assertFrameSent(self, connection, frame, eof=False): + """ + Outgoing data for ``connection`` contains the given frame. + + ``frame`` may be ``None`` if no frame is expected. + + When ``eof`` is ``True``, the end of the stream is also expected. + + """ + frames_sent = [ + None + if write is SEND_EOF + else self.parse( + write, + mask=connection.side is CLIENT, + extensions=connection.extensions, + ) + for write in connection.data_to_send() + ] + frames_expected = [] if frame is None else [frame] + if eof: + frames_expected += [None] + self.assertEqual(frames_sent, frames_expected) + + def assertFrameReceived(self, connection, frame): + """ + Incoming data for ``connection`` contains the given frame. + + ``frame`` may be ``None`` if no frame is expected. + + """ + frames_received = connection.events_received() + frames_expected = [] if frame is None else [frame] + self.assertEqual(frames_received, frames_expected) + + def assertConnectionClosing(self, connection, code=None, reason=""): + """ + Incoming data caused the "Start the WebSocket Closing Handshake" process. + + """ + close_frame = Frame( + OP_CLOSE, + b"" if code is None else Close(code, reason).serialize(), + ) + # A close frame was received. + self.assertFrameReceived(connection, close_frame) + # A close frame and possibly the end of stream were sent. + self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) + + def assertConnectionFailing(self, connection, code=None, reason=""): + """ + Incoming data caused the "Fail the WebSocket Connection" process. + + """ + close_frame = Frame( + OP_CLOSE, + b"" if code is None else Close(code, reason).serialize(), + ) + # No frame was received. + self.assertFrameReceived(connection, None) + # A close frame and possibly the end of stream were sent. + self.assertFrameSent(connection, close_frame, eof=connection.side is SERVER) + + +class MaskingTests(ProtocolTestCase): + """ + Test frame masking. + + 5.1. Overview + + """ + + unmasked_text_frame_date = b"\x81\x04Spam" + masked_text_frame_data = b"\x81\x84\x00\xff\x00\xff\x53\x8f\x61\x92" + + def test_client_sends_masked_frame(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\xff\x00\xff"): + client.send_text(b"Spam", True) + self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) + + def test_server_sends_unmasked_frame(self): + server = Protocol(SERVER) + server.send_text(b"Spam", True) + self.assertEqual(server.data_to_send(), [self.unmasked_text_frame_date]) + + def test_client_receives_unmasked_frame(self): + client = Protocol(CLIENT) + client.receive_data(self.unmasked_text_frame_date) + self.assertFrameReceived( + client, + Frame(OP_TEXT, b"Spam"), + ) + + def test_server_receives_masked_frame(self): + server = Protocol(SERVER) + server.receive_data(self.masked_text_frame_data) + self.assertFrameReceived( + server, + Frame(OP_TEXT, b"Spam"), + ) + + def test_client_receives_masked_frame(self): + client = Protocol(CLIENT) + client.receive_data(self.masked_text_frame_data) + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "incorrect masking") + self.assertConnectionFailing(client, 1002, "incorrect masking") + + def test_server_receives_unmasked_frame(self): + server = Protocol(SERVER) + server.receive_data(self.unmasked_text_frame_date) + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "incorrect masking") + self.assertConnectionFailing(server, 1002, "incorrect masking") + + +class ContinuationTests(ProtocolTestCase): + """ + Test continuation frames without text or binary frames. + + """ + + def test_client_sends_unexpected_continuation(self): + client = Protocol(CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_server_sends_unexpected_continuation(self): + server = Protocol(SERVER) + with self.assertRaises(ProtocolError) as raised: + server.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_client_receives_unexpected_continuation(self): + client = Protocol(CLIENT) + client.receive_data(b"\x00\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "unexpected continuation frame") + self.assertConnectionFailing(client, 1002, "unexpected continuation frame") + + def test_server_receives_unexpected_continuation(self): + server = Protocol(SERVER) + server.receive_data(b"\x00\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "unexpected continuation frame") + self.assertConnectionFailing(server, 1002, "unexpected continuation frame") + + def test_client_sends_continuation_after_sending_close(self): + client = Protocol(CLIENT) + # Since it isn't possible to send a close frame in a fragmented + # message (see test_client_send_close_in_fragmented_message), in fact, + # this is the same test as test_client_sends_unexpected_continuation. + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(ProtocolError) as raised: + client.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_server_sends_continuation_after_sending_close(self): + # Since it isn't possible to send a close frame in a fragmented + # message (see test_server_send_close_in_fragmented_message), in fact, + # this is the same test as test_server_sends_unexpected_continuation. + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(ProtocolError) as raised: + server.send_continuation(b"", fin=False) + self.assertEqual(str(raised.exception), "unexpected continuation frame") + + def test_client_receives_continuation_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x00\x00") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_continuation_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x00\x80\x00\xff\x00\xff") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class TextTests(ProtocolTestCase): + """ + Test text frames and continuation frames. + + """ + + def test_client_sends_text(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_text("😀".encode()) + self.assertEqual( + client.data_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] + ) + + def test_server_sends_text(self): + server = Protocol(SERVER) + server.send_text("😀".encode()) + self.assertEqual(server.data_to_send(), [b"\x81\x04\xf0\x9f\x98\x80"]) + + def test_client_receives_text(self): + client = Protocol(CLIENT) + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()), + ) + + def test_server_receives_text(self): + server = Protocol(SERVER) + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()), + ) + + def test_client_receives_text_over_size_limit(self): + client = Protocol(CLIENT, max_size=3) + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + + def test_server_receives_text_over_size_limit(self): + server = Protocol(SERVER, max_size=3) + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + + def test_client_receives_text_without_size_limit(self): + client = Protocol(CLIENT, max_size=None) + client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()), + ) + + def test_server_receives_text_without_size_limit(self): + server = Protocol(SERVER, max_size=None) + server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()), + ) + + def test_client_sends_fragmented_text(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_text("😀".encode()[:2], fin=False) + self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation("😀😀".encode()[2:6], fin=False) + self.assertEqual( + client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] + ) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation("😀".encode()[2:], fin=True) + self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) + + def test_server_sends_fragmented_text(self): + server = Protocol(SERVER) + server.send_text("😀".encode()[:2], fin=False) + self.assertEqual(server.data_to_send(), [b"\x01\x02\xf0\x9f"]) + server.send_continuation("😀😀".encode()[2:6], fin=False) + self.assertEqual(server.data_to_send(), [b"\x00\x04\x98\x80\xf0\x9f"]) + server.send_continuation("😀".encode()[2:], fin=True) + self.assertEqual(server.data_to_send(), [b"\x80\x02\x98\x80"]) + + def test_client_receives_fragmented_text(self): + client = Protocol(CLIENT) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), + ) + client.receive_data(b"\x80\x02\x98\x80") + self.assertFrameReceived( + client, + Frame(OP_CONT, "😀".encode()[2:]), + ) + + def test_server_receives_fragmented_text(self): + server = Protocol(SERVER) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertFrameReceived( + server, + Frame(OP_CONT, "😀".encode()[2:]), + ) + + def test_client_receives_fragmented_text_over_size_limit(self): + client = Protocol(CLIENT, max_size=3) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + client.receive_data(b"\x80\x02\x98\x80") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + + def test_server_receives_fragmented_text_over_size_limit(self): + server = Protocol(SERVER, max_size=3) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + + def test_client_receives_fragmented_text_without_size_limit(self): + client = Protocol(CLIENT, max_size=None) + client.receive_data(b"\x01\x02\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + client.receive_data(b"\x00\x04\x98\x80\xf0\x9f") + self.assertFrameReceived( + client, + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), + ) + client.receive_data(b"\x80\x02\x98\x80") + self.assertFrameReceived( + client, + Frame(OP_CONT, "😀".encode()[2:]), + ) + + def test_server_receives_fragmented_text_without_size_limit(self): + server = Protocol(SERVER, max_size=None) + server.receive_data(b"\x01\x82\x00\x00\x00\x00\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_TEXT, "😀".encode()[:2], fin=False), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f") + self.assertFrameReceived( + server, + Frame(OP_CONT, "😀😀".encode()[2:6], fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") + self.assertFrameReceived( + server, + Frame(OP_CONT, "😀".encode()[2:]), + ) + + def test_client_sends_unexpected_text(self): + client = Protocol(CLIENT) + client.send_text(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + client.send_text(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_server_sends_unexpected_text(self): + server = Protocol(SERVER) + server.send_text(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + server.send_text(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receives_unexpected_text(self): + client = Protocol(CLIENT) + client.receive_data(b"\x01\x00") + self.assertFrameReceived( + client, + Frame(OP_TEXT, b"", fin=False), + ) + client.receive_data(b"\x01\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "expected a continuation frame") + self.assertConnectionFailing(client, 1002, "expected a continuation frame") + + def test_server_receives_unexpected_text(self): + server = Protocol(SERVER) + server.receive_data(b"\x01\x80\x00\x00\x00\x00") + self.assertFrameReceived( + server, + Frame(OP_TEXT, b"", fin=False), + ) + server.receive_data(b"\x01\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "expected a continuation frame") + self.assertConnectionFailing(server, 1002, "expected a continuation frame") + + def test_client_sends_text_after_sending_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState): + client.send_text(b"") + + def test_server_sends_text_after_sending_close(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(InvalidState): + server.send_text(b"") + + def test_client_receives_text_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x81\x00") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_text_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x81\x80\x00\xff\x00\xff") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class BinaryTests(ProtocolTestCase): + """ + Test binary frames and continuation frames. + + """ + + def test_client_sends_binary(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_binary(b"\x01\x02\xfe\xff") + self.assertEqual( + client.data_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] + ) + + def test_server_sends_binary(self): + server = Protocol(SERVER) + server.send_binary(b"\x01\x02\xfe\xff") + self.assertEqual(server.data_to_send(), [b"\x82\x04\x01\x02\xfe\xff"]) + + def test_client_receives_binary(self): + client = Protocol(CLIENT) + client.receive_data(b"\x82\x04\x01\x02\xfe\xff") + self.assertFrameReceived( + client, + Frame(OP_BINARY, b"\x01\x02\xfe\xff"), + ) + + def test_server_receives_binary(self): + server = Protocol(SERVER) + server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") + self.assertFrameReceived( + server, + Frame(OP_BINARY, b"\x01\x02\xfe\xff"), + ) + + def test_client_receives_binary_over_size_limit(self): + client = Protocol(CLIENT, max_size=3) + client.receive_data(b"\x82\x04\x01\x02\xfe\xff") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + + def test_server_receives_binary_over_size_limit(self): + server = Protocol(SERVER, max_size=3) + server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + + def test_client_sends_fragmented_binary(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_binary(b"\x01\x02", fin=False) + self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation(b"\xee\xff\x01\x02", fin=False) + self.assertEqual( + client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] + ) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_continuation(b"\xee\xff", fin=True) + self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) + + def test_server_sends_fragmented_binary(self): + server = Protocol(SERVER) + server.send_binary(b"\x01\x02", fin=False) + self.assertEqual(server.data_to_send(), [b"\x02\x02\x01\x02"]) + server.send_continuation(b"\xee\xff\x01\x02", fin=False) + self.assertEqual(server.data_to_send(), [b"\x00\x04\xee\xff\x01\x02"]) + server.send_continuation(b"\xee\xff", fin=True) + self.assertEqual(server.data_to_send(), [b"\x80\x02\xee\xff"]) + + def test_client_receives_fragmented_binary(self): + client = Protocol(CLIENT) + client.receive_data(b"\x02\x02\x01\x02") + self.assertFrameReceived( + client, + Frame(OP_BINARY, b"\x01\x02", fin=False), + ) + client.receive_data(b"\x00\x04\xfe\xff\x01\x02") + self.assertFrameReceived( + client, + Frame(OP_CONT, b"\xfe\xff\x01\x02", fin=False), + ) + client.receive_data(b"\x80\x02\xfe\xff") + self.assertFrameReceived( + client, + Frame(OP_CONT, b"\xfe\xff"), + ) + + def test_server_receives_fragmented_binary(self): + server = Protocol(SERVER) + server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") + self.assertFrameReceived( + server, + Frame(OP_BINARY, b"\x01\x02", fin=False), + ) + server.receive_data(b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02") + self.assertFrameReceived( + server, + Frame(OP_CONT, b"\xee\xff\x01\x02", fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") + self.assertFrameReceived( + server, + Frame(OP_CONT, b"\xfe\xff"), + ) + + def test_client_receives_fragmented_binary_over_size_limit(self): + client = Protocol(CLIENT, max_size=3) + client.receive_data(b"\x02\x02\x01\x02") + self.assertFrameReceived( + client, + Frame(OP_BINARY, b"\x01\x02", fin=False), + ) + client.receive_data(b"\x80\x02\xfe\xff") + self.assertIsInstance(client.parser_exc, PayloadTooBig) + self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + + def test_server_receives_fragmented_binary_over_size_limit(self): + server = Protocol(SERVER, max_size=3) + server.receive_data(b"\x02\x82\x00\x00\x00\x00\x01\x02") + self.assertFrameReceived( + server, + Frame(OP_BINARY, b"\x01\x02", fin=False), + ) + server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") + self.assertIsInstance(server.parser_exc, PayloadTooBig) + self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + + def test_client_sends_unexpected_binary(self): + client = Protocol(CLIENT) + client.send_binary(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + client.send_binary(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_server_sends_unexpected_binary(self): + server = Protocol(SERVER) + server.send_binary(b"", fin=False) + with self.assertRaises(ProtocolError) as raised: + server.send_binary(b"", fin=False) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receives_unexpected_binary(self): + client = Protocol(CLIENT) + client.receive_data(b"\x02\x00") + self.assertFrameReceived( + client, + Frame(OP_BINARY, b"", fin=False), + ) + client.receive_data(b"\x02\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "expected a continuation frame") + self.assertConnectionFailing(client, 1002, "expected a continuation frame") + + def test_server_receives_unexpected_binary(self): + server = Protocol(SERVER) + server.receive_data(b"\x02\x80\x00\x00\x00\x00") + self.assertFrameReceived( + server, + Frame(OP_BINARY, b"", fin=False), + ) + server.receive_data(b"\x02\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "expected a continuation frame") + self.assertConnectionFailing(server, 1002, "expected a continuation frame") + + def test_client_sends_binary_after_sending_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState): + client.send_binary(b"") + + def test_server_sends_binary_after_sending_close(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(InvalidState): + server.send_binary(b"") + + def test_client_receives_binary_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x82\x00") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_binary_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x82\x80\x00\xff\x00\xff") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class CloseTests(ProtocolTestCase): + """ + Test close frames. + + See RFC 6544: + + 5.5.1. Close + 7.1.6. The WebSocket Connection Close Reason + 7.1.7. Fail the WebSocket Connection + + """ + + def test_close_code(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x04\x03\xe8OK") + client.receive_eof() + self.assertEqual(client.close_code, 1000) + + def test_close_reason(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe8OK") + server.receive_eof() + self.assertEqual(server.close_reason, "OK") + + def test_close_code_not_provided(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x00\x00\x00\x00") + server.receive_eof() + self.assertEqual(server.close_code, 1005) + + def test_close_reason_not_provided(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + client.receive_eof() + self.assertEqual(client.close_reason, "") + + def test_close_code_not_available(self): + client = Protocol(CLIENT) + client.receive_eof() + self.assertEqual(client.close_code, 1006) + + def test_close_reason_not_available(self): + server = Protocol(SERVER) + server.receive_eof() + self.assertEqual(server.close_reason, "") + + def test_close_code_not_available_yet(self): + server = Protocol(SERVER) + self.assertIsNone(server.close_code) + + def test_close_reason_not_available_yet(self): + client = Protocol(CLIENT) + self.assertIsNone(client.close_reason) + + def test_client_sends_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.send_close() + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, CLOSING) + + def test_server_sends_close(self): + server = Protocol(SERVER) + server.send_close() + self.assertEqual(server.data_to_send(), [b"\x88\x00"]) + self.assertIs(server.state, CLOSING) + + def test_client_receives_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.receive_data(b"\x88\x00") + self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, CLOSING) + + def test_server_receives_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertEqual(server.events_received(), [Frame(OP_CLOSE, b"")]) + self.assertEqual(server.data_to_send(), [b"\x88\x00", b""]) + self.assertIs(server.state, CLOSING) + + def test_client_sends_close_then_receives_close(self): + # Client-initiated close handshake on the client side. + client = Protocol(CLIENT) + + client.send_close() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, Frame(OP_CLOSE, b"")) + + client.receive_data(b"\x88\x00") + self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) + self.assertFrameSent(client, None) + + client.receive_eof() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None, eof=True) + + def test_server_sends_close_then_receives_close(self): + # Server-initiated close handshake on the server side. + server = Protocol(SERVER) + + server.send_close() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, Frame(OP_CLOSE, b"")) + + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) + self.assertFrameSent(server, None, eof=True) + + server.receive_eof() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + def test_client_receives_close_then_sends_close(self): + # Server-initiated close handshake on the client side. + client = Protocol(CLIENT) + + client.receive_data(b"\x88\x00") + self.assertFrameReceived(client, Frame(OP_CLOSE, b"")) + self.assertFrameSent(client, Frame(OP_CLOSE, b"")) + + client.receive_eof() + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_close_then_sends_close(self): + # Client-initiated close handshake on the server side. + server = Protocol(SERVER) + + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertFrameReceived(server, Frame(OP_CLOSE, b"")) + self.assertFrameSent(server, Frame(OP_CLOSE, b""), eof=True) + + server.receive_eof() + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + def test_client_sends_close_with_code(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + self.assertIs(client.state, CLOSING) + + def test_server_sends_close_with_code(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + self.assertIs(server.state, CLOSING) + + def test_client_receives_close_with_code(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000, "") + self.assertIs(client.state, CLOSING) + + def test_server_receives_close_with_code(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001, "") + self.assertIs(server.state, CLOSING) + + def test_client_sends_close_with_code_and_reason(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001, "going away") + self.assertEqual( + client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] + ) + self.assertIs(client.state, CLOSING) + + def test_server_sends_close_with_code_and_reason(self): + server = Protocol(SERVER) + server.send_close(1000, "OK") + self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) + self.assertIs(server.state, CLOSING) + + def test_client_receives_close_with_code_and_reason(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x04\x03\xe8OK") + self.assertConnectionClosing(client, 1000, "OK") + self.assertIs(client.state, CLOSING) + + def test_server_receives_close_with_code_and_reason(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") + self.assertConnectionClosing(server, 1001, "going away") + self.assertIs(server.state, CLOSING) + + def test_client_sends_close_with_reason_only(self): + client = Protocol(CLIENT) + with self.assertRaises(ProtocolError) as raised: + client.send_close(reason="going away") + self.assertEqual(str(raised.exception), "cannot send a reason without a code") + + def test_server_sends_close_with_reason_only(self): + server = Protocol(SERVER) + with self.assertRaises(ProtocolError) as raised: + server.send_close(reason="OK") + self.assertEqual(str(raised.exception), "cannot send a reason without a code") + + def test_client_receives_close_with_truncated_code(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x01\x03") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "close frame too short") + self.assertConnectionFailing(client, 1002, "close frame too short") + self.assertIs(client.state, CLOSING) + + def test_server_receives_close_with_truncated_code(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "close frame too short") + self.assertConnectionFailing(server, 1002, "close frame too short") + self.assertIs(server.state, CLOSING) + + def test_client_receives_close_with_non_utf8_reason(self): + client = Protocol(CLIENT) + + client.receive_data(b"\x88\x04\x03\xe8\xff\xff") + self.assertIsInstance(client.parser_exc, UnicodeDecodeError) + self.assertEqual( + str(client.parser_exc), + "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", + ) + self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") + self.assertIs(client.state, CLOSING) + + def test_server_receives_close_with_non_utf8_reason(self): + server = Protocol(SERVER) + + server.receive_data(b"\x88\x84\x00\x00\x00\x00\x03\xe9\xff\xff") + self.assertIsInstance(server.parser_exc, UnicodeDecodeError) + self.assertEqual( + str(server.parser_exc), + "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", + ) + self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") + self.assertIs(server.state, CLOSING) + + +class PingTests(ProtocolTestCase): + """ + Test ping. See 5.5.2. Ping in RFC 6544. + + """ + + def test_client_sends_ping(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"") + self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) + + def test_server_sends_ping(self): + server = Protocol(SERVER) + server.send_ping(b"") + self.assertEqual(server.data_to_send(), [b"\x89\x00"]) + + def test_client_receives_ping(self): + client = Protocol(CLIENT) + client.receive_data(b"\x89\x00") + self.assertFrameReceived( + client, + Frame(OP_PING, b""), + ) + self.assertFrameSent( + client, + Frame(OP_PONG, b""), + ) + + def test_server_receives_ping(self): + server = Protocol(SERVER) + server.receive_data(b"\x89\x80\x00\x44\x88\xcc") + self.assertFrameReceived( + server, + Frame(OP_PING, b""), + ) + self.assertFrameSent( + server, + Frame(OP_PONG, b""), + ) + + def test_client_sends_ping_with_data(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"\x22\x66\xaa\xee") + self.assertEqual( + client.data_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] + ) + + def test_server_sends_ping_with_data(self): + server = Protocol(SERVER) + server.send_ping(b"\x22\x66\xaa\xee") + self.assertEqual(server.data_to_send(), [b"\x89\x04\x22\x66\xaa\xee"]) + + def test_client_receives_ping_with_data(self): + client = Protocol(CLIENT) + client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + self.assertFrameReceived( + client, + Frame(OP_PING, b"\x22\x66\xaa\xee"), + ) + self.assertFrameSent( + client, + Frame(OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_server_receives_ping_with_data(self): + server = Protocol(SERVER) + server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived( + server, + Frame(OP_PING, b"\x22\x66\xaa\xee"), + ) + self.assertFrameSent( + server, + Frame(OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_client_sends_fragmented_ping_frame(self): + client = Protocol(CLIENT) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + client.send_frame(Frame(OP_PING, b"", fin=False)) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_server_sends_fragmented_ping_frame(self): + server = Protocol(SERVER) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + server.send_frame(Frame(OP_PING, b"", fin=False)) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_client_receives_fragmented_ping_frame(self): + client = Protocol(CLIENT) + client.receive_data(b"\x09\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "fragmented control frame") + self.assertConnectionFailing(client, 1002, "fragmented control frame") + + def test_server_receives_fragmented_ping_frame(self): + server = Protocol(SERVER) + server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "fragmented control frame") + self.assertConnectionFailing(server, 1002, "fragmented control frame") + + def test_client_sends_ping_after_sending_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + # The spec says: "An endpoint MAY send a Ping frame any time (...) + # before the connection is closed" but websockets doesn't support + # sending a Ping frame after a Close frame. + with self.assertRaises(InvalidState) as raised: + client.send_ping(b"") + self.assertEqual( + str(raised.exception), + "cannot write to a WebSocket in the CLOSING state", + ) + + def test_server_sends_ping_after_sending_close(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + # The spec says: "An endpoint MAY send a Ping frame any time (...) + # before the connection is closed" but websockets doesn't support + # sending a Ping frame after a Close frame. + with self.assertRaises(InvalidState) as raised: + server.send_ping(b"") + self.assertEqual( + str(raised.exception), + "cannot write to a WebSocket in the CLOSING state", + ) + + def test_client_receives_ping_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_ping_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class PongTests(ProtocolTestCase): + """ + Test pong frames. See 5.5.3. Pong in RFC 6544. + + """ + + def test_client_sends_pong(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_pong(b"") + self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) + + def test_server_sends_pong(self): + server = Protocol(SERVER) + server.send_pong(b"") + self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) + + def test_client_receives_pong(self): + client = Protocol(CLIENT) + client.receive_data(b"\x8a\x00") + self.assertFrameReceived( + client, + Frame(OP_PONG, b""), + ) + + def test_server_receives_pong(self): + server = Protocol(SERVER) + server.receive_data(b"\x8a\x80\x00\x44\x88\xcc") + self.assertFrameReceived( + server, + Frame(OP_PONG, b""), + ) + + def test_client_sends_pong_with_data(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_pong(b"\x22\x66\xaa\xee") + self.assertEqual( + client.data_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] + ) + + def test_server_sends_pong_with_data(self): + server = Protocol(SERVER) + server.send_pong(b"\x22\x66\xaa\xee") + self.assertEqual(server.data_to_send(), [b"\x8a\x04\x22\x66\xaa\xee"]) + + def test_client_receives_pong_with_data(self): + client = Protocol(CLIENT) + client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + self.assertFrameReceived( + client, + Frame(OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_server_receives_pong_with_data(self): + server = Protocol(SERVER) + server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived( + server, + Frame(OP_PONG, b"\x22\x66\xaa\xee"), + ) + + def test_client_sends_fragmented_pong_frame(self): + client = Protocol(CLIENT) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + client.send_frame(Frame(OP_PONG, b"", fin=False)) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_server_sends_fragmented_pong_frame(self): + server = Protocol(SERVER) + # This is only possible through a private API. + with self.assertRaises(ProtocolError) as raised: + server.send_frame(Frame(OP_PONG, b"", fin=False)) + self.assertEqual(str(raised.exception), "fragmented control frame") + + def test_client_receives_fragmented_pong_frame(self): + client = Protocol(CLIENT) + client.receive_data(b"\x0a\x00") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "fragmented control frame") + self.assertConnectionFailing(client, 1002, "fragmented control frame") + + def test_server_receives_fragmented_pong_frame(self): + server = Protocol(SERVER) + server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "fragmented control frame") + self.assertConnectionFailing(server, 1002, "fragmented control frame") + + def test_client_sends_pong_after_sending_close(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(1001) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + # websockets doesn't support sending a Pong frame after a Close frame. + with self.assertRaises(InvalidState): + client.send_pong(b"") + + def test_server_sends_pong_after_sending_close(self): + server = Protocol(SERVER) + server.send_close(1000) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + # websockets doesn't support sending a Pong frame after a Close frame. + with self.assertRaises(InvalidState): + server.send_pong(b"") + + def test_client_receives_pong_after_receiving_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + self.assertConnectionClosing(client, 1000) + client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + self.assertFrameReceived(client, None) + self.assertFrameSent(client, None) + + def test_server_receives_pong_after_receiving_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertConnectionClosing(server, 1001) + server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + self.assertFrameReceived(server, None) + self.assertFrameSent(server, None) + + +class FailTests(ProtocolTestCase): + """ + Test failing the connection. + + See 7.1.7. Fail the WebSocket Connection in RFC 6544. + + """ + + def test_client_stops_processing_frames_after_fail(self): + client = Protocol(CLIENT) + client.fail(1002) + self.assertConnectionFailing(client, 1002) + client.receive_data(b"\x88\x02\x03\xea") + self.assertFrameReceived(client, None) + + def test_server_stops_processing_frames_after_fail(self): + server = Protocol(SERVER) + server.fail(1002) + self.assertConnectionFailing(server, 1002) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") + self.assertFrameReceived(server, None) + + +class FragmentationTests(ProtocolTestCase): + """ + Test message fragmentation. + + See 5.4. Fragmentation in RFC 6544. + + """ + + def test_client_send_ping_pong_in_fragmented_message(self): + client = Protocol(CLIENT) + client.send_text(b"Spam", fin=False) + self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) + client.send_ping(b"Ping") + self.assertFrameSent(client, Frame(OP_PING, b"Ping")) + client.send_continuation(b"Ham", fin=False) + self.assertFrameSent(client, Frame(OP_CONT, b"Ham", fin=False)) + client.send_pong(b"Pong") + self.assertFrameSent(client, Frame(OP_PONG, b"Pong")) + client.send_continuation(b"Eggs", fin=True) + self.assertFrameSent(client, Frame(OP_CONT, b"Eggs")) + + def test_server_send_ping_pong_in_fragmented_message(self): + server = Protocol(SERVER) + server.send_text(b"Spam", fin=False) + self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) + server.send_ping(b"Ping") + self.assertFrameSent(server, Frame(OP_PING, b"Ping")) + server.send_continuation(b"Ham", fin=False) + self.assertFrameSent(server, Frame(OP_CONT, b"Ham", fin=False)) + server.send_pong(b"Pong") + self.assertFrameSent(server, Frame(OP_PONG, b"Pong")) + server.send_continuation(b"Eggs", fin=True) + self.assertFrameSent(server, Frame(OP_CONT, b"Eggs")) + + def test_client_receive_ping_pong_in_fragmented_message(self): + client = Protocol(CLIENT) + client.receive_data(b"\x01\x04Spam") + self.assertFrameReceived( + client, + Frame(OP_TEXT, b"Spam", fin=False), + ) + client.receive_data(b"\x89\x04Ping") + self.assertFrameReceived( + client, + Frame(OP_PING, b"Ping"), + ) + self.assertFrameSent( + client, + Frame(OP_PONG, b"Ping"), + ) + client.receive_data(b"\x00\x03Ham") + self.assertFrameReceived( + client, + Frame(OP_CONT, b"Ham", fin=False), + ) + client.receive_data(b"\x8a\x04Pong") + self.assertFrameReceived( + client, + Frame(OP_PONG, b"Pong"), + ) + client.receive_data(b"\x80\x04Eggs") + self.assertFrameReceived( + client, + Frame(OP_CONT, b"Eggs"), + ) + + def test_server_receive_ping_pong_in_fragmented_message(self): + server = Protocol(SERVER) + server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") + self.assertFrameReceived( + server, + Frame(OP_TEXT, b"Spam", fin=False), + ) + server.receive_data(b"\x89\x84\x00\x00\x00\x00Ping") + self.assertFrameReceived( + server, + Frame(OP_PING, b"Ping"), + ) + self.assertFrameSent( + server, + Frame(OP_PONG, b"Ping"), + ) + server.receive_data(b"\x00\x83\x00\x00\x00\x00Ham") + self.assertFrameReceived( + server, + Frame(OP_CONT, b"Ham", fin=False), + ) + server.receive_data(b"\x8a\x84\x00\x00\x00\x00Pong") + self.assertFrameReceived( + server, + Frame(OP_PONG, b"Pong"), + ) + server.receive_data(b"\x80\x84\x00\x00\x00\x00Eggs") + self.assertFrameReceived( + server, + Frame(OP_CONT, b"Eggs"), + ) + + def test_client_send_close_in_fragmented_message(self): + client = Protocol(CLIENT) + client.send_text(b"Spam", fin=False) + self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + with self.assertRaises(ProtocolError) as raised: + client.send_close(1001) + self.assertEqual(str(raised.exception), "expected a continuation frame") + client.send_continuation(b"Eggs", fin=True) + + def test_server_send_close_in_fragmented_message(self): + server = Protocol(CLIENT) + server.send_text(b"Spam", fin=False) + self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + with self.assertRaises(ProtocolError) as raised: + server.send_close(1000) + self.assertEqual(str(raised.exception), "expected a continuation frame") + + def test_client_receive_close_in_fragmented_message(self): + client = Protocol(CLIENT) + client.receive_data(b"\x01\x04Spam") + self.assertFrameReceived( + client, + Frame(OP_TEXT, b"Spam", fin=False), + ) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + client.receive_data(b"\x88\x02\x03\xe8") + self.assertIsInstance(client.parser_exc, ProtocolError) + self.assertEqual(str(client.parser_exc), "incomplete fragmented message") + self.assertConnectionFailing(client, 1002, "incomplete fragmented message") + + def test_server_receive_close_in_fragmented_message(self): + server = Protocol(SERVER) + server.receive_data(b"\x01\x84\x00\x00\x00\x00Spam") + self.assertFrameReceived( + server, + Frame(OP_TEXT, b"Spam", fin=False), + ) + # The spec says: "An endpoint MUST be capable of handling control + # frames in the middle of a fragmented message." However, since the + # endpoint must not send a data frame after a close frame, a close + # frame can't be "in the middle" of a fragmented message. + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") + self.assertIsInstance(server.parser_exc, ProtocolError) + self.assertEqual(str(server.parser_exc), "incomplete fragmented message") + self.assertConnectionFailing(server, 1002, "incomplete fragmented message") + + +class EOFTests(ProtocolTestCase): + """ + Test half-closes on connection termination. + + """ + + def test_client_receives_eof(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() + self.assertIs(client.state, CLOSED) + + def test_server_receives_eof(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() + self.assertIs(server.state, CLOSED) + + def test_client_receives_eof_between_frames(self): + client = Protocol(CLIENT) + client.receive_eof() + self.assertIsInstance(client.parser_exc, EOFError) + self.assertEqual(str(client.parser_exc), "unexpected end of stream") + self.assertIs(client.state, CLOSED) + + def test_server_receives_eof_between_frames(self): + server = Protocol(SERVER) + server.receive_eof() + self.assertIsInstance(server.parser_exc, EOFError) + self.assertEqual(str(server.parser_exc), "unexpected end of stream") + self.assertIs(server.state, CLOSED) + + def test_client_receives_eof_inside_frame(self): + client = Protocol(CLIENT) + client.receive_data(b"\x81") + client.receive_eof() + self.assertIsInstance(client.parser_exc, EOFError) + self.assertEqual( + str(client.parser_exc), + "stream ends after 1 bytes, expected 2 bytes", + ) + self.assertIs(client.state, CLOSED) + + def test_server_receives_eof_inside_frame(self): + server = Protocol(SERVER) + server.receive_data(b"\x81") + server.receive_eof() + self.assertIsInstance(server.parser_exc, EOFError) + self.assertEqual( + str(server.parser_exc), + "stream ends after 1 bytes, expected 2 bytes", + ) + self.assertIs(server.state, CLOSED) + + def test_client_receives_data_after_exception(self): + client = Protocol(CLIENT) + client.receive_data(b"\xff\xff") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_data(b"\x00\x00") + self.assertFrameSent(client, None) + + def test_server_receives_data_after_exception(self): + server = Protocol(SERVER) + server.receive_data(b"\xff\xff") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_data(b"\x00\x00") + self.assertFrameSent(server, None) + + def test_client_receives_eof_after_exception(self): + client = Protocol(CLIENT) + client.receive_data(b"\xff\xff") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_eof() + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_eof_after_exception(self): + server = Protocol(SERVER) + server.receive_data(b"\xff\xff") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_eof() + self.assertFrameSent(server, None) + + def test_client_receives_data_and_eof_after_exception(self): + client = Protocol(CLIENT) + client.receive_data(b"\xff\xff") + self.assertConnectionFailing(client, 1002, "invalid opcode") + client.receive_data(b"\x00\x00") + client.receive_eof() + self.assertFrameSent(client, None, eof=True) + + def test_server_receives_data_and_eof_after_exception(self): + server = Protocol(SERVER) + server.receive_data(b"\xff\xff") + self.assertConnectionFailing(server, 1002, "invalid opcode") + server.receive_data(b"\x00\x00") + server.receive_eof() + self.assertFrameSent(server, None) + + def test_client_receives_data_after_eof(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() + with self.assertRaises(EOFError) as raised: + client.receive_data(b"\x88\x00") + self.assertEqual(str(raised.exception), "stream ended") + + def test_server_receives_data_after_eof(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() + with self.assertRaises(EOFError) as raised: + server.receive_data(b"\x88\x80\x00\x00\x00\x00") + self.assertEqual(str(raised.exception), "stream ended") + + def test_client_receives_eof_after_eof(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + self.assertConnectionClosing(client) + client.receive_eof() + with self.assertRaises(EOFError) as raised: + client.receive_eof() + self.assertEqual(str(raised.exception), "stream ended") + + def test_server_receives_eof_after_eof(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertConnectionClosing(server) + server.receive_eof() + with self.assertRaises(EOFError) as raised: + server.receive_eof() + self.assertEqual(str(raised.exception), "stream ended") + + +class TCPCloseTests(ProtocolTestCase): + """ + Test expectation of TCP close on connection termination. + + """ + + def test_client_default(self): + client = Protocol(CLIENT) + self.assertFalse(client.close_expected()) + + def test_server_default(self): + server = Protocol(SERVER) + self.assertFalse(server.close_expected()) + + def test_client_sends_close(self): + client = Protocol(CLIENT) + client.send_close() + self.assertTrue(client.close_expected()) + + def test_server_sends_close(self): + server = Protocol(SERVER) + server.send_close() + self.assertTrue(server.close_expected()) + + def test_client_receives_close(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + self.assertTrue(client.close_expected()) + + def test_client_receives_close_then_eof(self): + client = Protocol(CLIENT) + client.receive_data(b"\x88\x00") + client.receive_eof() + self.assertFalse(client.close_expected()) + + def test_server_receives_close_then_eof(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + server.receive_eof() + self.assertFalse(server.close_expected()) + + def test_server_receives_close(self): + server = Protocol(SERVER) + server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") + self.assertTrue(server.close_expected()) + + def test_client_fails_connection(self): + client = Protocol(CLIENT) + client.fail(1002) + self.assertTrue(client.close_expected()) + + def test_server_fails_connection(self): + server = Protocol(SERVER) + server.fail(1002) + self.assertTrue(server.close_expected()) + + +class ConnectionClosedTests(ProtocolTestCase): + """ + Test connection closed exception. + + """ + + def test_client_sends_close_then_receives_close(self): + # Client-initiated close handshake on the client side complete. + client = Protocol(CLIENT) + client.send_close(1000, "") + client.receive_data(b"\x88\x02\x03\xe8") + client.receive_eof() + exc = client.close_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertFalse(exc.rcvd_then_sent) + + def test_server_sends_close_then_receives_close(self): + # Server-initiated close handshake on the server side complete. + server = Protocol(SERVER) + server.send_close(1000, "") + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") + server.receive_eof() + exc = server.close_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertFalse(exc.rcvd_then_sent) + + def test_client_receives_close_then_sends_close(self): + # Server-initiated close handshake on the client side complete. + client = Protocol(CLIENT) + client.receive_data(b"\x88\x02\x03\xe8") + client.receive_eof() + exc = client.close_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertTrue(exc.rcvd_then_sent) + + def test_server_receives_close_then_sends_close(self): + # Client-initiated close handshake on the server side complete. + server = Protocol(SERVER) + server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") + server.receive_eof() + exc = server.close_exc + self.assertIsInstance(exc, ConnectionClosedOK) + self.assertEqual(exc.rcvd, Close(1000, "")) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertTrue(exc.rcvd_then_sent) + + def test_client_sends_close_then_receives_eof(self): + # Client-initiated close handshake on the client side times out. + client = Protocol(CLIENT) + client.send_close(1000, "") + client.receive_eof() + exc = client.close_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertIsNone(exc.rcvd_then_sent) + + def test_server_sends_close_then_receives_eof(self): + # Server-initiated close handshake on the server side times out. + server = Protocol(SERVER) + server.send_close(1000, "") + server.receive_eof() + exc = server.close_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertEqual(exc.sent, Close(1000, "")) + self.assertIsNone(exc.rcvd_then_sent) + + def test_client_receives_eof(self): + # Server-initiated close handshake on the client side times out. + client = Protocol(CLIENT) + client.receive_eof() + exc = client.close_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertIsNone(exc.sent) + self.assertIsNone(exc.rcvd_then_sent) + + def test_server_receives_eof(self): + # Client-initiated close handshake on the server side times out. + server = Protocol(SERVER) + server.receive_eof() + exc = server.close_exc + self.assertIsInstance(exc, ConnectionClosedError) + self.assertIsNone(exc.rcvd) + self.assertIsNone(exc.sent) + self.assertIsNone(exc.rcvd_then_sent) + + +class ErrorTests(ProtocolTestCase): + """ + Test other error cases. + + """ + + def test_client_hits_internal_error_reading_frame(self): + client = Protocol(CLIENT) + # This isn't supposed to happen, so we're simulating it. + with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + client.receive_data(b"\x81\x00") + self.assertIsInstance(client.parser_exc, RuntimeError) + self.assertEqual(str(client.parser_exc), "BOOM") + self.assertConnectionFailing(client, 1011, "") + + def test_server_hits_internal_error_reading_frame(self): + server = Protocol(SERVER) + # This isn't supposed to happen, so we're simulating it. + with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + server.receive_data(b"\x81\x80\x00\x00\x00\x00") + self.assertIsInstance(server.parser_exc, RuntimeError) + self.assertEqual(str(server.parser_exc), "BOOM") + self.assertConnectionFailing(server, 1011, "") + + +class ExtensionsTests(ProtocolTestCase): + """ + Test how extensions affect frames. + + """ + + def test_client_extension_encodes_frame(self): + client = Protocol(CLIENT) + client.extensions = [Rsv2Extension()] + with self.enforce_mask(b"\x00\x44\x88\xcc"): + client.send_ping(b"") + self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) + + def test_server_extension_encodes_frame(self): + server = Protocol(SERVER) + server.extensions = [Rsv2Extension()] + server.send_ping(b"") + self.assertEqual(server.data_to_send(), [b"\xa9\x00"]) + + def test_client_extension_decodes_frame(self): + client = Protocol(CLIENT) + client.extensions = [Rsv2Extension()] + client.receive_data(b"\xaa\x00") + self.assertEqual(client.events_received(), [Frame(OP_PONG, b"")]) + + def test_server_extension_decodes_frame(self): + server = Protocol(SERVER) + server.extensions = [Rsv2Extension()] + server.receive_data(b"\xaa\x80\x00\x44\x88\xcc") + self.assertEqual(server.events_received(), [Frame(OP_PONG, b"")]) + + +class MiscTests(unittest.TestCase): + def test_client_default_logger(self): + client = Protocol(CLIENT) + logger = logging.getLogger("websockets.client") + self.assertIs(client.logger, logger) + + def test_server_default_logger(self): + server = Protocol(SERVER) + logger = logging.getLogger("websockets.server") + self.assertIs(server.logger, logger) + + def test_client_custom_logger(self): + logger = logging.getLogger("test") + client = Protocol(CLIENT, logger=logger) + self.assertIs(client.logger, logger) + + def test_server_custom_logger(self): + logger = logging.getLogger("test") + server = Protocol(SERVER, logger=logger) + self.assertIs(server.logger, logger) diff --git a/tests/test_server.py b/tests/test_server.py index f1404499b..ba7e05df2 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,11 +3,11 @@ import unittest import unittest.mock -from websockets.connection import CONNECTING, OPEN from websockets.datastructures import Headers from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response +from websockets.protocol import CONNECTING, OPEN from websockets.server import * from .extensions.utils import ( @@ -17,12 +17,12 @@ ServerRsv2ExtensionFactory, ) from .test_utils import ACCEPT, KEY -from .utils import DATE +from .utils import DATE, DeprecationTestCase class ConnectTests(unittest.TestCase): def test_receive_connect(self): - server = ServerConnection() + server = ServerProtocol() server.receive_data( ( f"GET /test HTTP/1.1\r\n" @@ -40,7 +40,7 @@ def test_receive_connect(self): self.assertFalse(server.close_expected()) def test_connect_request(self): - server = ServerConnection() + server = ServerProtocol() server.receive_data( ( f"GET /test HTTP/1.1\r\n" @@ -84,7 +84,7 @@ def make_request(self): ) def test_send_accept(self): - server = ServerConnection() + server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.accept(self.make_request()) self.assertIsInstance(response, Response) @@ -104,7 +104,7 @@ def test_send_accept(self): self.assertEqual(server.state, OPEN) def test_send_reject(self): - server = ServerConnection() + server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") self.assertIsInstance(response, Response) @@ -126,7 +126,7 @@ def test_send_reject(self): self.assertEqual(server.state, CONNECTING) def test_accept_response(self): - server = ServerConnection() + server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.accept(self.make_request()) self.assertIsInstance(response, Response) @@ -146,7 +146,7 @@ def test_accept_response(self): self.assertIsNone(response.body) def test_reject_response(self): - server = ServerConnection() + server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") self.assertIsInstance(response, Response) @@ -166,17 +166,17 @@ def test_reject_response(self): self.assertEqual(response.body, b"Sorry folks.\n") def test_basic(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() response = server.accept(request) self.assertEqual(response.status_code, 101) def test_unexpected_exception(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() with unittest.mock.patch( - "websockets.server.ServerConnection.process_request", + "websockets.server.ServerProtocol.process_request", side_effect=Exception("BOOM"), ): response = server.accept(request) @@ -187,7 +187,7 @@ def test_unexpected_exception(self): self.assertEqual(str(raised.exception), "BOOM") def test_missing_connection(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Connection"] response = server.accept(request) @@ -199,7 +199,7 @@ def test_missing_connection(self): self.assertEqual(str(raised.exception), "missing Connection header") def test_invalid_connection(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Connection"] request.headers["Connection"] = "close" @@ -212,7 +212,7 @@ def test_invalid_connection(self): self.assertEqual(str(raised.exception), "invalid Connection header: close") def test_missing_upgrade(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Upgrade"] response = server.accept(request) @@ -224,7 +224,7 @@ def test_missing_upgrade(self): self.assertEqual(str(raised.exception), "missing Upgrade header") def test_invalid_upgrade(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Upgrade"] request.headers["Upgrade"] = "h2c" @@ -237,7 +237,7 @@ def test_invalid_upgrade(self): self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") def test_missing_key(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Key"] response = server.accept(request) @@ -248,7 +248,7 @@ def test_missing_key(self): self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") def test_multiple_key(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() request.headers["Sec-WebSocket-Key"] = KEY response = server.accept(request) @@ -263,7 +263,7 @@ def test_multiple_key(self): ) def test_invalid_key(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Key"] request.headers["Sec-WebSocket-Key"] = "not Base64 data!" @@ -277,7 +277,7 @@ def test_invalid_key(self): ) def test_truncated_key(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Key"] request.headers["Sec-WebSocket-Key"] = KEY[ @@ -293,7 +293,7 @@ def test_truncated_key(self): ) def test_missing_version(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Version"] response = server.accept(request) @@ -304,7 +304,7 @@ def test_missing_version(self): self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") def test_multiple_version(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) @@ -319,7 +319,7 @@ def test_multiple_version(self): ) def test_invalid_version(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() del request.headers["Sec-WebSocket-Version"] request.headers["Sec-WebSocket-Version"] = "11" @@ -333,7 +333,7 @@ def test_invalid_version(self): ) def test_no_origin(self): - server = ServerConnection(origins=["https://example.com"]) + server = ServerProtocol(origins=["https://example.com"]) request = self.make_request() response = server.accept(request) @@ -343,7 +343,7 @@ def test_no_origin(self): self.assertEqual(str(raised.exception), "missing Origin header") def test_origin(self): - server = ServerConnection(origins=["https://example.com"]) + server = ServerProtocol(origins=["https://example.com"]) request = self.make_request() request.headers["Origin"] = "https://example.com" response = server.accept(request) @@ -352,7 +352,7 @@ def test_origin(self): self.assertEqual(server.origin, "https://example.com") def test_unexpected_origin(self): - server = ServerConnection(origins=["https://example.com"]) + server = ServerProtocol(origins=["https://example.com"]) request = self.make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) @@ -365,7 +365,7 @@ def test_unexpected_origin(self): ) def test_multiple_origin(self): - server = ServerConnection( + server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = self.make_request() @@ -384,7 +384,7 @@ def test_multiple_origin(self): ) def test_supported_origin(self): - server = ServerConnection( + server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = self.make_request() @@ -395,7 +395,7 @@ def test_supported_origin(self): self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): - server = ServerConnection( + server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) request = self.make_request() @@ -410,7 +410,7 @@ def test_unsupported_origin(self): ) def test_no_origin_accepted(self): - server = ServerConnection(origins=[None]) + server = ServerProtocol(origins=[None]) request = self.make_request() response = server.accept(request) @@ -418,7 +418,7 @@ def test_no_origin_accepted(self): self.assertIsNone(server.origin) def test_no_extensions(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() response = server.accept(request) @@ -427,7 +427,7 @@ def test_no_extensions(self): self.assertEqual(server.extensions, []) def test_no_extension(self): - server = ServerConnection(extensions=[ServerOpExtensionFactory()]) + server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) request = self.make_request() response = server.accept(request) @@ -436,7 +436,7 @@ def test_no_extension(self): self.assertEqual(server.extensions, []) def test_extension(self): - server = ServerConnection(extensions=[ServerOpExtensionFactory()]) + server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) @@ -446,7 +446,7 @@ def test_extension(self): self.assertEqual(server.extensions, [OpExtension()]) def test_unexpected_extension(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) @@ -456,7 +456,7 @@ def test_unexpected_extension(self): self.assertEqual(server.extensions, []) def test_unsupported_extension(self): - server = ServerConnection(extensions=[ServerRsv2ExtensionFactory()]) + server = ServerProtocol(extensions=[ServerRsv2ExtensionFactory()]) request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) @@ -466,7 +466,7 @@ def test_unsupported_extension(self): self.assertEqual(server.extensions, []) def test_supported_extension_parameters(self): - server = ServerConnection(extensions=[ServerOpExtensionFactory("this")]) + server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" response = server.accept(request) @@ -476,7 +476,7 @@ def test_supported_extension_parameters(self): self.assertEqual(server.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): - server = ServerConnection(extensions=[ServerOpExtensionFactory("this")]) + server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) request = self.make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) @@ -486,7 +486,7 @@ def test_unsupported_extension_parameters(self): self.assertEqual(server.extensions, []) def test_multiple_supported_extension_parameters(self): - server = ServerConnection( + server = ServerProtocol( extensions=[ ServerOpExtensionFactory("this"), ServerOpExtensionFactory("that"), @@ -501,7 +501,7 @@ def test_multiple_supported_extension_parameters(self): self.assertEqual(server.extensions, [OpExtension("that")]) def test_multiple_extensions(self): - server = ServerConnection( + server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) request = self.make_request() @@ -516,7 +516,7 @@ def test_multiple_extensions(self): self.assertEqual(server.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): - server = ServerConnection( + server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) request = self.make_request() @@ -531,7 +531,7 @@ def test_multiple_extensions_order(self): self.assertEqual(server.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() response = server.accept(request) @@ -540,7 +540,7 @@ def test_no_subprotocols(self): self.assertIsNone(server.subprotocol) def test_no_subprotocol(self): - server = ServerConnection(subprotocols=["chat"]) + server = ServerProtocol(subprotocols=["chat"]) request = self.make_request() response = server.accept(request) @@ -549,7 +549,7 @@ def test_no_subprotocol(self): self.assertIsNone(server.subprotocol) def test_subprotocol(self): - server = ServerConnection(subprotocols=["chat"]) + server = ServerProtocol(subprotocols=["chat"]) request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) @@ -559,7 +559,7 @@ def test_subprotocol(self): self.assertEqual(server.subprotocol, "chat") def test_unexpected_subprotocol(self): - server = ServerConnection() + server = ServerProtocol() request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) @@ -569,7 +569,7 @@ def test_unexpected_subprotocol(self): self.assertIsNone(server.subprotocol) def test_multiple_subprotocols(self): - server = ServerConnection(subprotocols=["superchat", "chat"]) + server = ServerProtocol(subprotocols=["superchat", "chat"]) request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "superchat" request.headers["Sec-WebSocket-Protocol"] = "chat" @@ -580,7 +580,7 @@ def test_multiple_subprotocols(self): self.assertEqual(server.subprotocol, "superchat") def test_supported_subprotocol(self): - server = ServerConnection(subprotocols=["superchat", "chat"]) + server = ServerProtocol(subprotocols=["superchat", "chat"]) request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) @@ -590,7 +590,7 @@ def test_supported_subprotocol(self): self.assertEqual(server.subprotocol, "chat") def test_unsupported_subprotocol(self): - server = ServerConnection(subprotocols=["superchat", "chat"]) + server = ServerProtocol(subprotocols=["superchat", "chat"]) request = self.make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) @@ -602,7 +602,7 @@ def test_unsupported_subprotocol(self): class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - server = ServerConnection(state=OPEN) + server = ServerProtocol(state=OPEN) server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") [frame] = server.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) @@ -610,5 +610,17 @@ def test_bypass_handshake(self): def test_custom_logger(self): logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ServerConnection(logger=logger) + ServerProtocol(logger=logger) self.assertEqual(len(logs.records), 1) + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_server_connection_class(self): + with self.assertDeprecationWarning( + "ServerConnection was renamed to ServerProtocol" + ): + from websockets.server import ServerConnection + + server = ServerConnection("ws://localhost/") + + self.assertIsInstance(server, ServerProtocol) diff --git a/tests/utils.py b/tests/utils.py index ac891a0fd..92c754810 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,18 @@ +import contextlib import email.utils import unittest +import warnings DATE = email.utils.formatdate(usegmt=True) class GeneratorTestCase(unittest.TestCase): + """ + Base class for testing generator-based coroutines. + + """ + def assertGeneratorRunning(self, gen): """ Check that a generator-based coroutine hasn't completed yet. @@ -21,3 +28,25 @@ def assertGeneratorReturns(self, gen): with self.assertRaises(StopIteration) as raised: next(gen) return raised.exception.value + + +class DeprecationTestCase(unittest.TestCase): + """ + Base class for testing deprecations. + + """ + + @contextlib.contextmanager + def assertDeprecationWarning(self, message): + """ + Check that a deprecation warning was raised with the given message. + + """ + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") + yield + + self.assertEqual(len(recorded_warnings), 1) + warning = recorded_warnings[0] + self.assertEqual(warning.category, DeprecationWarning) + self.assertEqual(str(warning.message), message) From 35731196de91b91fa79573ad6235a27858027398 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 11:53:16 +0200 Subject: [PATCH 366/762] Change Sans-I/O constructors to keyword-only. --- docs/project/changelog.rst | 7 +++++++ src/websockets/client.py | 1 + src/websockets/protocol.py | 1 + src/websockets/server.py | 1 + tests/test_server.py | 2 +- 5 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 532eeffb2..aab9afdb1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -45,6 +45,13 @@ Backwards-incompatible changes ``client.ClientConnection`` classes were renamed to ``protocol.Protocol``, ``server.ServerProtocol``, and ``client.ClientProtocol``. +.. admonition:: Sans-I/O protocol constructors now use keyword-only arguments. + :class: caution + + If you instantiate :class:`~server.ServerProtocol` or + :class:`~client.ClientProtocol` directly, make sure you are using keyword + arguments. + .. admonition:: Closing a connection without an empty close frame is OK. :class: note diff --git a/src/websockets/client.py b/src/websockets/client.py index a439ab846..ad4c1a15d 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -70,6 +70,7 @@ class ClientProtocol(Protocol): def __init__( self, wsuri: WebSocketURI, + *, origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 29d7e1596..1f1af4a84 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -86,6 +86,7 @@ class Protocol: def __init__( self, side: Side, + *, state: State = OPEN, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, diff --git a/src/websockets/server.py b/src/websockets/server.py index 548048c92..cc5ba798a 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -69,6 +69,7 @@ class ServerProtocol(Protocol): def __init__( self, + *, origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, diff --git a/tests/test_server.py b/tests/test_server.py index ba7e05df2..2645dc843 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -621,6 +621,6 @@ def test_server_connection_class(self): ): from websockets.server import ServerConnection - server = ServerConnection("ws://localhost/") + server = ServerConnection() self.assertIsInstance(server, ServerProtocol) From 1363dd5be5b55fb3a16d58ca346c9f09e7519393 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 28 Nov 2022 07:28:09 +0100 Subject: [PATCH 367/762] Revert "Fix example of shutting down a client." This reverts commit 9e960b50. The example was correct. --- example/faq/shutdown_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/faq/shutdown_client.py b/example/faq/shutdown_client.py index 65cfca41b..539dd0304 100755 --- a/example/faq/shutdown_client.py +++ b/example/faq/shutdown_client.py @@ -10,7 +10,7 @@ async def client(): # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() loop.add_signal_handler( - signal.SIGTERM, loop.create_task, websocket.close) + signal.SIGTERM, loop.create_task, websocket.close()) # Process messages received on the connection. async for message in websocket: From 23a2d3f6dcd9056c94406b7354d0def89e46a720 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 13:37:22 +0200 Subject: [PATCH 368/762] Add API to customize subprotocol selection logic. Also changed the default logic: (1) to reject client connections that don't offer a subprotocol when the server is configured with subprotocols. This is the expected behavior for what I believe to be the default use case: require one particular subprotocol. This change of behavior isn't documented because I don't know anyone embedding the Sans-I/O layer and supporting subprotocol selection at this time. (2) to rely only on the order of preference of the server. Hey, trying to cater to the preferences of clients was nice, but the behavior was so needlessly complex that documentation apologized... This keeps things simple and still falls within the documented behavior. --- docs/project/changelog.rst | 5 +- docs/reference/server.rst | 2 + src/websockets/legacy/server.py | 24 +++---- src/websockets/server.py | 116 +++++++++++++++++++++----------- tests/test_server.py | 46 +++++++++++-- 5 files changed, 136 insertions(+), 57 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index aab9afdb1..34e1f909c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -37,7 +37,7 @@ Backwards-incompatible changes :class: caution Aliases provide compatibility for all previously public APIs according to - the `backwards-compatibility policy`_ + the `backwards-compatibility policy`_. * The ``connection`` module was renamed to ``protocol``. @@ -67,6 +67,9 @@ New features * Made it possible to close a server without closing existing connections. +* Added :attr:`~protocol.ServerProtocol.select_subprotocol` to customize + negotiation of subprotocols in the Sans-I/O layer. + 10.4 ---- diff --git a/docs/reference/server.rst b/docs/reference/server.rst index 50ef4ee3c..08bfe2f57 100644 --- a/docs/reference/server.rst +++ b/docs/reference/server.rst @@ -123,6 +123,8 @@ Sans-I/O .. automethod:: accept + .. automethod:: select_subprotocol + .. automethod:: reject .. automethod:: send_response diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index eabeb8e96..5ec5ed1a3 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -520,31 +520,29 @@ def select_subprotocol( server_subprotocols: Sequence[Subprotocol], ) -> Optional[Subprotocol]: """ - Pick a subprotocol among those offered by the client. + Pick a subprotocol among those supported by the client and the server. - If several subprotocols are supported by the client and the server, - the default implementation selects the preferred subprotocol by - giving equal value to the priorities of the client and the server. - If no subprotocol is supported by the client and the server, it - proceeds without a subprotocol. + If several subprotocols are available, select the preferred subprotocol + by giving equal weight to the preferences of the client and the server. - This is unlikely to be the most useful implementation in practice. - Many servers providing a subprotocol will require that the client - uses that subprotocol. Such rules can be implemented in a subclass. + If no subprotocol is available, proceed without a subprotocol. - You may also override this method with the ``select_subprotocol`` - argument of :func:`serve` and :class:`WebSocketServerProtocol`. + You may provide a ``select_subprotocol`` argument to :func:`serve` or + :class:`WebSocketServerProtocol` to override this logic. For example, + you could reject the handshake if the client doesn't support a + particular subprotocol, rather than accept the handshake without that + subprotocol. Args: client_subprotocols: list of subprotocols offered by the client. server_subprotocols: list of subprotocols available on the server. Returns: - Optional[Subprotocol]: Selected subprotocol. + Optional[Subprotocol]: Selected subprotocol, if a common subprotocol + was found. :obj:`None` to continue without a subprotocol. - """ if self._select_subprotocol is not None: return self._select_subprotocol(client_subprotocols, server_subprotocols) diff --git a/src/websockets/server.py b/src/websockets/server.py index cc5ba798a..547c516b5 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Generator, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -58,6 +58,10 @@ class ServerProtocol(Protocol): should be tried. subprotocols: list of supported subprotocols, in order of decreasing preference. + select_subprotocol: callback for selecting a subprotocol among + those supported by the client and the server. It has the same + signature as the :meth:`select_subprotocol` method, including a + :class:`ServerProtocol` instance as first argument. state: initial state of the WebSocket connection. max_size: maximum size of incoming messages in bytes; :obj:`None` to disable the limit. @@ -73,6 +77,7 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, + select_subprotocol: Optional[SelectSubprotocol] = None, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -86,6 +91,14 @@ def __init__( self.origins = origins self.available_extensions = extensions self.available_subprotocols = subprotocols + if select_subprotocol is not None: + # Bind select_subprotocol then shadow self.select_subprotocol. + # Use setattr to work around https://github.com/python/mypy/issues/2427. + setattr( + self, + "select_subprotocol", + select_subprotocol.__get__(self, self.__class__), + ) def accept(self, request: Request) -> Response: """ @@ -96,7 +109,7 @@ def accept(self, request: Request) -> Response: You must send the handshake response with :meth:`send_response`. - You can modify it before sending it, for example to add HTTP headers. + You may modify it before sending it, for example to add HTTP headers. Args: request: WebSocket handshake request event received from the client. @@ -175,7 +188,8 @@ def accept(self, request: Request) -> Response: return Response(101, "Switching Protocols", headers) def process_request( - self, request: Request + self, + request: Request, ) -> Tuple[str, Optional[str], Optional[str]]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -273,6 +287,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: Optional[Origin]: origin, if it is acceptable. Raises: + InvalidHandshake: if the Origin header is invalid. InvalidOrigin: if the origin isn't acceptable. """ @@ -323,7 +338,7 @@ def process_extensions( HTTP response header and list of accepted extensions. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid. """ response_header_value: Optional[str] = None @@ -383,60 +398,79 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: also the value of the ``Sec-WebSocket-Protocol`` response header. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid. """ - subprotocol: Optional[Subprotocol] = None - - header_values = headers.get_all("Sec-WebSocket-Protocol") - - if header_values and self.available_subprotocols: - - parsed_header_values: List[Subprotocol] = sum( - [parse_subprotocol(header_value) for header_value in header_values], [] - ) - - subprotocol = self.select_subprotocol( - parsed_header_values, self.available_subprotocols - ) + subprotocols: Sequence[Subprotocol] = sum( + [ + parse_subprotocol(header_value) + for header_value in headers.get_all("Sec-WebSocket-Protocol") + ], + [], + ) - return subprotocol + return self.select_subprotocol(subprotocols) def select_subprotocol( self, - client_subprotocols: Sequence[Subprotocol], - server_subprotocols: Sequence[Subprotocol], + subprotocols: Sequence[Subprotocol], ) -> Optional[Subprotocol]: """ Pick a subprotocol among those offered by the client. - If several subprotocols are supported by the client and the server, - the default implementation selects the preferred subprotocols by - giving equal value to the priorities of the client and the server. + If several subprotocols are supported by both the client and the server, + pick the first one in the list declared the server. + + If the server doesn't support any subprotocols, continue without a + subprotocol, regardless of what the client offers. - If no common subprotocol is supported by the client and the server, it - proceeds without a subprotocol. + If the server supports at least one subprotocol and the client doesn't + offer any, abort the handshake with an HTTP 400 error. - This is unlikely to be the most useful implementation in practice, as - many servers providing a subprotocol will require that the client uses - that subprotocol. + You provide a ``select_subprotocol`` argument to :class:`ServerProtocol` + to override this logic. For example, you could accept the connection + even if client doesn't offer a subprotocol, rather than reject it. + + Here's how to negotiate the ``chat`` subprotocol if the client supports + it and continue without a subprotocol otherwise:: + + def select_subprotocol(protocol, subprotocols): + if "chat" in subprotocols: + return "chat" Args: - client_subprotocols: list of subprotocols offered by the client. - server_subprotocols: list of subprotocols available on the server. + subprotocols: list of subprotocols offered by the client. Returns: - Optional[Subprotocol]: Subprotocol, if a common subprotocol was - found. + Optional[Subprotocol]: Selected subprotocol, if a common subprotocol + was found. + + :obj:`None` to continue without a subprotocol. + + Raises: + NegotiationError: custom implementations may raise this exception + to abort the handshake with an HTTP 400 error. """ - subprotocols = set(client_subprotocols) & set(server_subprotocols) - if not subprotocols: + # Server doesn't offer any subprotocols. + if not self.available_subprotocols: # None or empty list return None - priority = lambda p: ( - client_subprotocols.index(p) + server_subprotocols.index(p) + + # Server offers at least one subprotocol but client doesn't offer any. + if not subprotocols: + raise NegotiationError("missing subprotocol") + + # Server and client both offer subprotocols. Look for a shared one. + proposed_subprotocols = set(subprotocols) + for subprotocol in self.available_subprotocols: + if subprotocol in proposed_subprotocols: + return subprotocol + + # No common subprotocol was found. + raise NegotiationError( + "invalid subprotocol; expected one of " + + ", ".join(self.available_subprotocols) ) - return sorted(subprotocols, key=priority)[0] def reject( self, @@ -519,6 +553,12 @@ def parse(self) -> Generator[None, None, None]: yield from super().parse() +SelectSubprotocol = Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Optional[Subprotocol], +] + + class ServerConnection(ServerProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( diff --git a/tests/test_server.py b/tests/test_server.py index 2645dc843..1c0bbb292 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,7 +4,12 @@ import unittest.mock from websockets.datastructures import Headers -from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade +from websockets.exceptions import ( + InvalidHeader, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, +) from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN @@ -544,9 +549,12 @@ def test_no_subprotocol(self): request = self.make_request() response = server.accept(request) - self.assertEqual(response.status_code, 101) - self.assertNotIn("Sec-WebSocket-Protocol", response.headers) - self.assertIsNone(server.subprotocol) + self.assertEqual(response.status_code, 400) + with self.assertRaisesRegex( + NegotiationError, + r"missing subprotocol", + ): + raise server.handshake_exc def test_subprotocol(self): server = ServerProtocol(subprotocols=["chat"]) @@ -571,8 +579,8 @@ def test_unexpected_subprotocol(self): def test_multiple_subprotocols(self): server = ServerProtocol(subprotocols=["superchat", "chat"]) request = self.make_request() - request.headers["Sec-WebSocket-Protocol"] = "superchat" request.headers["Sec-WebSocket-Protocol"] = "chat" + request.headers["Sec-WebSocket-Protocol"] = "superchat" response = server.accept(request) self.assertEqual(response.status_code, 101) @@ -595,6 +603,34 @@ def test_unsupported_subprotocol(self): request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) + self.assertEqual(response.status_code, 400) + with self.assertRaisesRegex( + NegotiationError, + r"invalid subprotocol; expected one of superchat, chat", + ): + raise server.handshake_exc + + @staticmethod + def optional_chat(protocol, subprotocols): + if "chat" in subprotocols: + return "chat" + + def test_select_subprotocol(self): + server = ServerProtocol(select_subprotocol=self.optional_chat) + request = self.make_request() + request.headers["Sec-WebSocket-Protocol"] = "chat" + response = server.accept(request) + + self.assertEqual(response.status_code, 101) + self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") + self.assertEqual(server.subprotocol, "chat") + + def test_select_no_subprotocol(self): + server = ServerProtocol(select_subprotocol=self.optional_chat) + request = self.make_request() + request.headers["Sec-WebSocket-Protocol"] = "otherchat" + response = server.accept(request) + self.assertEqual(response.status_code, 101) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) From e4fcab16a70344e356268e50dc0c0cf541920c5d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Oct 2022 21:51:39 +0200 Subject: [PATCH 369/762] Handle exceptions when parsing opening handshake. --- src/websockets/client.py | 27 +++++++++++++++++---------- src/websockets/server.py | 11 ++++++++++- tests/test_client.py | 26 ++++++++++++++++++++++++++ tests/test_server.py | 18 ++++++++++++++++++ 4 files changed, 71 insertions(+), 11 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index ad4c1a15d..bfc8080a7 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -316,11 +316,17 @@ def send_request(self, request: Request) -> None: def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: - response = yield from Response.parse( - self.reader.read_line, - self.reader.read_exact, - self.reader.read_to_eof, - ) + try: + response = yield from Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + ) + except Exception as exc: + self.handshake_exc = exc + self.parser = self.discard() + next(self.parser) # start coroutine + yield if self.debug: code, phrase = response.status_code, response.reason_phrase @@ -334,14 +340,15 @@ def parse(self) -> Generator[None, None, None]: self.process_response(response) except InvalidHandshake as exc: response._exception = exc + self.events.append(response) self.handshake_exc = exc self.parser = self.discard() next(self.parser) # start coroutine - else: - assert self.state is CONNECTING - self.state = OPEN - finally: - self.events.append(response) + yield + + assert self.state is CONNECTING + self.state = OPEN + self.events.append(response) yield from super().parse() diff --git a/src/websockets/server.py b/src/websockets/server.py index 547c516b5..214b38bfa 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -541,7 +541,16 @@ def send_response(self, response: Response) -> None: def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: - request = yield from Request.parse(self.reader.read_line) + try: + request = yield from Request.parse( + self.reader.read_line, + ) + except Exception as exc: + self.handshake_exc = exc + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + yield if self.debug: self.logger.debug("< GET %s HTTP/1.1", request.path) diff --git a/tests/test_client.py b/tests/test_client.py index 718219b9d..c83c87038 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -218,6 +218,32 @@ def test_reject_response(self): ) self.assertEqual(response.body, b"Sorry folks.\n") + def test_no_response(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientProtocol(parse_uri("ws://example.com/test")) + client.connect() + client.receive_eof() + self.assertEqual(client.events_received(), []) + + def test_partial_response(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientProtocol(parse_uri("ws://example.com/test")) + client.connect() + client.receive_data(b"HTTP/1.1 101 Switching Protocols\r\n") + client.receive_eof() + self.assertEqual(client.events_received(), []) + + def test_random_response(self): + with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): + client = ClientProtocol(parse_uri("ws://example.com/test")) + client.connect() + client.receive_data(b"220 smtp.invalid\r\n") + client.receive_data(b"250 Hello relay.invalid\r\n") + client.receive_data(b"250 Ok\r\n") + client.receive_data(b"250 Ok\r\n") + client.receive_eof() + self.assertEqual(client.events_received(), []) + def make_accept_response(self, client): request = client.connect() return Response( diff --git a/tests/test_server.py b/tests/test_server.py index 1c0bbb292..ecf3d4cbe 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -72,6 +72,24 @@ def test_connect_request(self): ), ) + def test_no_request(self): + server = ServerProtocol() + server.receive_eof() + self.assertEqual(server.events_received(), []) + + def test_partial_request(self): + server = ServerProtocol() + server.receive_data(b"GET /test HTTP/1.1\r\n") + server.receive_eof() + self.assertEqual(server.events_received(), []) + + def test_random_request(self): + server = ServerProtocol() + server.receive_data(b"HELO relay.invalid\r\n") + server.receive_data(b"MAIL FROM: \r\n") + server.receive_data(b"RCPT TO: \r\n") + self.assertEqual(server.events_received(), []) + class AcceptRejectTests(unittest.TestCase): def make_request(self): From 3015447f5afbe5e6c913bf0a353777ce3dc45f80 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Jan 2023 08:47:56 +0100 Subject: [PATCH 370/762] Attempt to get a 10 on the OpenSSF check. --- SECURITY.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index 82024b485..175b20c58 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,5 +1,12 @@ -# Security policy +# Security + +## Policy Only the latest version receives security updates. -Please report vulnerabilities [via Tidelift](https://tidelift.com/security). +## Contact information + +Please report security vulnerabilities to the +[Tidelift security team](https://tidelift.com/security). + +Tidelift will coordinate the fix and disclosure. From 38b08fb72fb3c7e8358b3bba7cbe467c2a355aa9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Jan 2023 08:51:24 +0100 Subject: [PATCH 371/762] Increase max header length in legacy module. Fix #1243. Ref #1239. --- src/websockets/legacy/http.py | 4 ++-- tests/legacy/test_http.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index d9e44cc28..7cc3db844 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -10,8 +10,8 @@ __all__ = ["read_request", "read_response"] -MAX_HEADERS = 256 -MAX_LINE = 4110 +MAX_HEADERS = 128 +MAX_LINE = 8192 def d(value: bytes) -> str: diff --git a/tests/legacy/test_http.py b/tests/legacy/test_http.py index 5c9adc97f..15d53e08d 100644 --- a/tests/legacy/test_http.py +++ b/tests/legacy/test_http.py @@ -119,13 +119,13 @@ async def test_header_value(self): await read_headers(self.stream) async def test_headers_limit(self): - self.stream.feed_data(b"foo: bar\r\n" * 257 + b"\r\n") + self.stream.feed_data(b"foo: bar\r\n" * 129 + b"\r\n") with self.assertRaises(SecurityError): await read_headers(self.stream) async def test_line_limit(self): - # Header line contains 5 + 4104 + 2 = 4111 bytes. - self.stream.feed_data(b"foo: " + b"a" * 4104 + b"\r\n\r\n") + # Header line contains 5 + 8186 + 2 = 8193 bytes. + self.stream.feed_data(b"foo: " + b"a" * 8186 + b"\r\n\r\n") with self.assertRaises(SecurityError): await read_headers(self.stream) From 75bb1cb07a476e899b689c5f50872f90f98a38e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Jan 2023 09:14:09 +0100 Subject: [PATCH 372/762] Fix incorrect partial binding pattern in docs. Fix #1275. Supersede #1276. --- docs/faq/common.rst | 4 ++-- docs/faq/server.rst | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index dff64f67c..66e273056 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -125,11 +125,11 @@ You can bind additional arguments to the protocol factory with import websockets class MyServerProtocol(websockets.WebSocketServerProtocol): - def __init__(self, extra_argument, *args, **kwargs): + def __init__(self, *args, extra_argument=None, **kwargs): super().__init__(*args, **kwargs) # do something with extra_argument - create_protocol = functools.partial(MyServerProtocol, extra_argument='spam') + create_protocol = functools.partial(MyServerProtocol, extra_argument=42) start_server = websockets.serve(..., create_protocol=create_protocol) This example was for a server. The same pattern applies on a client. diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 4e4622dce..02a248cb8 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -185,7 +185,7 @@ You can bind additional arguments to the connection handler with async def handler(websocket, extra_argument): ... - bound_handler = functools.partial(handler, extra_argument='spam') + bound_handler = functools.partial(handler, extra_argument=42) start_server = websockets.serve(bound_handler, ...) Another way to achieve this result is to define the ``handler`` coroutine in From f42fd7ba34d40b9d9a800916fe0d27bf21c13656 Mon Sep 17 00:00:00 2001 From: ooliver1 Date: Sat, 10 Dec 2022 20:13:15 +0000 Subject: [PATCH 373/762] Refactor create_protocol to actually allow subclasses --- src/websockets/legacy/auth.py | 2 +- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index ac24c179e..def36a39c 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -118,7 +118,7 @@ def basic_auth_protocol_factory( realm: Optional[str] = None, credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, - create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None, + create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, ) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 9b953df8f..4981ac9bd 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -438,7 +438,7 @@ def __init__( self, uri: str, *, - create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None, + create_protocol: Optional[Callable[..., WebSocketClientProtocol]] = None, logger: Optional[LoggerLike] = None, compression: Optional[str] = "deflate", origin: Optional[Origin] = None, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 5ec5ed1a3..6f8833a88 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -989,7 +989,7 @@ def __init__( host: Optional[Union[str, Sequence[str]]] = None, port: Optional[int] = None, *, - create_protocol: Optional[Callable[[Any], WebSocketServerProtocol]] = None, + create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None, logger: Optional[LoggerLike] = None, compression: Optional[str] = "deflate", origins: Optional[Sequence[Optional[Origin]]] = None, From 716245215fab4a6937a9514eb0c0e1939dfc4fcd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 18 Jan 2023 07:24:46 +0100 Subject: [PATCH 374/762] Follow-up on f42fd7ba. --- src/websockets/legacy/auth.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index def36a39c..3511469e6 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -119,7 +119,7 @@ def basic_auth_protocol_factory( credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, -) -> Callable[[Any], BasicAuthWebSocketServerProtocol]: +) -> Callable[..., BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. @@ -175,11 +175,7 @@ async def check_credentials(username: str, password: str) -> bool: return hmac.compare_digest(expected_password, password) if create_protocol is None: - # Not sure why mypy cannot figure this out. - create_protocol = cast( - Callable[[Any], BasicAuthWebSocketServerProtocol], - BasicAuthWebSocketServerProtocol, - ) + create_protocol = BasicAuthWebSocketServerProtocol return functools.partial( create_protocol, From f2176ebc682742ec6a00646663c108bc400451bb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 18 Jan 2023 07:36:10 +0100 Subject: [PATCH 375/762] Fix errors in the documentation of 23a2d3f6. --- docs/project/changelog.rst | 2 +- src/websockets/server.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 34e1f909c..cc3ebc091 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -67,7 +67,7 @@ New features * Made it possible to close a server without closing existing connections. -* Added :attr:`~protocol.ServerProtocol.select_subprotocol` to customize +* Added :attr:`~server.ServerProtocol.select_subprotocol` to customize negotiation of subprotocols in the Sans-I/O layer. 10.4 diff --git a/src/websockets/server.py b/src/websockets/server.py index 214b38bfa..148b4ec11 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -77,7 +77,12 @@ def __init__( origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[SelectSubprotocol] = None, + select_subprotocol: Optional[ + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Optional[Subprotocol], + ] + ] = None, state: State = CONNECTING, max_size: Optional[int] = 2**20, logger: Optional[LoggerLike] = None, @@ -562,12 +567,6 @@ def parse(self) -> Generator[None, None, None]: yield from super().parse() -SelectSubprotocol = Callable[ - [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], -] - - class ServerConnection(ServerProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( From a525950c84a60151f261f6282fa80ff310954718 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 13:35:29 +0100 Subject: [PATCH 376/762] Fix typos in comments. Fix #1284. Thank you @cclauss! --- src/websockets/frames.py | 4 ++-- src/websockets/legacy/protocol.py | 2 +- src/websockets/protocol.py | 2 +- src/websockets/server.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 45d006e3f..52d81746d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -127,7 +127,7 @@ class Frame: def __str__(self) -> str: """ - Return a human-readable represention of a frame. + Return a human-readable representation of a frame. """ coding = None @@ -389,7 +389,7 @@ class Close: def __str__(self) -> str: """ - Return a human-readable represention of a close code and reason. + Return a human-readable representation of a close code and reason. """ if 3000 <= self.code < 4000: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7881b947d..f6c419c3a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1120,7 +1120,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: try: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame - # is empty and Close.parse() synthetizes a 1005 close code. + # is empty and Close.parse() synthesizes a 1005 close code. await self.write_close_frame(self.close_rcvd, frame.data) except ConnectionClosed: # Connection closed before we could echo the close frame. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1f1af4a84..e5e8826f6 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -653,7 +653,7 @@ def recv_frame(self, frame: Frame) -> None: if self.state is OPEN: # Echo the original data instead of re-serializing it with # Close.serialize() because that fails when the close frame - # is empty and Close.parse() synthetizes a 1005 close code. + # is empty and Close.parse() synthesizes a 1005 close code. # The rest is identical to send_close(). self.send_frame(Frame(OP_CLOSE, frame.data)) self.close_sent = self.close_rcvd diff --git a/src/websockets/server.py b/src/websockets/server.py index 148b4ec11..5c73d7e07 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -512,7 +512,7 @@ def reject( response = Response(status.value, status.phrase, headers, body) # When reject() is called from accept(), handshake_exc is already set. # If a user calls reject(), set handshake_exc to guarantee invariant: - # "handshake_exc is None if and only if opening handshake succeded." + # "handshake_exc is None if and only if opening handshake succeeded." if self.handshake_exc is None: self.handshake_exc = InvalidStatus(response) self.logger.info("connection failed (%d %s)", status.value, status.phrase) From 87657de0edefa9c1d5b305d48c9cb0d28bdfc9d1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 15:09:44 +0100 Subject: [PATCH 377/762] Small FAQ updates. --- docs/faq/asyncio.rst | 2 +- docs/faq/common.rst | 6 ++---- docs/faq/misc.rst | 4 ++-- docs/faq/server.rst | 4 ++-- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index 7db43e1be..d00cf3f47 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -1,4 +1,4 @@ -asyncio usage +Using asyncio ============= .. currentmodule:: websockets diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 66e273056..2a512ea90 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -109,10 +109,8 @@ Use :func:`~asyncio.wait_for`:: await asyncio.wait_for(websocket.recv(), timeout=10) -This technique works for most APIs, except for asynchronous context managers. -See `issue 574`_. - -.. _issue 574: https://github.com/aaugustin/websockets/issues/574 +This technique works for most APIs. When it doesn't, for example with +asynchronous context managers, websockets provides an ``open_timeout`` argument. How can I pass arguments to a custom protocol subclass? ------------------------------------------------------- diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 681c5e45a..15e520fdd 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -33,8 +33,8 @@ Instead, use the real import paths e.g.:: websockets.client.connect(...) websockets.server.serve(...) -Why is websockets slower than another Python library in my benchmark? -..................................................................... +Why is websockets slower than another library in my benchmark? +.............................................................. Not all libraries are as feature-complete as websockets. For a fair benchmark, you should disable features that the other library doesn't provide. Typically, diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 02a248cb8..d1388c701 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -139,7 +139,7 @@ Add error handling according to the behavior you want if the user disconnected before the message could be sent. This example supports only one connection per user. To support concurrent -connects by the same user, you can change ``CONNECTIONS`` to store a set of +connections by the same user, you can change ``CONNECTIONS`` to store a set of connections for each user. If you're running multiple server processes, call ``message_user`` in each @@ -213,7 +213,7 @@ client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to authenticate the connection before routing it, this is usually more convenient. Generally speaking, there is far less emphasis on the request path in WebSocket -servers than in HTTP servers. When a WebSockt server provides a single endpoint, +servers than in HTTP servers. When a WebSocket server provides a single endpoint, it may ignore the request path entirely. How do I access HTTP headers? From bc8b3a85fe3d83378257d5ef2630e9c03714da44 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 16:45:29 +0100 Subject: [PATCH 378/762] Explain the legacy submodule. Fix #1297. --- docs/faq/misc.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 15e520fdd..72e7d6d56 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -33,6 +33,19 @@ Instead, use the real import paths e.g.:: websockets.client.connect(...) websockets.server.serve(...) +Why is the default implementation located in ``websockets.legacy``? +................................................................... + +This is an artifact of websockets' history. For its first eight years, only the +:mod:`asyncio`-based implementation existed. Then, the Sans-I/O implementation +was added. Moving the code in a ``legacy`` submodule eased this refactoring and +optimized maintainability. + +All public APIs were kept at their original locations. ``websockets.legacy`` +isn't a public API. It's only visible in the source code and in stack traces. +There is no intent to deprecate this implementation — at least until a superior +alternative exists. + Why is websockets slower than another library in my benchmark? .............................................................. From 8f0a33c5fb962036a2eb389a14c5bfaf58dfffa6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 16:45:59 +0100 Subject: [PATCH 379/762] Add missing words to spellchecker. --- docs/spelling_wordlist.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1d342515b..9cc2182e1 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -18,6 +18,7 @@ coroutine coroutines cryptocurrencies cryptocurrency +css ctrl deserialize django @@ -27,11 +28,13 @@ formatter fractalideas gunicorn healthz +html hypercorn iframe IPv istio iterable +js keepalive KiB kubernetes @@ -40,12 +43,15 @@ linkerd liveness lookups MiB +mutex mypy nginx Paketo permessage pid +procfile proxying +py pythonic reconnection redis @@ -56,6 +62,7 @@ scalable stateful subclasses subclassing +submodule subpackages subprotocol subprotocols @@ -63,6 +70,7 @@ supervisord tidelift tls tox +txt unregister uple uvicorn From 206d7ef5eea27791684f815f65a01fa23d89d3b2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 16:53:49 +0100 Subject: [PATCH 380/762] Restructure API reference. --- docs/faq/misc.rst | 2 + docs/project/changelog.rst | 5 +- docs/reference/{ => asyncio}/client.rst | 58 +--------- docs/reference/{ => asyncio}/common.rst | 68 +----------- docs/reference/{ => asyncio}/server.rst | 69 ++---------- .../{utilities.rst => datastructures.rst} | 9 +- docs/reference/index.rst | 101 +++++++++++------- docs/reference/sansio/client.rst | 58 ++++++++++ docs/reference/sansio/common.rst | 62 +++++++++++ docs/reference/sansio/server.rst | 62 +++++++++++ src/websockets/client.py | 2 +- src/websockets/legacy/client.py | 4 +- src/websockets/legacy/protocol.py | 6 +- src/websockets/legacy/server.py | 6 +- src/websockets/protocol.py | 2 +- src/websockets/server.py | 2 +- 16 files changed, 278 insertions(+), 238 deletions(-) rename docs/reference/{ => asyncio}/client.rst (69%) rename docs/reference/{ => asyncio}/common.rst (52%) rename docs/reference/{ => asyncio}/server.rst (73%) rename docs/reference/{utilities.rst => datastructures.rst} (90%) create mode 100644 docs/reference/sansio/client.rst create mode 100644 docs/reference/sansio/common.rst create mode 100644 docs/reference/sansio/server.rst diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 72e7d6d56..e320cb808 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -10,6 +10,8 @@ Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. +.. _real-import-paths: + Why does my IDE fail to show documentation for websockets APIs? ............................................................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index cc3ebc091..85a4ac92d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -22,8 +22,9 @@ When a release contains backwards-incompatible API changes, the major version is increased, else the minor version is increased. Patch versions are only for fixing regressions shortly after a release. -Only documented APIs are public. Undocumented APIs are considered private. -They may change at any time. +Only documented API are public. Undocumented, private API may change without +notice. + 11.0 ---- diff --git a/docs/reference/client.rst b/docs/reference/asyncio/client.rst similarity index 69% rename from docs/reference/client.rst rename to docs/reference/asyncio/client.rst index 44f053b1e..5086015b7 100644 --- a/docs/reference/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,13 +1,10 @@ -Client -====== +Client (:mod:`asyncio`) +======================= .. automodule:: websockets.client -asyncio -------- - Opening a connection -.................... +-------------------- .. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: @@ -16,7 +13,7 @@ Opening a connection :async: Using a connection -.................. +------------------ .. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) @@ -65,50 +62,3 @@ Using a connection .. autoproperty:: close_code .. autoproperty:: close_reason - -Sans-I/O --------- - -.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) - - .. automethod:: receive_data - - .. automethod:: receive_eof - - .. automethod:: connect - - .. automethod:: send_request - - .. automethod:: send_continuation - - .. automethod:: send_text - - .. automethod:: send_binary - - .. automethod:: send_close - - .. automethod:: send_ping - - .. automethod:: send_pong - - .. automethod:: fail - - .. automethod:: events_received - - .. automethod:: data_to_send - - .. automethod:: close_expected - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: state - - .. autoattribute:: handshake_exc - - .. autoproperty:: close_code - - .. autoproperty:: close_reason - - .. autoproperty:: close_exc diff --git a/docs/reference/common.rst b/docs/reference/asyncio/common.rst similarity index 52% rename from docs/reference/common.rst rename to docs/reference/asyncio/common.rst index b42f5ea3e..ee8dc54ac 100644 --- a/docs/reference/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,8 +1,5 @@ -Both sides -========== - -asyncio -------- +Both sides (:mod:`asyncio`) +=========================== .. automodule:: websockets.legacy.protocol @@ -53,64 +50,3 @@ asyncio .. autoproperty:: close_code .. autoproperty:: close_reason - -Sans-I/O --------- - -.. automodule:: websockets.protocol - -.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) - - .. automethod:: receive_data - - .. automethod:: receive_eof - - .. automethod:: send_continuation - - .. automethod:: send_text - - .. automethod:: send_binary - - .. automethod:: send_close - - .. automethod:: send_ping - - .. automethod:: send_pong - - .. automethod:: fail - - .. automethod:: events_received - - .. automethod:: data_to_send - - .. automethod:: close_expected - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: state - - .. autoproperty:: close_code - - .. autoproperty:: close_reason - - .. autoproperty:: close_exc - -.. autoclass:: Side - - .. autoattribute:: SERVER - - .. autoattribute:: CLIENT - -.. autoclass:: State - - .. autoattribute:: CONNECTING - - .. autoattribute:: OPEN - - .. autoattribute:: CLOSING - - .. autoattribute:: CLOSED - -.. autodata:: SEND_EOF diff --git a/docs/reference/server.rst b/docs/reference/asyncio/server.rst similarity index 73% rename from docs/reference/server.rst rename to docs/reference/asyncio/server.rst index 08bfe2f57..106317916 100644 --- a/docs/reference/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,13 +1,10 @@ -Server -====== +Server (:mod:`asyncio`) +======================= .. automodule:: websockets.server -asyncio -------- - Starting a server -................. +----------------- .. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: @@ -16,7 +13,7 @@ Starting a server :async: Stopping a server -................. +----------------- .. autoclass:: WebSocketServer @@ -35,7 +32,7 @@ Stopping a server .. autoattribute:: sockets Using a connection -.................. +------------------ .. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) @@ -93,7 +90,7 @@ Using a connection Basic authentication -.................... +-------------------- .. automodule:: websockets.auth @@ -110,55 +107,7 @@ websockets supports HTTP Basic Authentication according to .. automethod:: check_credentials -.. currentmodule:: websockets.server - -Sans-I/O --------- - -.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) - - .. automethod:: receive_data - - .. automethod:: receive_eof - - .. automethod:: accept - - .. automethod:: select_subprotocol - - .. automethod:: reject - - .. automethod:: send_response - - .. automethod:: send_continuation - - .. automethod:: send_text - - .. automethod:: send_binary - - .. automethod:: send_close - - .. automethod:: send_ping - - .. automethod:: send_pong - - .. automethod:: fail - - .. automethod:: events_received - - .. automethod:: data_to_send - - .. automethod:: close_expected - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: state - - .. autoattribute:: handshake_exc - - .. autoproperty:: close_code - - .. autoproperty:: close_reason +Broadcast +--------- - .. autoproperty:: close_exc +.. autofunction:: websockets.broadcast diff --git a/docs/reference/utilities.rst b/docs/reference/datastructures.rst similarity index 90% rename from docs/reference/utilities.rst rename to docs/reference/datastructures.rst index 6b5d402fc..8217052d1 100644 --- a/docs/reference/utilities.rst +++ b/docs/reference/datastructures.rst @@ -1,10 +1,5 @@ -Utilities -========= - -Broadcast ---------- - -.. autofunction:: websockets.broadcast +Data structures +=============== WebSocket events ---------------- diff --git a/docs/reference/index.rst b/docs/reference/index.rst index f164dde91..3b708ef91 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -3,66 +3,91 @@ API reference .. currentmodule:: websockets -websockets provides client and server implementations, as shown in -the :doc:`getting started guide <../intro/index>`. +:mod:`asyncio` +-------------- -The process for opening and closing a WebSocket connection depends on which -side you're implementing. +This is the default implementation. It's ideal for servers that handle many +clients concurrently. -* On the client side, connecting to a server with :func:`~client.connect` - yields a connection object that provides methods for interacting with the - connection. Your code can open a connection, then send or receive messages. +.. toctree:: + :titlesonly: + + asyncio/server + asyncio/client + asyncio/common + +`Sans-I/O`_ +----------- + +This layer is designed for integrating in third-party libraries, typically +application servers. + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +.. toctree:: + :titlesonly: - If you use :func:`~client.connect` as an asynchronous context manager, - then websockets closes the connection on exit. If not, then your code is - responsible for closing the connection. + sansio/server + sansio/client + sansio/common -* On the server side, :func:`~server.serve` starts listening for client - connections and yields an server object that you can use to shut down - the server. +Extensions +---------- - Then, when a client connects, the server initializes a connection object and - passes it to a handler coroutine, which is where your code can send or - receive messages. This pattern is called `inversion of control`_. It's - common in frameworks implementing servers. +The Per-Message Deflate extension is built in. You may also define custom +extensions. - When the handler coroutine terminates, websockets closes the connection. You - may also close it in the handler coroutine if you'd like. +.. toctree:: + :titlesonly: -.. _inversion of control: https://en.wikipedia.org/wiki/Inversion_of_control + extensions + +Shared +------ -Once the connection is open, the WebSocket protocol is symmetrical, except for -low-level details that websockets manages under the hood. The same methods -are available on client connections created with :func:`~client.connect` and -on server connections received in argument by the connection handler -of :func:`~server.serve`. +These low-level API are shared by all implementations. .. toctree:: :titlesonly: - server - client - common - utilities + datastructures exceptions types - extensions - limitations -Public API documented in the API reference are subject to the +API stability +------------- + +Public API documented in this API reference are subject to the :ref:`backwards-compatibility policy `. Anything that isn't listed in the API reference is a private API. There's no guarantees of behavior or backwards-compatibility for private APIs. +Convenience imports +------------------- + +For convenience, many public APIs can be imported directly from the +``websockets`` package. + + .. admonition:: Convenience imports are incompatible with some development tools. :class: caution - For convenience, most public APIs can be imported from the ``websockets`` - package. However, this is incompatible with static code analysis. - - It may break auto-completion and contextual documentation in IDEs, type - checking with mypy_, etc. If you're using such tools, stick to the full - import paths. + Specifically, static code analysis tools don't understand them. This breaks + auto-completion and contextual documentation in IDEs, type checking with + mypy_, etc. .. _mypy: https://github.com/python/mypy + + If you're using such tools, stick to the full import paths, as explained in + this FAQ: :ref:`real-import-paths` + +Limitations +----------- + +There are a few known limitations in the current API. + +.. toctree:: + :titlesonly: + + limitations diff --git a/docs/reference/sansio/client.rst b/docs/reference/sansio/client.rst new file mode 100644 index 000000000..09bafc745 --- /dev/null +++ b/docs/reference/sansio/client.rst @@ -0,0 +1,58 @@ +Client (`Sans-I/O`_) +==================== + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +.. currentmodule:: websockets.client + +.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: connect + + .. automethod:: send_request + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + WebSocket protocol objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: handshake_exc + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + .. autoproperty:: close_exc diff --git a/docs/reference/sansio/common.rst b/docs/reference/sansio/common.rst new file mode 100644 index 000000000..2678c1361 --- /dev/null +++ b/docs/reference/sansio/common.rst @@ -0,0 +1,62 @@ +Both sides (`Sans-I/O`_) +========================= + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +.. automodule:: websockets.protocol + +.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: state + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + .. autoproperty:: close_exc + +.. autoclass:: Side + + .. autoattribute:: SERVER + + .. autoattribute:: CLIENT + +.. autoclass:: State + + .. autoattribute:: CONNECTING + + .. autoattribute:: OPEN + + .. autoattribute:: CLOSING + + .. autoattribute:: CLOSED + +.. autodata:: SEND_EOF diff --git a/docs/reference/sansio/server.rst b/docs/reference/sansio/server.rst new file mode 100644 index 000000000..d70df6277 --- /dev/null +++ b/docs/reference/sansio/server.rst @@ -0,0 +1,62 @@ +Server (`Sans-I/O`_) +==================== + +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +.. currentmodule:: websockets.server + +.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) + + .. automethod:: receive_data + + .. automethod:: receive_eof + + .. automethod:: accept + + .. automethod:: select_subprotocol + + .. automethod:: reject + + .. automethod:: send_response + + .. automethod:: send_continuation + + .. automethod:: send_text + + .. automethod:: send_binary + + .. automethod:: send_close + + .. automethod:: send_ping + + .. automethod:: send_pong + + .. automethod:: fail + + .. automethod:: events_received + + .. automethod:: data_to_send + + .. automethod:: close_expected + + WebSocket protocol objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: handshake_exc + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + .. autoproperty:: close_exc diff --git a/src/websockets/client.py b/src/websockets/client.py index bfc8080a7..df4f067d3 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -63,7 +63,7 @@ class ClientProtocol(Protocol): :obj:`None` to disable the limit. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. """ diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 4981ac9bd..1e59a56c9 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -385,10 +385,10 @@ class Connect: be set to a wrapper or a subclass to customize connection handling. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. compression: shortcut that enables the "permessage-deflate" extension by default; may be set to :obj:`None` to disable compression; - see the :doc:`compression guide <../topics/compression>` for details. + see the :doc:`compression guide <../../topics/compression>` for details. origin: value of the ``Origin`` header. This is useful when connecting to a server that validates the ``Origin`` header to defend against Cross-Site WebSocket Hijacking attacks. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index f6c419c3a..374aa476d 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -98,7 +98,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds for clients and ``4 * close_timeout`` for servers. - See the discussion of :doc:`timeouts <../topics/timeouts>` for details. + See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. ``close_timeout`` needs to be a parameter of the protocol because websockets usually calls :meth:`close` implicitly upon exit: @@ -141,12 +141,12 @@ class WebSocketCommonProtocol(asyncio.Protocol): The default value is 64 KiB, equal to asyncio's default (based on the current implementation of ``FlowControlMixin``). - See the discussion of :doc:`memory usage <../topics/memory>` for details. + See the discussion of :doc:`memory usage <../../topics/memory>` for details. Args: logger: logger for this connection; defaults to ``logging.getLogger("websockets.protocol")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. ping_interval: delay between keepalive pings in seconds; :obj:`None` to disable keepalive pings. ping_timeout: timeout for keepalive pings in seconds; diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 6f8833a88..72f0a8203 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -664,7 +664,7 @@ class WebSocketServer: Args: logger: logger for this server; defaults to ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. """ @@ -934,10 +934,10 @@ class Serve: be set to a wrapper or a subclass to customize connection handling. logger: logger for this server; defaults to ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. compression: shortcut that enables the "permessage-deflate" extension by default; may be set to :obj:`None` to disable compression; - see the :doc:`compression guide <../topics/compression>` for details. + see the :doc:`compression guide <../../topics/compression>` for details. origins: acceptable values of the ``Origin`` header; include :obj:`None` in the list if the lack of an origin is acceptable. This is useful for defending against Cross-Site WebSocket diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index e5e8826f6..7bfa96f8b 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -79,7 +79,7 @@ class Protocol: logger: logger for this connection; depending on ``side``, defaults to ``logging.getLogger("websockets.client")`` or ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. """ diff --git a/src/websockets/server.py b/src/websockets/server.py index 5c73d7e07..9415a99a8 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -67,7 +67,7 @@ class ServerProtocol(Protocol): :obj:`None` to disable the limit. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; - see the :doc:`logging guide <../topics/logging>` for details. + see the :doc:`logging guide <../../topics/logging>` for details. """ From 4aa91dc90ccd743209b02e308e498ce0f57616f6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:19:50 +0100 Subject: [PATCH 381/762] Upgrade to black 2023 style. Changes stem from https://github.com/psf/black/pull/3035. --- src/websockets/client.py | 4 ---- src/websockets/extensions/permessage_deflate.py | 1 - src/websockets/legacy/client.py | 4 ---- src/websockets/legacy/framing.py | 1 - src/websockets/legacy/protocol.py | 2 -- src/websockets/legacy/server.py | 5 ----- src/websockets/server.py | 3 --- tests/legacy/test_auth.py | 1 - tests/legacy/test_client_server.py | 4 ---- 9 files changed, 25 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index df4f067d3..b5f871571 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -222,7 +222,6 @@ def process_extensions(self, headers: Headers) -> List[Extension]: extensions = headers.get_all("Sec-WebSocket-Extensions") if extensions: - if self.available_extensions is None: raise InvalidHandshake("no extensions supported") @@ -231,9 +230,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: ) for name, response_params in parsed_extensions: - for extension_factory in self.available_extensions: - # Skip non-matching extensions based on their name. if extension_factory.name != name: continue @@ -280,7 +277,6 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: subprotocols = headers.get_all("Sec-WebSocket-Protocol") if subprotocols: - if self.available_subprotocols is None: raise InvalidHandshake("no subprotocols supported") diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index e0de5e8f8..b391837c6 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -211,7 +211,6 @@ def _extract_parameters( client_max_window_bits: Optional[Union[int, bool]] = None for name, value in params: - if name == "server_no_context_takeover": if server_no_context_takeover: raise exceptions.DuplicateParameter(name) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 1e59a56c9..f8876f59d 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -188,7 +188,6 @@ def process_extensions( header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values: - if available_extensions is None: raise InvalidHandshake("no extensions supported") @@ -197,9 +196,7 @@ def process_extensions( ) for name, response_params in parsed_header_values: - for extension_factory in available_extensions: - # Skip non-matching extensions based on their name. if extension_factory.name != name: continue @@ -245,7 +242,6 @@ def process_subprotocol( header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values: - if available_subprotocols is None: raise InvalidHandshake("no subprotocols supported") diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 04cddc0e0..29864b136 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -14,7 +14,6 @@ class Frame(NamedTuple): - fin: bool opcode: frames.Opcode data: bytes diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 374aa476d..d31ec19a8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -652,7 +652,6 @@ async def send( # Fragmented message -- regular iterator. elif isinstance(message, Iterable): - # Work around https://github.com/python/mypy/issues/6227 message = cast(Iterable[Data], message) @@ -1519,7 +1518,6 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self.connection_lost_waiter.set_result(None) if True: # pragma: no cover - # Copied from asyncio.StreamReaderProtocol if self.reader is not None: if exc is None: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 72f0a8203..048a270b5 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -160,7 +160,6 @@ async def handler(self) -> None: """ try: - try: await self.handshake( origins=self.origins, @@ -443,15 +442,12 @@ def process_extensions( header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) for name, request_params in parsed_header_values: - for ext_factory in available_extensions: - # Skip non-matching extensions based on their name. if ext_factory.name != name: continue @@ -503,7 +499,6 @@ def process_subprotocol( header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: - parsed_header_values: List[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) diff --git a/src/websockets/server.py b/src/websockets/server.py index 9415a99a8..16979e7e7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -354,15 +354,12 @@ def process_extensions( header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and self.available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) for name, request_params in parsed_header_values: - for ext_factory in self.available_extensions: - # Skip non-matching extensions based on their name. if ext_factory.name != name: continue diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index 2b670c31f..3754bcf3a 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -32,7 +32,6 @@ async def check_credentials(self, username, password): class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase): - create_protocol = basic_auth_protocol_factory( realm="auth-tests", credentials=("hello", "iloveyou") ) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 4a4510536..b05d40721 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -212,7 +212,6 @@ class BarClientProtocol(WebSocketClientProtocol): class ClientServerTestsMixin: - secure = False def setUp(self): @@ -309,7 +308,6 @@ def make_http_request(self, path="/", headers=None): class SecureClientServerTestsMixin(ClientServerTestsMixin): - secure = True @property @@ -1299,7 +1297,6 @@ class ClientServerTests( class SecureClientServerTests( CommonClientServerTests, SecureClientServerTestsMixin, AsyncioTestCase ): - # The implementation of this test makes it hard to run it over TLS. test_client_connect_canceled_during_handshake = None @@ -1462,7 +1459,6 @@ async def run_server(path): class AsyncIteratorTests(ClientServerTestsMixin, AsyncioTestCase): - # This is a protocol-level feature, but since it's a high-level API, it is # much easier to exercise at the client or server level. From 1418917e54f56e557b3975a171aada842db8f2e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:59:24 +0100 Subject: [PATCH 382/762] Environment variables are always strings. --- tests/legacy/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 302fd70be..6195849bd 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -88,7 +88,7 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) +MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) # asyncio's debug mode has a 10x performance penalty for this test suite. if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover From 8e1628a14e0dd2ca98871c7500484b5d42d16b67 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 20:12:35 +0100 Subject: [PATCH 383/762] Ignore files from direnv. --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 8f9e7dc51..324e77069 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ *.pyc *.so .coverage +.direnv +.envrc .idea/ .mypy_cache .tox From ba1ed7a65cc876ff4e0fcd4dd4711402836475e2 Mon Sep 17 00:00:00 2001 From: Sasja Date: Tue, 28 Feb 2023 12:47:07 +0800 Subject: [PATCH 384/762] fix small docs typo --- docs/topics/broadcast.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 6c7ced8b0..9a25cbf7d 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -344,5 +344,5 @@ All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower than the built-in :func:`broadcast` function. -There is no major difference between the performance of per-message queues and +There is no major difference between the performance of per-client queues and publish–subscribe. From 17ffb5c9777f21d35ada85155db8c70be490cc3b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 28 Mar 2023 08:12:27 +0200 Subject: [PATCH 385/762] Disable social cards in docs. --- docs/conf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index fe6282b5a..58f0c2d55 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,6 +133,11 @@ def linkcode_resolve(domain, info): return f"{code_url}/{file}#L{start}-L{end}" +# Configure opengraph extension + +# Social cards don't support the SVG logo. Also, the text preview looks bad. +ogp_social_cards = {"enable": False} + # -- Options for HTML output ------------------------------------------------- From 3c96411902d34f87bb8e6515387df7fbafc86869 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 18:04:00 +0100 Subject: [PATCH 386/762] Move MS constant out of legacy directory. --- tests/legacy/test_client_server.py | 3 ++- tests/legacy/test_protocol.py | 3 ++- tests/legacy/utils.py | 19 ------------------- tests/utils.py | 19 +++++++++++++++++++ 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index b05d40721..752f94270 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -42,7 +42,8 @@ NoOpExtension, ServerNoOpExtensionFactory, ) -from .utils import MS, AsyncioTestCase +from ..utils import MS +from .utils import AsyncioTestCase # Generate TLS certificate with: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index e85402a39..328bc80a2 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -19,7 +19,8 @@ from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast from websockets.protocol import State -from .utils import MS, AsyncioTestCase +from ..utils import MS +from .utils import AsyncioTestCase async def async_iterable(iterable): diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 6195849bd..bb4eebb52 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -2,9 +2,6 @@ import contextlib import functools import logging -import os -import platform -import time import unittest @@ -84,19 +81,3 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): set(str(recorded.message) for recorded in recorded_warnings), set(expected_warnings), ) - - -# Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) - -# asyncio's debug mode has a 10x performance penalty for this test suite. -if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover - MS *= 10 - -# PyPy has a performance penalty for this test suite. -if platform.python_implementation() == "PyPy": # pragma: no cover - MS *= 5 - -# Ensure that timeouts are larger than the clock's resolution (for Windows). -MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) diff --git a/tests/utils.py b/tests/utils.py index 92c754810..5331746f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,8 @@ import contextlib import email.utils +import os +import platform +import time import unittest import warnings @@ -7,6 +10,22 @@ DATE = email.utils.formatdate(usegmt=True) +# Unit for timeouts. May be increased on slow machines by setting the +# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. +MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) + +# PyPy has a performance penalty for this test suite. +if platform.python_implementation() == "PyPy": # pragma: no cover + MS *= 5 + +# asyncio's debug mode has a 10x performance penalty for this test suite. +if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover + MS *= 10 + +# Ensure that timeouts are larger than the clock's resolution (for Windows). +MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) + + class GeneratorTestCase(unittest.TestCase): """ Base class for testing generator-based coroutines. From 445cdf9b05d7f1e2ac9ff9dd5e254cbc57ab6ced Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:54:23 +0100 Subject: [PATCH 387/762] Move test certificate out of legacy directory. Also add a second domain name for tests. --- tests/legacy/test_client_server.py | 15 ++---- tests/test_localhost.cnf | 5 +- tests/test_localhost.pem | 84 +++++++++++++++--------------- tests/utils.py | 10 ++++ 4 files changed, 58 insertions(+), 56 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 752f94270..72f0b021b 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -42,19 +42,10 @@ NoOpExtension, ServerNoOpExtensionFactory, ) -from ..utils import MS +from ..utils import CERTIFICATE, MS from .utils import AsyncioTestCase -# Generate TLS certificate with: -# $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ -# -out test_localhost.crt -keyout test_localhost.key -# $ cat test_localhost.key test_localhost.crt > test_localhost.pem -# $ rm test_localhost.key test_localhost.crt - -testcert = bytes(pathlib.Path(__file__).parent.with_name("test_localhost.pem")) - - async def default_handler(ws): if ws.path == "/deprecated_attributes": await ws.recv() # delay that allows catching warnings @@ -314,13 +305,13 @@ class SecureClientServerTestsMixin(ClientServerTestsMixin): @property def server_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(testcert) + ssl_context.load_cert_chain(CERTIFICATE) return ssl_context @property def client_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(testcert) + ssl_context.load_verify_locations(CERTIFICATE) return ssl_context def start_server(self, **kwargs): diff --git a/tests/test_localhost.cnf b/tests/test_localhost.cnf index 6dc331ac6..4069e3967 100644 --- a/tests/test_localhost.cnf +++ b/tests/test_localhost.cnf @@ -22,5 +22,6 @@ subjectAltName = @san [ san ] DNS.1 = localhost -IP.2 = 127.0.0.1 -IP.3 = ::1 +DNS.2 = overridden +IP.3 = 127.0.0.1 +IP.4 = ::1 diff --git a/tests/test_localhost.pem b/tests/test_localhost.pem index b8a9ea9ab..8df63ec8f 100644 --- a/tests/test_localhost.pem +++ b/tests/test_localhost.pem @@ -1,48 +1,48 @@ -----BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCUgrQVkNbAWRlo -zZUj14Ufz7YEp2MXmvmhdlfOGLwjy+xPO98aJRv5/nYF2eWM3llcmLe8FbBSK+QF -To4su7ZVnc6qITOHqcSDUw06WarQUMs94bhHUvQp1u8+b2hNiMeGw6+QiBI6OJRO -iGpLRbkN6Uj3AKwi8SYVoLyMiztuwbNyGf8fF3DDpHZtBitGtMSBCMsQsfB465pl -2UoyBrWa2lsbLt3VvBZZvHqfEuPjpjjKN5USIXnaf0NizaR6ps3EyfftWy4i7zIQ -N5uTExvaPDyPn9nH3q/dkT99mSMSU1AvTTpX8PN7DlqE6wZMbQsBPRGW7GElQ+Ox -IKdKOLk5AgMBAAECggEAd3kqzQqnaTiEs4ZoC9yPUUc1pErQ8iWP27Ar9TZ67MVa -B2ggFJV0C0sFwbFI9WnPNCn77gj4vzJmD0riH+SnS/tXThDFtscBu7BtvNp0C4Bj -8RWMvXxjxuENuQnBPFbkRWtZ6wk8uK/Zx9AAyyt9M07Qjz1wPfAIdm/IH7zHBFMA -gsqjnkLh1r0FvjNEbLiuGqYU/GVxaZYd+xy+JU52IxjHUUL9yD0BPWb+Szar6AM2 -gUpmTX6+BcCZwwZ//DzCoWYZ9JbP8akn6edBeZyuMPqYgLzZkPyQ+hRW46VPPw89 -yg4LR9nzgQiBHlac0laB4NrWa+d9QRRLitl1O3gVAQKBgQDDkptxXu7w9Lpc+HeE -N/pJfpCzUuF7ZC4vatdoDzvfB5Ky6W88Poq+I7bB9m7StXdFAbDyUBxvisjTBMVA -OtYqpAk/rhX8MjSAtjoFe2nH+eEiQriuZmtA5CdKEXS4hNbc/HhEPWhk7Zh8OV5v -y7l4r6l4UHqaN9QyE0vlFdmcmQKBgQDCZZR/trJ2/g2OquaS+Zd2h/3NXw0NBq4z -4OBEWqNa/R35jdK6WlWJH7+tKOacr+xtswLpPeZHGwMdk64/erbYWBuJWAjpH72J -DM9+1H5fFHANWpWTNn94enQxwfzZRvdkxq4IWzGhesptYnHIzoAmaqC3lbn/e3u0 -Flng32hFoQKBgQCF3D4K3hib0lYQtnxPgmUMktWF+A+fflViXTWs4uhu4mcVkFNz -n7clJ5q6reryzAQjtmGfqRedfRex340HRn46V2aBMK2Znd9zzcZu5CbmGnFvGs3/ -iNiWZNNDjike9sV+IkxLIODoW/vH4xhxWrbLFSjg0ezoy5ew4qZK2abF2QKBgQC5 -M5efeQpbjTyTUERtf/aKCZOGZmkDoPq0GCjxVjzNQdqd1z0NJ2TYR/QP36idXIlu -FZ7PYZaS5aw5MGpQtfOe94n8dm++0et7t0WzunRO1yTNxCA+aSxWNquegAcJZa/q -RdKlyWPmSRqzzZdDzWCPuQQ3AyF5wkYfUy/7qjwoIQKBgB2v96BV7+lICviIKzzb -1o3A3VzAX5MGd98uLGjlK4qsBC+s7mk2eQztiNZgbA0W6fhQ5Dz3HcXJ5ppy8Okc -jeAktrNRzz15hvi/XkWdO+VMqiHW4l+sWYukjhCyod1oO1KGHq0LYYvv076syxGw -vRKLq7IJ4WIp1VtfaBlrIogq +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDYOOQyq8yYtn5x +K3yRborFxTFse16JIVb4x/ZhZgGm49eARCi09fmczQxJdQpHz81Ij6z0xi7AUYH7 +9wS8T0Lh3uGFDDS1GzITUVPIqSUi0xim2T6XPzXFVQYI1D/OjUxlHm+3/up+WwbL +sBgBO/lDmzoa3ZN7kt9HQoGc/14oQz1Qsv1QTDQs69r+o7mmBJr/hf/g7S0Csyy3 +iC6aaq+yCUyzDbjXceTI7WJqbTGNnK0/DjdFD/SJS/uSDNEg0AH53eqcCSjm+Ei/ +UF8qR5Pu4sSsNwToOW2MVgjtHFazc+kG3rzD6+3Dp+t6x6uI/npyuudOMCmOtd6z +kX0UPQaNAgMBAAECggEAS4eMBztGC+5rusKTEAZKSY15l0h9HG/d/qdzJFDKsO6T +/8VPZu8pk6F48kwFHFK1hexSYWq9OAcA3fBK4jDZzybZJm2+F6l5U5AsMUMMqt6M +lPP8Tj8RXG433muuIkvvbL82DVLpvNu1Qv+vUvcNOpWFtY7DDv6eKjlMJ3h4/pzh +89MNt26VMCYOlq1NSjuZBzFohL2u9nsFehlOpcVsqNfNfcYCq9+5yoH8fWJP90Op +hqhvqUoGLN7DRKV1f+AWHSA4nmGgvVviV5PQgMhtk5exlN7kG+rDc3LbzhefS1Sp +Tat1qIgm8fK2n+Q/obQPjHOGOGuvE5cIF7E275ZKgQKBgQDt87BqALKWnbkbQnb7 +GS1h6LRcKyZhFbxnO2qbviBWSo15LEF8jPGV33Dj+T56hqufa/rUkbZiUbIR9yOX +dnOwpAVTo+ObAwZfGfHvrnufiIbHFqJBumaYLqjRZ7AC0QtS3G+kjS9dbllrr7ok +fO4JdfKRXzBJKrkQdCn8hR22rQKBgQDon0b49Dxs1EfdSDbDode2TSwE83fI3vmR +SKUkNY8ma6CRbomVRWijhBM458wJeuhpjPZOvjNMsnDzGwrtdAp2VfFlMIDnA8ZC +fEWIAAH2QYKXKGmkoXOcWB2QbvbI154zCm6zFGtzvRKOCGmTXuhFajO8VPwOyJVt +aSJA3bLrYQKBgQDJM2/tAfAAKRdW9GlUwqI8Ep9G+/l0yANJqtTnIemH7XwYhJJO +9YJlPszfB2aMBgliQNSUHy1/jyKpzDYdITyLlPUoFwEilnkxuud2yiuf5rpH51yF +hU6wyWtXvXv3tbkEdH42PmdZcjBMPQeBSN2hxEi6ISncBDL9tau26PwJ9QKBgQCs +cNYl2reoXTzgtpWSNDk6NL769JjJWTFcF6QD0YhKjOI8rNpkw00sWc3+EybXqDr9 +c7dq6+gPZQAB1vwkxi6zRkZqIqiLl+qygnjwtkC+EhYCg7y8g8q2DUPtO7TJcb0e +TQ9+xRZad8B3dZj93A8G1hF//OfU9bB/qL3xo+bsQQKBgC/9YJvgLIWA/UziLcB2 +29Ai0nbPkN5df7z4PifUHHSlbQJHKak8UKbMP+8S064Ul0F7g8UCjZMk2LzSbaNY +XU5+2j0sIOnGUFoSlvcpdowzYrD2LN5PkKBot7AOq/v7HlcOoR8J8RGWAMpCrHsI +a/u/dlZs+/K16RcavQwx8rag -----END PRIVATE KEY----- -----BEGIN CERTIFICATE----- -MIIDTTCCAjWgAwIBAgIJAJ6VG2cQlsepMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV +MIIDWTCCAkGgAwIBAgIJAOL9UKiOOxupMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp -bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTE4MDUwNTE2NTc1NloYDzIwNjAwNTA0 -MTY1NzU2WjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM +bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTIyMTAxNTE5Mjg0MVoYDzIwNjQxMDE0 +MTkyODQxWjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI -hvcNAQEBBQADggEPADCCAQoCggEBAJSCtBWQ1sBZGWjNlSPXhR/PtgSnYxea+aF2 -V84YvCPL7E873xolG/n+dgXZ5YzeWVyYt7wVsFIr5AVOjiy7tlWdzqohM4epxINT -DTpZqtBQyz3huEdS9CnW7z5vaE2Ix4bDr5CIEjo4lE6IaktFuQ3pSPcArCLxJhWg -vIyLO27Bs3IZ/x8XcMOkdm0GK0a0xIEIyxCx8HjrmmXZSjIGtZraWxsu3dW8Flm8 -ep8S4+OmOMo3lRIhedp/Q2LNpHqmzcTJ9+1bLiLvMhA3m5MTG9o8PI+f2cfer92R -P32ZIxJTUC9NOlfw83sOWoTrBkxtCwE9EZbsYSVD47Egp0o4uTkCAwEAAaMwMC4w -LAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0G -CSqGSIb3DQEBCwUAA4IBAQA0imKp/rflfbDCCx78NdsR5rt0jKem2t3YPGT6tbeU -+FQz62SEdeD2OHWxpvfPf+6h3iTXJbkakr2R4lP3z7GHUe61lt3So9VHAvgbtPTH -aB1gOdThA83o0fzQtnIv67jCvE9gwPQInViZLEcm2iQEZLj6AuSvBKmluTR7vNRj -8/f2R4LsDfCWGrzk2W+deGRvSow7irS88NQ8BW8S8otgMiBx4D2UlOmQwqr6X+/r -jYIDuMb6GDKRXtBUGDokfE94hjj9u2mrNRwt8y4tqu8ZNa//yLEQ0Ow2kP3QJPLY -941VZpwRi2v/+JvI7OBYlvbOTFwM8nAk79k+Dgviygd9 +hvcNAQEBBQADggEPADCCAQoCggEBANg45DKrzJi2fnErfJFuisXFMWx7XokhVvjH +9mFmAabj14BEKLT1+ZzNDEl1CkfPzUiPrPTGLsBRgfv3BLxPQuHe4YUMNLUbMhNR +U8ipJSLTGKbZPpc/NcVVBgjUP86NTGUeb7f+6n5bBsuwGAE7+UObOhrdk3uS30dC +gZz/XihDPVCy/VBMNCzr2v6juaYEmv+F/+DtLQKzLLeILppqr7IJTLMNuNdx5Mjt +YmptMY2crT8ON0UP9IlL+5IM0SDQAfnd6pwJKOb4SL9QXypHk+7ixKw3BOg5bYxW +CO0cVrNz6QbevMPr7cOn63rHq4j+enK6504wKY613rORfRQ9Bo0CAwEAAaM8MDow +OAYDVR0RBDEwL4IJbG9jYWxob3N0ggpvdmVycmlkZGVuhwR/AAABhxAAAAAAAAAA +AAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQBPNDGDdl4wsCRlDuyCHBC8o+vW +Vb14thUw9Z6UrlsQRXLONxHOXbNAj1sYQACNwIWuNz36HXu5m8Xw/ID/bOhnIg+b +Y6l/JU/kZQYB7SV1aR3ZdbCK0gjfkE0POBHuKOjUFIOPBCtJ4tIBUX94zlgJrR9v +2rqJC3TIYrR7pVQumHZsI5GZEMpM5NxfreWwxcgltgxmGdm7elcizHfz7k5+szwh +4eZ/rxK9bw1q8BIvVBWelRvUR55mIrCjzfZp5ZObSYQTZlW7PzXBe5Jk+1w31YHM +RSBA2EpPhYlGNqPidi7bg7rnQcsc6+hE0OqzTL/hWxPm9Vbp9dj3HFTik1wa -----END CERTIFICATE----- diff --git a/tests/utils.py b/tests/utils.py index 5331746f8..afc1a460a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,12 +1,22 @@ import contextlib import email.utils import os +import pathlib import platform import time import unittest import warnings +# Generate TLS certificate with: +# $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ +# -out test_localhost.crt -keyout test_localhost.key +# $ cat test_localhost.key test_localhost.crt > test_localhost.pem +# $ rm test_localhost.key test_localhost.crt + +CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) + + DATE = email.utils.formatdate(usegmt=True) From 7b92fa02d88b6d6e807a653329b85127ce79d5e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 18:24:42 +0100 Subject: [PATCH 388/762] Add temp_unix_socket_path context manager for tests. --- tests/legacy/test_client_server.py | 11 +++-------- tests/utils.py | 7 +++++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 72f0b021b..d92338585 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -3,13 +3,11 @@ import functools import http import logging -import pathlib import platform import random import socket import ssl import sys -import tempfile import unittest import unittest.mock import urllib.error @@ -42,7 +40,7 @@ NoOpExtension, ServerNoOpExtensionFactory, ) -from ..utils import CERTIFICATE, MS +from ..utils import CERTIFICATE, MS, temp_unix_socket_path from .utils import AsyncioTestCase @@ -447,9 +445,7 @@ def send(self, *args, **kwargs): @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") def test_unix_socket(self): - with tempfile.TemporaryDirectory() as temp_dir: - path = bytes(pathlib.Path(temp_dir) / "websockets") - + with temp_unix_socket_path() as path: # Like self.start_server() but with unix_serve(). async def start_server(): return await unix_serve(default_handler, path) @@ -1445,8 +1441,7 @@ async def run_server(path): # Check that exiting the context manager closed the server. self.assertFalse(server.sockets) - with tempfile.TemporaryDirectory() as temp_dir: - path = bytes(pathlib.Path(temp_dir) / "websockets") + with temp_unix_socket_path() as path: self.loop.run_until_complete(run_server(path)) diff --git a/tests/utils.py b/tests/utils.py index afc1a460a..4e9ac9f0e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import os import pathlib import platform +import tempfile import time import unittest import warnings @@ -79,3 +80,9 @@ def assertDeprecationWarning(self, message): warning = recorded_warnings[0] self.assertEqual(warning.category, DeprecationWarning) self.assertEqual(str(warning.message), message) + + +@contextlib.contextmanager +def temp_unix_socket_path(): + with tempfile.TemporaryDirectory() as temp_dir: + yield str(pathlib.Path(temp_dir) / "websockets") From d9c8694a5cd41287a33f804419c75937420e251b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 18:06:32 +0100 Subject: [PATCH 389/762] Add websockets.sync package. --- setup.py | 7 ++++++- src/websockets/sync/__init__.py | 0 tests/sync/__init__.py | 0 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 src/websockets/sync/__init__.py create mode 100644 tests/sync/__init__.py diff --git a/setup.py b/setup.py index 564ada85c..5ed472503 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,12 @@ exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) -packages = ["websockets", "websockets/legacy", "websockets/extensions"] +packages = [ + "websockets", + "websockets/extensions", + "websockets/legacy", + "websockets/sync", +] ext_modules = [ setuptools.Extension( diff --git a/src/websockets/sync/__init__.py b/src/websockets/sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/sync/__init__.py b/tests/sync/__init__.py new file mode 100644 index 000000000..e69de29bb From 2d624fa36ae66fc818e075726f0a4aff683a4ef3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 18:06:50 +0100 Subject: [PATCH 390/762] Add thread-safe message reassembler. --- src/websockets/sync/messages.py | 281 +++++++++++++++++++ tests/sync/test_messages.py | 479 ++++++++++++++++++++++++++++++++ tests/sync/utils.py | 26 ++ 3 files changed, 786 insertions(+) create mode 100644 src/websockets/sync/messages.py create mode 100644 tests/sync/test_messages.py create mode 100644 tests/sync/utils.py diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py new file mode 100644 index 000000000..9265e4e9f --- /dev/null +++ b/src/websockets/sync/messages.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import codecs +import queue +import threading +from typing import Iterator, List, Optional, cast + +from ..frames import Frame, Opcode +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class Assembler: + """ + Assemble messages from frames. + + """ + + def __init__(self) -> None: + # Serialize reads and writes -- except for reads via synchronization + # primitives provided by the threading and queue modules. + self.mutex = threading.Lock() + + # We create a latch with two events to ensure proper interleaving of + # writing and reading messages. + # put() sets this event to tell get() that a message can be fetched. + self.message_complete = threading.Event() + # get() sets this event to let put() that the message was fetched. + self.message_fetched = threading.Event() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + # This flag prevents concurrent calls to put() by library code. + self.put_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder: Optional[codecs.IncrementalDecoder] = None + + # Buffer of frames belonging to the same message. + self.chunks: List[Data] = [] + + # When switching from "buffering" to "streaming", we use a thread-safe + # queue for transferring frames from the writing thread (library code) + # to the reading thread (user code). We're buffering when chunks_queue + # is None and streaming when it's a SimpleQueue. None is a sentinel + # value marking the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + # Remove quotes around type when dropping Python < 3.9. + self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None + + # This flag marks the end of the stream. + self.closed = False + + def get(self, timeout: Optional[float] = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + timeout: if a timeout is provided and elapses before a complete + message is received, :meth:`get` raises :exc:`TimeoutError`. + + Raises: + EOFError: if the stream of frames has ended. + RuntimeError: if two threads run :meth:`get` or :meth:``get_iter` + concurrently. + + """ + with self.mutex: + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get or get_iter is already running") + + self.get_in_progress = True + + # If the message_complete event isn't set yet, release the lock to + # allow put() to run and eventually set it. + # Locking with get_in_progress ensures only one thread can get here. + completed = self.message_complete.wait(timeout) + + with self.mutex: + self.get_in_progress = False + + # Waiting for a complete message timed out. + if not completed: + raise TimeoutError(f"timed out in {timeout:.1f}s") + + # get() was unblocked by close() rather than put(). + if self.closed: + raise EOFError("stream of frames ended") + + assert self.message_complete.is_set() + self.message_complete.clear() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + + assert not self.message_fetched.is_set() + self.message_fetched.set() + + self.chunks = [] + assert self.chunks_queue is None + + return message + + def get_iter(self) -> Iterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` yields a :class:`str` or + :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`RuntimeError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Raises: + EOFError: if the stream of frames has ended. + RuntimeError: if two threads run :meth:`get` or :meth:``get_iter` + concurrently. + + """ + with self.mutex: + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get or get_iter is already running") + + chunks = self.chunks + self.chunks = [] + self.chunks_queue = cast( + # Remove quotes around type when dropping Python < 3.9. + "queue.SimpleQueue[Optional[Data]]", + queue.SimpleQueue(), + ) + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.is_set(): + self.chunks_queue.put(None) + + self.get_in_progress = True + + # Locking with get_in_progress ensures only one thread can get here. + yield from chunks + while True: + chunk = self.chunks_queue.get() + if chunk is None: + break + yield chunk + + with self.mutex: + self.get_in_progress = False + + assert self.message_complete.is_set() + self.message_complete.clear() + + # get_iter() was unblocked by close() rather than put(). + if self.closed: + raise EOFError("stream of frames ended") + + assert not self.message_fetched.is_set() + self.message_fetched.set() + + assert self.chunks == [] + self.chunks_queue = None + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + When ``frame`` is the final frame in a message, :meth:`put` waits until + the message is fetched, either by calling :meth:`get` or by fully + consuming the return value of :meth:`get_iter`. + + :meth:`put` assumes that the stream of frames respects the protocol. If + it doesn't, the behavior is undefined. + + Raises: + EOFError: if the stream of frames has ended. + RuntimeError: if two threads run :meth:`put` concurrently. + + """ + with self.mutex: + if self.closed: + raise EOFError("stream of frames ended") + + if self.put_in_progress: + raise RuntimeError("put is already running") + + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + + if self.chunks_queue is None: + self.chunks.append(data) + else: + self.chunks_queue.put(data) + + if not frame.fin: + return + + # Message is complete. Wait until it's fetched to return. + + assert not self.message_complete.is_set() + self.message_complete.set() + + if self.chunks_queue is not None: + self.chunks_queue.put(None) + + assert not self.message_fetched.is_set() + + self.put_in_progress = True + + # Release the lock to allow get() to run and eventually set the event. + self.message_fetched.wait() + + with self.mutex: + self.put_in_progress = False + + assert self.message_fetched.is_set() + self.message_fetched.clear() + + # put() was unblocked by close() rather than get() or get_iter(). + if self.closed: + raise EOFError("stream of frames ended") + + self.decoder = None + + def close(self) -> None: + """ + End the stream of frames. + + Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + with self.mutex: + if self.closed: + return + + self.closed = True + + # Unblock get or get_iter. + if self.get_in_progress: + self.message_complete.set() + if self.chunks_queue is not None: + self.chunks_queue.put(None) + + # Unblock put(). + if self.put_in_progress: + self.message_fetched.set() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py new file mode 100644 index 000000000..069da784b --- /dev/null +++ b/tests/sync/test_messages.py @@ -0,0 +1,479 @@ +import time + +from websockets.frames import OP_BINARY, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame +from websockets.sync.messages import * + +from ..utils import MS +from .utils import ThreadTestCase + + +class AssemblerTests(ThreadTestCase): + """ + Tests in this class interact a lot with hidden synchronization mechanisms: + + - get() / get_iter() and put() must run in separate threads when a final + frame is set because put() waits for get() / get_iter() to fetch the + message before returning. + + - run_in_thread() lets its target run before yielding back control on entry, + which guarantees the intended execution order of test cases. + + - run_in_thread() waits for its target to finish running before yielding + back control on exit, which allows making assertions immediately. + + - When the main thread performs actions that let another thread progress, it + must wait before making assertions, to avoid depending on scheduling. + + """ + + def setUp(self): + self.assembler = Assembler() + + def tearDown(self): + """ + Ensure the assembler goes back to its default state after each test. + + This removes the need for testing various sequences. + + """ + self.assertFalse(self.assembler.mutex.locked()) + self.assertFalse(self.assembler.get_in_progress) + self.assertFalse(self.assembler.put_in_progress) + if not self.assembler.closed: + self.assertFalse(self.assembler.message_complete.is_set()) + self.assertFalse(self.assembler.message_fetched.is_set()) + self.assertIsNone(self.assembler.decoder) + self.assertEqual(self.assembler.chunks, []) + self.assertIsNone(self.assembler.chunks_queue) + + # Test get + + def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, "café") + + def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"tea")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, b"tea") + + def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + self.assertEqual(message, "café") + + def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"tea")) + + self.assertEqual(message, b"tea") + + def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, "café") + + def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, b"tea") + + def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(message, "café") + + def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(message, b"tea") + + def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(message, "café") + + def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + message = None + + def getter(): + nonlocal message + message = self.assembler.get() + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(message, b"tea") + + # Test get_iter + + def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + fragments = list(self.assembler.get_iter()) + + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"tea")) + + with self.run_in_thread(putter): + fragments = list(self.assembler.get_iter()) + + self.assertEqual(fragments, [b"tea"]) + + def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"tea")) + + self.assertEqual(fragments, [b"tea"]) + + def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.run_in_thread(putter): + fragments = list(self.assembler.get_iter()) + + self.assertEqual(fragments, ["ca", "f", "é"]) + + def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + with self.run_in_thread(putter): + fragments = list(self.assembler.get_iter()) + + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + with self.run_in_thread(getter): + self.assertEqual(fragments, ["ca"]) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, ["ca", "f"]) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(fragments, ["ca", "f", "é"]) + + def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + with self.run_in_thread(getter): + self.assertEqual(fragments, [b"t"]) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, [b"t", b"e"]) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, ["ca"]) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, ["ca", "f"]) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(fragments, ["ca", "f", "é"]) + + def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + fragments = [] + + def getter(): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + + with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, [b"t"]) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + time.sleep(MS) + self.assertEqual(fragments, [b"t", b"e"]) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + # Test timeouts + + def test_get_with_timeout_completes(self): + """get returns a message when it is received before the timeout.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + message = self.assembler.get(MS) + + self.assertEqual(message, "café") + + def test_get_with_timeout_times_out(self): + """get raises TimeoutError when no message is received before the timeout.""" + with self.assertRaises(TimeoutError): + self.assembler.get(MS) + + # Test control frames + + def test_control_frame_before_message_is_ignored(self): + """get ignores control frames between messages.""" + + def putter(): + self.assembler.put(Frame(OP_PING, b"")) + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, "café") + + def test_control_frame_in_fragmented_message_is_ignored(self): + """get ignores control frames within fragmented messages.""" + + def putter(): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_PING, b"")) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_PONG, b"")) + self.assembler.put(Frame(OP_CONT, b"a")) + + with self.run_in_thread(putter): + message = self.assembler.get() + + self.assertEqual(message, b"tea") + + # Test concurrency + + def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_put_fails_when_put_is_running(self): + """put cannot be called concurrently with itself.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + with self.assertRaises(RuntimeError): + self.assembler.put(Frame(OP_BINARY, b"tea")) + self.assembler.get() # unblock other thread + + # Test termination + + def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + + def closer(): + time.sleep(2 * MS) + self.assembler.close() + + with self.run_in_thread(closer): + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + + def closer(): + time.sleep(2 * MS) + self.assembler.close() + + with self.run_in_thread(closer): + with self.assertRaises(EOFError): + list(self.assembler.get_iter()) + + def test_put_fails_when_interrupted_by_close(self): + """put raises EOFError when close is called.""" + + def closer(): + time.sleep(2 * MS) + self.assembler.close() + + with self.run_in_thread(closer): + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + list(self.assembler.get_iter()) + + def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() diff --git a/tests/sync/utils.py b/tests/sync/utils.py new file mode 100644 index 000000000..8903cd349 --- /dev/null +++ b/tests/sync/utils.py @@ -0,0 +1,26 @@ +import contextlib +import threading +import time +import unittest + +from ..utils import MS + + +class ThreadTestCase(unittest.TestCase): + @contextlib.contextmanager + def run_in_thread(self, target): + """ + Run ``target`` function without arguments in a thread. + + In order to facilitate writing tests, this helper lets the thread run + for 1ms on entry and joins the thread with a 1ms timeout on exit. + + """ + thread = threading.Thread(target=target) + thread.start() + time.sleep(MS) + try: + yield + finally: + thread.join(MS) + self.assertFalse(thread.is_alive()) From 4616405fa50959eedfbe8de6df423f9305f266f8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 24 Apr 2022 08:16:02 +0200 Subject: [PATCH 391/762] Add deadline for managing timeouts. --- src/websockets/sync/utils.py | 47 ++++++++++++++++++++++++++++++++++++ tests/sync/test_utils.py | 33 +++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 src/websockets/sync/utils.py create mode 100644 tests/sync/test_utils.py diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py new file mode 100644 index 000000000..8aab6c0d9 --- /dev/null +++ b/src/websockets/sync/utils.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import time +from typing import Optional + + +__all__ = ["Deadline"] + + +class Deadline: + """ + Manage timeouts across multiple steps. + + Args: + timeout: time available in seconds; :obj:`None` if there is no limit. + + """ + + def __init__(self, timeout: Optional[float]) -> None: + self.deadline: Optional[float] + if timeout is None: + self.deadline = None + else: + self.deadline = time.monotonic() + timeout + + def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: + """ + Calculate a timeout from a deadline. + + Args: + raise_if_elapsed (bool): whether to raise :exc:`TimeoutError` + if the deadline lapsed. + + Raises: + TimeoutError: if the deadline lapsed. + + Returns: + Optional[float]: Time left in seconds; + :obj:`None` if there is no limit. + + """ + if self.deadline is None: + return None + timeout = self.deadline - time.monotonic() + if raise_if_elapsed and timeout <= 0: + raise TimeoutError("timed out") + return timeout diff --git a/tests/sync/test_utils.py b/tests/sync/test_utils.py new file mode 100644 index 000000000..2980a97b4 --- /dev/null +++ b/tests/sync/test_utils.py @@ -0,0 +1,33 @@ +import unittest + +from websockets.sync.utils import * + +from ..utils import MS + + +class DeadlineTests(unittest.TestCase): + def test_timeout_pending(self): + """timeout returns remaining time if deadline is in the future.""" + deadline = Deadline(MS) + timeout = deadline.timeout() + self.assertGreater(timeout, 0) + self.assertLess(timeout, MS) + + def test_timeout_elapsed_exception(self): + """timeout raises TimeoutError if deadline is in the past.""" + deadline = Deadline(-MS) + with self.assertRaises(TimeoutError): + deadline.timeout() + + def test_timeout_elapsed_no_exception(self): + """timeout doesn't raise TimeoutError when raise_if_elapsed is disabled.""" + deadline = Deadline(-MS) + timeout = deadline.timeout(raise_if_elapsed=False) + self.assertGreater(timeout, -2 * MS) + self.assertLess(timeout, -MS) + + def test_no_timeout(self): + """timeout returns None when no deadline is set.""" + deadline = Deadline(None) + timeout = deadline.timeout() + self.assertIsNone(timeout, None) From 125ffe28d0fd1c1a80d6021a7dbe7ee69451bb33 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:55:30 +0100 Subject: [PATCH 392/762] Add compatibility shim for socket.create_server. --- src/websockets/sync/compatibility.py | 21 +++++++++++++++++++++ tests/maxi_cov.py | 4 +++- 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 src/websockets/sync/compatibility.py diff --git a/src/websockets/sync/compatibility.py b/src/websockets/sync/compatibility.py new file mode 100644 index 000000000..3064263e9 --- /dev/null +++ b/src/websockets/sync/compatibility.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +try: + from socket import create_server as socket_create_server +except ImportError: # pragma: no cover + import socket + + def socket_create_server(address, family=socket.AF_INET): # type: ignore + """Simplified backport of socket.create_server from Python 3.8.""" + sock = socket.socket(family, socket.SOCK_STREAM) + try: + sock.bind(address) + sock.listen() + return sock + except socket.error: + sock.close() + raise + + +__all__ = ["socket_create_server"] diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index b7c07b698..2568dcf18 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -47,6 +47,7 @@ def get_mapping(src_dir="src"): if "legacy" not in os.path.dirname(src_file) if os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" + and os.path.basename(src_file) != "compatibility.py" ] test_files = [ test_file @@ -89,8 +90,9 @@ def get_ignored_files(src_dir="src"): return [ # */websockets matches src/websockets and .tox/**/site-packages/websockets. - # There are no tests for the __main__ module. + # There are no tests for the __main__ module and for compatibility modules. "*/websockets/__main__.py", + "*/websockets/*/compatibility.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", From 59fa2e3de50a9c46f4287ac300b7a2c42c7ae0a4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:57:42 +0100 Subject: [PATCH 393/762] Add thread-based implementation. --- .github/workflows/tests.yml | 3 + setup.cfg | 1 + src/websockets/server.py | 2 + src/websockets/sync/client.py | 327 +++++++++++++ src/websockets/sync/connection.py | 757 ++++++++++++++++++++++++++++++ src/websockets/sync/messages.py | 14 +- src/websockets/sync/server.py | 525 +++++++++++++++++++++ src/websockets/sync/utils.py | 9 +- tests/protocol.py | 29 ++ tests/sync/client.py | 55 +++ tests/sync/connection.py | 109 +++++ tests/sync/server.py | 67 +++ tests/sync/test_client.py | 271 +++++++++++ tests/sync/test_connection.py | 704 +++++++++++++++++++++++++++ tests/sync/test_messages.py | 2 +- tests/sync/test_server.py | 389 +++++++++++++++ 16 files changed, 3251 insertions(+), 13 deletions(-) create mode 100644 src/websockets/sync/client.py create mode 100644 src/websockets/sync/connection.py create mode 100644 src/websockets/sync/server.py create mode 100644 tests/protocol.py create mode 100644 tests/sync/client.py create mode 100644 tests/sync/connection.py create mode 100644 tests/sync/server.py create mode 100644 tests/sync/test_client.py create mode 100644 tests/sync/test_connection.py create mode 100644 tests/sync/test_server.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4d5cc3cd0..34f8d8c5c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,6 +8,9 @@ on: branches: - main +env: + WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 10 + jobs: coverage: name: Run test coverage checks diff --git a/setup.cfg b/setup.cfg index 48703df87..3a8321c50 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,3 +37,4 @@ exclude_lines = raise AssertionError raise NotImplementedError self.fail\(".*"\) + @unittest.skip diff --git a/src/websockets/server.py b/src/websockets/server.py index 16979e7e7..0dd579052 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -164,6 +164,8 @@ def accept(self, request: Request) -> Response: f"Failed to open a WebSocket connection: {exc}.\n", ) except Exception as exc: + # Handle exceptions raised by user-provided select_subprotocol and + # unexpected errors. request._exception = exc self.handshake_exc = exc self.logger.error("opening handshake failed", exc_info=True) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py new file mode 100644 index 000000000..e5922582d --- /dev/null +++ b/src/websockets/sync/client.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import socket +import ssl +import threading +from typing import Any, Optional, Sequence, Type + +from ..client import ClientProtocol +from ..datastructures import HeadersLike +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Response +from ..protocol import CONNECTING, OPEN, Event +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import parse_uri +from .connection import Connection +from .utils import Deadline + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + + +class ClientConnection(Connection): + """ + Threaded implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports iteration to receive messages:: + + for message in websocket: + process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + socket: Socket connected to a WebSocket server. + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + + """ + + def __init__( + self, + socket: socket.socket, + protocol: ClientProtocol, + *, + close_timeout: Optional[float] = 10, + ) -> None: + self.protocol: ClientProtocol + self.response_rcvd = threading.Event() + super().__init__( + socket, + protocol, + close_timeout=close_timeout, + ) + + def handshake( + self, + additional_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, + timeout: Optional[float] = None, + ) -> None: + """ + Perform the opening handshake. + + """ + with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers["User-Agent"] = user_agent_header + self.protocol.send_request(self.request) + + if not self.response_rcvd.wait(timeout): + self.close_socket() + self.recv_events_thread.join() + raise TimeoutError("timed out during handshake") + + if self.response is None: + self.close_socket() + self.recv_events_thread.join() + raise ConnectionError("connection closed during handshake") + + if self.protocol.state is not OPEN: + self.recv_events_thread.join(self.close_timeout) + self.close_socket() + self.recv_events_thread.join() + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + """ + try: + super().recv_events() + finally: + # If the connection is closed during the handshake, unblock it. + self.response_rcvd.set() + + +def connect( + uri: str, + *, + # TCP/TLS — unix and path are only for unix_connect() + sock: Optional[socket.socket] = None, + ssl_context: Optional[ssl.SSLContext] = None, + server_hostname: Optional[str] = None, + unix: bool = False, + path: Optional[str] = None, + # WebSocket + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + additional_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, + compression: Optional[str] = "deflate", + # Timeouts + open_timeout: Optional[float] = 10, + close_timeout: Optional[float] = 10, + # Limits + max_size: Optional[int] = 2**20, + # Logging + logger: Optional[LoggerLike] = None, + # Escape hatch for advanced customization + create_connection: Optional[Type[ClientConnection]] = None, +) -> ClientConnection: + """ + Connect to the WebSocket server at ``uri``. + + This function returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as a context manager:: + + async with websockets.sync.client.connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + Args: + uri: URI of the WebSocket server. + sock: Preexisting TCP socket. ``sock`` overrides the host and port + from ``uri``. You may call :func:`socket.create_connection` to + create a suitable TCP socket. + ssl_context: Configuration for enabling TLS on the connection. + server_hostname: Hostname for the TLS handshake. ``server_hostname`` + overrides the hostname from ``uri``. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + # Process parameters + + wsuri = parse_uri(uri) + if not wsuri.secure and ssl_context is not None: + raise TypeError("ssl_context argument is incompatible with a ws:// URI") + + if unix: + if path is None and sock is None: + raise TypeError("missing path argument") + elif path is not None and sock is not None: + raise TypeError("path and sock arguments are incompatible") + else: + assert path is None # private argument, only set by unix_connect() + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + # Calculate timeouts on the TCP, TLS, and WebSocket handshakes. + # The TCP and TLS timeouts must be set on the socket, then removed + # to avoid conflicting with the WebSocket timeout in handshake(). + deadline = Deadline(open_timeout) + + if create_connection is None: + create_connection = ClientConnection + + try: + # Connect socket + + if sock is None: + if unix: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(deadline.timeout()) + assert path is not None # validated above -- this is for mpypy + sock.connect(path) + else: + sock = socket.create_connection( + (wsuri.host, wsuri.port), + deadline.timeout(), + ) + sock.settimeout(None) + + # Disable Nagle algorithm + + if not unix: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + + # Initialize TLS wrapper and perform TLS handshake + + if wsuri.secure: + if ssl_context is None: + ssl_context = ssl.create_default_context() + if server_hostname is None: + server_hostname = wsuri.host + sock.settimeout(deadline.timeout()) + sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname) + sock.settimeout(None) + + # Initialize WebSocket connection + + protocol = ClientProtocol( + wsuri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + state=CONNECTING, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket protocol + + connection = create_connection( + sock, + protocol, + close_timeout=close_timeout, + ) + # On failure, handshake() closes the socket and raises an exception. + connection.handshake( + additional_headers, + user_agent_header, + deadline.timeout(), + ) + + except Exception: + if sock is not None: + sock.close() + raise + + return connection + + +def unix_connect( + path: Optional[str] = None, + uri: Optional[str] = None, + **kwargs: Any, +) -> ClientConnection: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function is identical to :func:`connect`, except for the additional + ``path`` argument. It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl_context`` is provided, to + ``wss://localhost/``. + + """ + if uri is None: + if kwargs.get("ssl_context") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py new file mode 100644 index 000000000..64e5c8b44 --- /dev/null +++ b/src/websockets/sync/connection.py @@ -0,0 +1,757 @@ +from __future__ import annotations + +import contextlib +import logging +import random +import socket +import struct +import threading +import uuid +from types import TracebackType +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union + +from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..frames import DATA_OPCODES, BytesLike, Frame, Opcode, prepare_ctrl +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .messages import Assembler +from .utils import Deadline + + +__all__ = ["Connection"] + +logger = logging.getLogger(__name__) + +BUFSIZE = 65536 + + +class Connection: + """ + Threaded implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.sync.client.ClientConnection` or + :class:`~websockets.sync.server.ServerConnection`. + + """ + + def __init__( + self, + socket: socket.socket, + protocol: Protocol, + *, + close_timeout: Optional[float] = 10, + ) -> None: + self.socket = socket + self.protocol = protocol + self.close_timeout = close_timeout + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Optional[Request] = None + """Opening handshake request.""" + self.response: Optional[Response] = None + """Opening handshake response.""" + + # Mutex serializing interactions with the protocol. + self.protocol_mutex = threading.Lock() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages = Assembler() + + # Whether we are busy sending a fragmented message. + self.send_in_progress = False + + # Deadline for the closing handshake. + self.close_deadline: Optional[Deadline] = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pings: Dict[bytes, threading.Event] = {} + + # Receiving events from the socket. + self.recv_events_thread = threading.Thread(target=self.recv_events) + self.recv_events_thread.start() + + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_events_exc: Optional[BaseException] = None + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.socket.getsockname() + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.socket.getpeername() + + @property + def subprotocol(self) -> Optional[Subprotocol]: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + # Public methods + + def __enter__(self) -> Connection: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close(1000 if exc_type is None else 1011) + + def __iter__(self) -> Iterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages in an infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield self.recv() + except ConnectionClosedOK: + return + + def recv(self, timeout: Optional[float] = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + If ``timeout`` is :obj:`None`, block until a message is received. If + ``timeout`` is set and no message is received within ``timeout`` + seconds, raise :exc:`TimeoutError`. Set ``timeout`` to ``0`` to check if + a message was already received. + + If the message is fragmented, wait until all fragments are received, + reassemble them, and return the whole message. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two threads call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return self.recv_messages.get(timeout) + except EOFError: + raise self.protocol.close_exc from self.recv_events_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv while another thread " + "is already running recv or recv_streaming" + ) from None + + def recv_streaming(self) -> Iterator[Data]: + """ + Receive the next message frame by frame. + + If the message is fragmented, yield each fragment as it is received. + The iterator must be fully consumed, or else the connection will become + unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two threads call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + yield from self.recv_messages.get_iter() + except EOFError: + raise self.protocol.close_exc from self.recv_events_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv_streaming while another thread " + "is already running recv or recv_streaming" + ) from None + + def send(self, message: Union[Data, Iterable[Data]]) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :meth:`send` also accepts an iterable of strings, bytestrings, or + bytes-like objects to enable fragmentation_. Each item is treated as a + message fragment and sent in its own frame. All items must be of the + same type, or else :meth:`send` will raise a :exc:`TypeError` and the + connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If a connection is busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. + + """ + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + with self.send_context(): + if self.send_in_progress: + raise RuntimeError( + "cannot call send while another thread " + "is already running send" + ) + self.protocol.send_text(message.encode("utf-8")) + + elif isinstance(message, BytesLike): + with self.send_context(): + if self.send_in_progress: + raise RuntimeError( + "cannot call send while another thread " + "is already running send" + ) + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + try: + # First fragment. + if isinstance(chunk, str): + text = True + with self.send_context(): + if self.send_in_progress: + raise RuntimeError( + "cannot call send while another thread " + "is already running send" + ) + self.send_in_progress = True + self.protocol.send_text( + chunk.encode("utf-8"), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + with self.send_context(): + if self.send_in_progress: + raise RuntimeError( + "cannot call send while another thread " + "is already running send" + ) + self.send_in_progress = True + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("data iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and text: + with self.send_context(): + assert self.send_in_progress + self.protocol.send_continuation( + chunk.encode("utf-8"), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + with self.send_context(): + assert self.send_in_progress + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("data iterable must contain uniform types") + + # Final fragment. + with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + self.send_in_progress = False + + except RuntimeError: + # We didn't start sending a fragmented message. + raise + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + else: + raise TypeError("data must be bytes, str, or iterable") + + def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + with self.send_context(): + if self.send_in_progress: + self.protocol.fail(1011, "close during fragmented message") + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + def ping(self, data: Optional[Data] = None) -> threading.Event: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + An event that will be set when the corresponding pong is received. + You can ignore it if you don't intend to wait. + + :: + + pong_event = ws.ping() + pong_event.wait() # only if you want to wait for the pong + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if data is not None: + data = prepare_ctrl(data) + + with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = threading.Event() + self.pings[data] = pong_waiter + self.protocol.send_ping(data) + return pong_waiter + + def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + data = prepare_ctrl(data) + + with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + with self.protocol_mutex: + # Ignore unsolicited pong. + if data not in self.pings: + return + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, ping in self.pings.items(): + ping_ids.append(ping_id) + ping.set() + if ping_id == data: + break + else: + raise AssertionError("solicited pong not found in pings") + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + Run this method in a thread as long as the connection is alive. + + ``recv_events()`` exits immediately when the ``self.socket`` is closed. + + """ + try: + while True: + try: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) + data = self.socket.recv(BUFSIZE) + except Exception as exc: + if self.debug: + self.logger.debug("error while receiving data", exc_info=True) + # When the closing handshake is initiated by our side, + # recv() may block until send_context() closes the socket. + # In that case, send_context() already set recv_events_exc. + # Calling set_recv_events_exc() avoids overwriting it. + with self.protocol_mutex: + self.set_recv_events_exc(exc) + break + + if data == b"": + break + + # Acquire the connection lock. + with self.protocol_mutex: + # Feed incoming data to the connection. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the socket. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # Similarly to the above, avoid overriding an exception + # set by send_context(), in case of a race condition + # i.e. send_context() closes the socket after recv() + # returns above but before send_data() calls send(). + self.set_recv_events_exc(exc) + break + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_deadline is None: + self.close_deadline = Deadline(self.close_timeout) + + # Unlock conn_mutex before processing events. Else, the + # application can't send messages in response to events. + + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + # Breaking out of the while True: ... loop means that we believe + # that the socket doesn't work anymore. + with self.protocol_mutex: + # Feed the end of the data stream to the connection. + self.protocol.receive_eof() + + # This isn't expected to generate events. + assert not self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it handles errors itself. + self.send_data() + + except Exception as exc: + # This branch should never run. It's a safety net in case of bugs. + self.logger.error("unexpected internal error", exc_info=True) + with self.protocol_mutex: + self.set_recv_events_exc(exc) + # We don't know where we crashed. Force protocol state to CLOSED. + self.protocol.state = CLOSED + finally: + # This isn't expected to raise an exception. + self.recv_messages.close() + self.close_socket() + + @contextlib.contextmanager + def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> Iterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` acquires the connection lock and checks + that the connection is open; on exit, it writes outgoing data to the + socket:: + + with self.send_context(): + self.protocol.send_text(message.encode("utf-8")) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the socket and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: Optional[BaseException] = None + + # Acquire the protocol lock. + with self.protocol_mutex: + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, RuntimeError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING) and we didn't release protocol_mutex, + # it is certain that self.close_deadline is still None. + assert self.close_deadline is None + self.close_deadline = Deadline(self.close_timeout) + # Write outgoing data to the socket. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + raise_close_exc = True + + # To avoid a deadlock, release the connection lock by exiting the + # context manager before waiting for recv_events() to terminate. + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + if self.close_deadline is None: + timeout = self.close_timeout + else: + # Thread.join() returns immediately if timeout is negative. + timeout = self.close_deadline.timeout(raise_if_elapsed=False) + self.recv_events_thread.join(timeout) + + if self.recv_events_thread.is_alive(): + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_events_exc before closing the socket in order to get + # proper exception reporting. + raise_close_exc = True + with self.protocol_mutex: + self.set_recv_events_exc(original_exc) + + # If an error occurred, close the socket to terminate the connection and + # raise an exception. + if raise_close_exc: + self.close_socket() + self.recv_events_thread.join() + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + This method requires holding protocol_mutex. + + Raises: + OSError: When a socket operations fails. + + """ + assert self.protocol_mutex.locked() + for data in self.protocol.data_to_send(): + if data: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) + self.socket.sendall(data) + else: + try: + self.socket.shutdown(socket.SHUT_WR) + except OSError: # socket already closed + pass + + def set_recv_events_exc(self, exc: Optional[BaseException]) -> None: + """ + Set recv_events_exc, if not set yet. + + This method requires holding protocol_mutex. + + """ + assert self.protocol_mutex.locked() + if self.recv_events_exc is None: + self.recv_events_exc = exc + + def close_socket(self) -> None: + """ + Shutdown and close socket. + + shutdown() is required to interrupt recv() on Linux. + + """ + try: + self.socket.shutdown(socket.SHUT_RDWR) + except OSError: + pass # socket is already closed + self.socket.close() diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 9265e4e9f..67a22313c 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -67,12 +67,12 @@ def get(self, timeout: Optional[float] = None) -> Data: messages frame by frame, use :meth:`get_iter` instead. Args: - timeout: if a timeout is provided and elapses before a complete + timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. Raises: - EOFError: if the stream of frames has ended. - RuntimeError: if two threads run :meth:`get` or :meth:``get_iter` + EOFError: If the stream of frames has ended. + RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` concurrently. """ @@ -130,8 +130,8 @@ def get_iter(self) -> Iterator[Data]: fragmented, use :meth:`get` instead. Raises: - EOFError: if the stream of frames has ended. - RuntimeError: if two threads run :meth:`get` or :meth:``get_iter` + EOFError: If the stream of frames has ended. + RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` concurrently. """ @@ -194,8 +194,8 @@ def put(self, frame: Frame) -> None: it doesn't, the behavior is undefined. Raises: - EOFError: if the stream of frames has ended. - RuntimeError: if two threads run :meth:`put` concurrently. + EOFError: If the stream of frames has ended. + RuntimeError: If two threads run :meth:`put` concurrently. """ with self.mutex: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py new file mode 100644 index 000000000..a53ae2b25 --- /dev/null +++ b/src/websockets/sync/server.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +import http +import logging +import os +import select +import socket +import ssl +import threading +from types import TracebackType +from typing import Any, Callable, Optional, Sequence, Type + +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Request, Response +from ..protocol import CONNECTING, OPEN, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, Subprotocol +from .compatibility import socket_create_server +from .connection import Connection +from .utils import Deadline + + +__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] + + +class ServerConnection(Connection): + """ + Threaded implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports iteration to receive messages:: + + for message in websocket: + process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + socket: Socket connected to a WebSocket client. + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + + """ + + def __init__( + self, + socket: socket.socket, + protocol: ServerProtocol, + *, + close_timeout: Optional[float] = 10, + ) -> None: + self.protocol: ServerProtocol + self.request_rcvd = threading.Event() + super().__init__( + socket, + protocol, + close_timeout=close_timeout, + ) + + def handshake( + self, + process_request: Optional[ + Callable[ + [ServerConnection, Request], + Optional[Response], + ] + ] = None, + process_response: Optional[ + Callable[ + [ServerConnection, Request, Response], + Optional[Response], + ] + ] = None, + server_header: Optional[str] = USER_AGENT, + timeout: Optional[float] = None, + ) -> None: + """ + Perform the opening handshake. + + """ + if not self.request_rcvd.wait(timeout): + self.close_socket() + self.recv_events_thread.join() + raise TimeoutError("timed out during handshake") + + if self.request is None: + self.close_socket() + self.recv_events_thread.join() + raise ConnectionError("connection closed during handshake") + + with self.send_context(expected_state=CONNECTING): + self.response = None + + if process_request is not None: + try: + self.response = process_request(self, self.request) + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + self.response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if self.response is None: + self.response = self.protocol.accept(self.request) + + if server_header is not None: + self.response.headers["Server"] = server_header + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + self.response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + else: + if response is not None: + self.response = response + + self.protocol.send_response(self.response) + + if self.protocol.state is not OPEN: + self.recv_events_thread.join(self.close_timeout) + self.close_socket() + self.recv_events_thread.join() + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + """ + try: + super().recv_events() + finally: + # If the connection is closed during the handshake, unblock it. + self.request_rcvd.set() + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~socketserver.BaseServer`, notably the + :meth:`~socketserver.BaseServer.serve_forever` and + :meth:`~socketserver.BaseServer.shutdown` methods, as well as the context + manager protocol. + + Args: + socket: Server socket listening for new connections. + handler: Handler for one connection. Receives the socket and address + returned by :meth:`~socket.socket.accept`. + logger: Logger for this server. + + """ + + def __init__( + self, + socket: socket.socket, + handler: Callable[[socket.socket, Any], None], + logger: Optional[LoggerLike] = None, + ): + self.socket = socket + self.handler = handler + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + self.shutdown_watcher, self.shutdown_notifier = os.pipe() + + def serve_forever(self) -> None: + """ + See :meth:`socketserver.BaseServer.serve_forever`. + + This method doesn't return. Calling :meth:`shutdown` from another thread + stops the server. + + Typical use:: + + with serve(...) as server: + server.serve_forever() + + """ + poller = select.poll() + poller.register(self.socket) + poller.register(self.shutdown_watcher) + + while True: + poller.poll() + try: + # If the socket is closed, this will raise an exception and exit + # the loop. So we don't need to check the return value of poll(). + sock, addr = self.socket.accept() + except OSError: + break + thread = threading.Thread(target=self.handler, args=(sock, addr)) + thread.start() + + def shutdown(self) -> None: + """ + See :meth:`socketserver.BaseServer.shutdown`. + + """ + self.socket.close() + os.write(self.shutdown_notifier, b"x") + + def fileno(self) -> int: + """ + See :meth:`socketserver.BaseServer.fileno`. + + """ + return self.socket.fileno() + + def __enter__(self) -> WebSocketServer: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.shutdown() + + +def serve( + handler: Callable[[ServerConnection], None], + host: Optional[str] = None, + port: Optional[int] = None, + *, + # TCP/TLS — unix and path are only for unix_serve() + sock: Optional[socket.socket] = None, + ssl_context: Optional[ssl.SSLContext] = None, + unix: bool = False, + path: Optional[str] = None, + # WebSocket + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + select_subprotocol: Optional[ + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Optional[Subprotocol], + ] + ] = None, + process_request: Optional[ + Callable[ + [ServerConnection, Request], + Optional[Response], + ] + ] = None, + process_response: Optional[ + Callable[ + [ServerConnection, Request, Response], + Optional[Response], + ] + ] = None, + server_header: Optional[str] = USER_AGENT, + compression: Optional[str] = "deflate", + # Timeouts + open_timeout: Optional[float] = 10, + close_timeout: Optional[float] = 10, + # Limits + max_size: Optional[int] = 2**20, + # Logging + logger: Optional[LoggerLike] = None, + # Escape hatch for advanced customization + create_connection: Optional[Type[ServerConnection]] = None, +) -> WebSocketServer: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler``. + + The handler receives a :class:`ServerConnection` instance, which you can use + to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + :class:`WebSocketServer` mirrors the API of + :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure + that it will be closed and call the :meth:`~WebSocketServer.serve_forever` + method to serve requests:: + + def handler(websocket): + ... + + with websockets.sync.server.serve(handler, ...) as server: + server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :func:`~socket.create_server` for details. + port: TCP port the server listens on. + See :func:`~socket.create_server` for details. + sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``. + You may call :func:`socket.create_server` to create a suitable TCP + socket. + ssl_context: Configuration for enabling TLS on the connection. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force a HTTP 101 Continue response, + the handshake is successful. Else, the connection is aborted. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force a HTTP 101 Continue response, + the handshake is successful. Else, the connection is aborted. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + """ + + # Process parameters + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + # Bind socket and listen + + if sock is None: + if unix: + if path is None: + raise TypeError("missing path argument") + sock = socket_create_server(path, family=socket.AF_UNIX) + else: + sock = socket_create_server((host, port)) + else: + if path is not None: + raise TypeError("path and sock arguments are incompatible") + + # Initialize TLS wrapper + + if ssl_context is not None: + sock = ssl_context.wrap_socket( + sock, + server_side=True, + # Delay TLS handshake until after we set a timeout on the socket. + do_handshake_on_connect=False, + ) + + # Define request handler + + def conn_handler(sock: socket.socket, addr: Any) -> None: + # Calculate timeouts on the TLS and WebSocket handshakes. + # The TLS timeout must be set on the socket, then removed + # to avoid conflicting with the WebSocket timeout in handshake(). + deadline = Deadline(open_timeout) + + try: + # Disable Nagle algorithm + + if not unix: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + + # Perform TLS handshake + + if ssl_context is not None: + sock.settimeout(deadline.timeout()) + assert isinstance(sock, ssl.SSLSocket) # mypy cannot figure this out + sock.do_handshake() + sock.settimeout(None) + + # Create a closure so that select_subprotocol has access to self. + + protocol_select_subprotocol: Optional[ + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Optional[Subprotocol], + ] + ] = None + + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Optional[Subprotocol]: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # Initialize WebSocket connection + + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + state=CONNECTING, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket protocol + + assert create_connection is not None # help mypy + connection = create_connection( + sock, + protocol, + close_timeout=close_timeout, + ) + # On failure, handshake() closes the socket, raises an exception, and + # logs it. + connection.handshake( + process_request, + process_response, + server_header, + deadline.timeout(), + ) + + except Exception: + sock.close() + return + + try: + handler(connection) + except Exception: + protocol.logger.error("connection handler failed", exc_info=True) + connection.close(1011) + else: + connection.close() + + # Initialize server + + return WebSocketServer(sock, conn_handler, logger) + + +def unix_serve( + handler: Callable[[ServerConnection], Any], + path: Optional[str] = None, + **kwargs: Any, +) -> WebSocketServer: + """ + Create a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, path=path, unix=True, **kwargs) diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 8aab6c0d9..471f32e19 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -12,7 +12,7 @@ class Deadline: Manage timeouts across multiple steps. Args: - timeout: time available in seconds; :obj:`None` if there is no limit. + timeout: Time available in seconds or :obj:`None` if there is no limit. """ @@ -28,15 +28,14 @@ def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: Calculate a timeout from a deadline. Args: - raise_if_elapsed (bool): whether to raise :exc:`TimeoutError` + raise_if_elapsed (bool): Whether to raise :exc:`TimeoutError` if the deadline lapsed. Raises: - TimeoutError: if the deadline lapsed. + TimeoutError: If the deadline lapsed. Returns: - Optional[float]: Time left in seconds; - :obj:`None` if there is no limit. + Time left in seconds or :obj:`None` if there is no limit. """ if self.deadline is None: diff --git a/tests/protocol.py b/tests/protocol.py new file mode 100644 index 000000000..4e843daab --- /dev/null +++ b/tests/protocol.py @@ -0,0 +1,29 @@ +from websockets.protocol import Protocol + + +class RecordingProtocol(Protocol): + """ + Protocol subclass that records incoming frames. + + By interfacing with this protocol, you can check easily what the component + being testing sends during a test. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.frames_rcvd = [] + + def get_frames_rcvd(self): + """ + Get incoming frames received up to this point. + + Calling this method clears the list. Each frame is returned only once. + + """ + frames_rcvd, self.frames_rcvd = self.frames_rcvd, [] + return frames_rcvd + + def recv_frame(self, frame): + self.frames_rcvd.append(frame) + super().recv_frame(frame) diff --git a/tests/sync/client.py b/tests/sync/client.py new file mode 100644 index 000000000..51bbd4388 --- /dev/null +++ b/tests/sync/client.py @@ -0,0 +1,55 @@ +import contextlib +import ssl +import sys +import warnings + +from websockets.sync.client import * +from websockets.sync.server import WebSocketServer + +from ..utils import CERTIFICATE + + +__all__ = [ + "CLIENT_CONTEXT", + "run_client", + "run_unix_client", +] + + +CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) + +# Work around https://github.com/openssl/openssl/issues/7967 + +# This bug causes connect() to hang in tests for the client. Including this +# workaround acknowledges that the issue could happen outside of the test suite. + +# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it +# happens, we can look for a library-level fix, but it won't be easy. + +if sys.version_info[:2] < (3, 8): # pragma: no cover + # ssl.OP_NO_TLSv1_3 was introduced and deprecated on Python 3.7. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + CLIENT_CONTEXT.options |= ssl.OP_NO_TLSv1_3 + + +@contextlib.contextmanager +def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): + if isinstance(wsuri_or_server, str): + wsuri = wsuri_or_server + else: + assert isinstance(wsuri_or_server, WebSocketServer) + if secure is None: + secure = "ssl_context" in kwargs + protocol = "wss" if secure else "ws" + host, port = wsuri_or_server.socket.getsockname() + wsuri = f"{protocol}://{host}:{port}{resource_name}" + with connect(wsuri, **kwargs) as client: + yield client + + +@contextlib.contextmanager +def run_unix_client(path, **kwargs): + with unix_connect(path, **kwargs) as client: + yield client diff --git a/tests/sync/connection.py b/tests/sync/connection.py new file mode 100644 index 000000000..89d4909ee --- /dev/null +++ b/tests/sync/connection.py @@ -0,0 +1,109 @@ +import contextlib +import time + +from websockets.sync.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, you can simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.socket = InterceptingSocket(self.socket) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.socket.delay_sendall is None + self.socket.delay_sendall = delay + try: + yield + finally: + self.socket.delay_sendall = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.socket.delay_shutdown is None + self.socket.delay_shutdown = delay + try: + yield + finally: + self.socket.delay_shutdown = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.socket.drop_sendall + self.socket.drop_sendall = True + try: + yield + finally: + self.socket.drop_sendall = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.socket.drop_shutdown + self.socket.drop_shutdown = True + try: + yield + finally: + self.socket.drop_shutdown = False + + +class InterceptingSocket: + """ + Socket wrapper that intercepts calls to sendall and shutdown. + + This is coupled to the implementation, which relies on these two methods. + + """ + + def __init__(self, socket): + self.socket = socket + self.delay_sendall = None + self.delay_shutdown = None + self.drop_sendall = False + self.drop_shutdown = False + + def __getattr__(self, name): + return getattr(self.socket, name) + + def sendall(self, bytes, flags=0): + if self.delay_sendall is not None: + time.sleep(self.delay_sendall) + if not self.drop_sendall: + self.socket.sendall(bytes, flags) + + def shutdown(self, how): + if self.delay_shutdown is not None: + time.sleep(self.delay_shutdown) + if not self.drop_shutdown: + self.socket.shutdown(how) diff --git a/tests/sync/server.py b/tests/sync/server.py new file mode 100644 index 000000000..5f0cd3b07 --- /dev/null +++ b/tests/sync/server.py @@ -0,0 +1,67 @@ +import contextlib +import ssl +import sys +import threading + +from websockets.sync.server import * + +from ..utils import CERTIFICATE + + +SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +SERVER_CONTEXT.load_cert_chain(CERTIFICATE) + +# Work around https://github.com/openssl/openssl/issues/7967 + +# This bug causes connect() to hang in tests for the client. Including this +# workaround acknowledges that the issue could happen outside of the test suite. + +# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it +# happens, we can look for a library-level fix, but it won't be easy. + +if sys.version_info[:2] >= (3, 8): # pragma: no cover + SERVER_CONTEXT.num_tickets = 0 + + +def crash(ws): + raise RuntimeError + + +def do_nothing(ws): + pass + + +def eval_shell(ws): + for expr in ws: + value = eval(expr) + ws.send(str(value)) + + +class EvalShellMixin: + def assertEval(self, client, expr, value): + client.send(expr) + self.assertEqual(client.recv(), value) + + +@contextlib.contextmanager +def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs): + with serve(ws_handler, host, port, **kwargs) as server: + thread = threading.Thread(target=server.serve_forever) + thread.start() + try: + yield server + finally: + server.shutdown() + thread.join() + + +@contextlib.contextmanager +def run_unix_server(path, ws_handler=eval_shell, **kwargs): + with unix_serve(ws_handler, path, **kwargs) as server: + thread = threading.Thread(target=server.serve_forever) + thread.start() + try: + yield server + finally: + server.shutdown() + thread.join() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py new file mode 100644 index 000000000..8824ed894 --- /dev/null +++ b/tests/sync/test_client.py @@ -0,0 +1,271 @@ +import socket +import ssl +import threading +import unittest + +from websockets.exceptions import InvalidHandshake +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.sync.client import * + +from ..utils import MS, temp_unix_socket_path +from .client import CLIENT_CONTEXT, run_client, run_unix_client +from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server + + +class ClientTests(unittest.TestCase): + def test_connection(self): + """Client connects to server and the handshake succeeds.""" + with run_server() as server: + with run_client(server) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + def test_connection_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + with run_server(do_nothing, process_response=remove_accept_header) as server: + with self.assertRaisesRegex( + InvalidHandshake, + "missing Sec-WebSocket-Accept header", + ): + with run_client(server, close_timeout=MS): + self.fail("did not raise") + + def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + with run_server() as server: + with socket.create_connection(server.socket.getsockname()) as sock: + # Use a non-existing domain to ensure we connect to the right socket. + with run_client("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + with run_server() as server: + with run_client( + server, additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + with run_server() as server: + with run_client(server, user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + with run_server() as server: + with run_client(server, user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + def test_compression_is_enabled(self): + """Client enables compression by default.""" + with run_server() as server: + with run_client(server) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + def test_disable_compression(self): + """Client disables compression.""" + with run_server() as server: + with run_client(server, compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + with run_server() as server: + with run_client(server, create_connection=create_connection) as client: + self.assertTrue(client.create_connection_ran) + + def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + gate = threading.Event() + + def stall_connection(self, request): + gate.wait() + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + with run_server(do_nothing, process_request=stall_connection) as server: + try: + with self.assertRaisesRegex( + TimeoutError, + "timed out during handshake", + ): + with run_client(server, open_timeout=3 * MS): + self.fail("did not raise") + finally: + gate.set() + + def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + def close_connection(self, request): + self.close_socket() + + with run_server(process_request=close_connection) as server: + with self.assertRaisesRegex( + ConnectionError, + "connection closed during handshake", + ): + with run_client(server): + self.fail("did not raise") + + +class SecureClientTests(unittest.TestCase): + def test_connection(self): + """Client connects to server securely.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + + def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client( + path, + ssl_context=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + + def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client( + path, + ssl_context=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + + def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + with self.assertRaisesRegex( + ssl.SSLCertVerificationError, + r"certificate verify failed: self[ -]signed certificate", + ): + # The test certificate isn't trusted system-wide. + with run_client(server, secure=True): + self.fail("did not raise") + + def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + with self.assertRaisesRegex( + ssl.SSLCertVerificationError, + r"certificate verify failed: Hostname mismatch", + ): + # This hostname isn't included in the test certificate. + with run_client( + server, ssl_context=CLIENT_CONTEXT, server_hostname="invalid" + ): + self.fail("did not raise") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixClientTests(unittest.TestCase): + def test_connection(self): + """Client connects to server over a Unix socket.""" + with temp_unix_socket_path() as path: + with run_unix_server(path): + with run_unix_client(path) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + def test_set_host_header(self): + """Client sets the Host header to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + with run_unix_server(path): + with run_unix_client(path, uri="ws://overridden/") as client: + self.assertEqual(client.request.headers["Host"], "overridden") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixClientTests(unittest.TestCase): + def test_connection(self): + """Client connects to server securely over a Unix socket.""" + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + + def test_set_server_hostname(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client( + path, + ssl_context=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + + +class ClientUsageErrorsTests(unittest.TestCase): + def test_ssl_context_without_secure_uri(self): + """Client rejects ssl_context when URI isn't secure.""" + with self.assertRaisesRegex( + TypeError, + "ssl_context argument is incompatible with a ws:// URI", + ): + connect("ws://localhost/", ssl_context=CLIENT_CONTEXT) + + def test_unix_without_path_or_sock(self): + """Unix client requires path when sock isn't provided.""" + with self.assertRaisesRegex( + TypeError, + "missing path argument", + ): + unix_connect() + + def test_unix_with_path_and_sock(self): + """Unix client rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaisesRegex( + TypeError, + "path and sock arguments are incompatible", + ): + unix_connect(path="/", sock=sock) + + def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaisesRegex( + TypeError, + "subprotocols must be a list", + ): + connect("ws://localhost/", subprotocols="chat") + + def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaisesRegex( + ValueError, + "unsupported compression: False", + ): + connect("ws://localhost/", compression=False) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py new file mode 100644 index 000000000..94850affe --- /dev/null +++ b/tests/sync/test_connection.py @@ -0,0 +1,704 @@ +import contextlib +import logging +import socket +import sys +import threading +import time +import unittest +import uuid +from unittest.mock import patch + +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.frames import Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.sync.connection import * + +from ..protocol import RecordingProtocol +from ..utils import MS +from .connection import InterceptingConnection + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(unittest.TestCase): + LOCAL = CLIENT + REMOTE = SERVER + + def setUp(self): + socket_, remote_socket = socket.socketpair() + protocol = Protocol(self.LOCAL) + remote_protocol = RecordingProtocol(self.REMOTE) + self.connection = Connection(socket_, protocol, close_timeout=2 * MS) + self.remote_connection = InterceptingConnection(remote_socket, remote_protocol) + + def tearDown(self): + self.remote_connection.close() + self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + time.sleep(MS) # let the remote side process messages + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + def assertNoFrameSent(self): + """Check that no frame was sent.""" + time.sleep(MS) # let the remote side process messages + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.contextmanager + def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + time.sleep(MS) # let the remote side process messages + + @contextlib.contextmanager + def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + time.sleep(MS) # let the remote side process messages + + @contextlib.contextmanager + def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + time.sleep(MS) # let the remote side process messages + + @contextlib.contextmanager + def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + time.sleep(MS) # let the remote side process messages + + # Test __enter__ and __exit__. + + def test_enter(self): + """__enter__ returns the connection itself.""" + with self.connection as connection: + self.assertIs(connection, self.connection) + + def test_exit(self): + """__exit__ closes the connection with code 1000.""" + with self.connection: + self.assertNoFrameSent() + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + with self.connection: + raise RuntimeError + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __iter__. + + def test_iter_text(self): + """__iter__ yields text messages.""" + iterator = iter(self.connection) + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + + def test_iter_binary(self): + """__iter__ yields binary messages.""" + iterator = iter(self.connection) + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + + def test_iter_mixed(self): + """__iter__ yields a mix of text and binary messages.""" + iterator = iter(self.connection) + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + + def test_iter_connection_closed_ok(self): + """__iter__ terminates after a normal closure.""" + iterator = iter(self.connection) + self.remote_connection.close() + with self.assertRaises(StopIteration): + next(iterator) + + def test_iter_connection_closed_error(self): + """__iter__ raises ConnnectionClosedError after an error.""" + iterator = iter(self.connection) + self.remote_connection.close(code=1011) + with self.assertRaises(ConnectionClosedError): + next(iterator) + + # Test recv. + + def test_recv_text(self): + """recv receives a text message.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(), "😀") + + def test_recv_binary(self): + """recv receives a binary message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + + def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + self.remote_connection.send(["😀", "😀"]) + self.assertEqual(self.connection.recv(), "😀😀") + + def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + + def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + self.connection.recv() + + def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + self.remote_connection.close(code=1011) + with self.assertRaises(ConnectionClosedError): + self.connection.recv() + + def test_recv_during_recv(self): + """recv raises RuntimeError when called concurrently with itself.""" + recv_thread = threading.Thread(target=self.connection.recv) + recv_thread.start() + + with self.assertRaisesRegex( + RuntimeError, + "cannot call recv while another thread " + "is already running recv or recv_streaming", + ): + self.connection.recv() + + self.remote_connection.send("") + recv_thread.join() + + def test_recv_during_recv_streaming(self): + """recv raises RuntimeError when called concurrently with recv_streaming.""" + recv_streaming_thread = threading.Thread( + target=lambda: list(self.connection.recv_streaming()) + ) + recv_streaming_thread.start() + + with self.assertRaisesRegex( + RuntimeError, + "cannot call recv while another thread " + "is already running recv or recv_streaming", + ): + self.connection.recv() + + self.remote_connection.send("") + recv_streaming_thread.join() + + # Test recv_streaming. + + def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + self.remote_connection.send("😀") + self.assertEqual( + list(self.connection.recv_streaming()), + ["😀"], + ) + + def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + list(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + list(self.connection.recv_streaming()) + + def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + self.remote_connection.close(code=1011) + with self.assertRaises(ConnectionClosedError): + list(self.connection.recv_streaming()) + + def test_recv_streaming_during_recv(self): + """recv_streaming raises RuntimeError when called concurrently with recv.""" + recv_thread = threading.Thread(target=self.connection.recv) + recv_thread.start() + + with self.assertRaisesRegex( + RuntimeError, + "cannot call recv_streaming while another thread " + "is already running recv or recv_streaming", + ): + list(self.connection.recv_streaming()) + + self.remote_connection.send("") + recv_thread.join() + + def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises RuntimeError when called concurrently with itself.""" + recv_streaming_thread = threading.Thread( + target=lambda: list(self.connection.recv_streaming()) + ) + recv_streaming_thread.start() + + with self.assertRaisesRegex( + RuntimeError, + r"cannot call recv_streaming while another thread " + r"is already running recv or recv_streaming", + ): + list(self.connection.recv_streaming()) + + self.remote_connection.send("") + recv_streaming_thread.join() + + # Test send. + + def test_send_text(self): + """send sends a text message.""" + self.connection.send("😀") + self.assertEqual(self.remote_connection.recv(), "😀") + + def test_send_binary(self): + """send sends a binary message.""" + self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + self.connection.send("😀") + + def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + self.remote_connection.close(code=1011) + with self.assertRaises(ConnectionClosedError): + self.connection.send("😀") + + def test_send_during_send(self): + """send raises RuntimeError when called concurrently with itself.""" + recv_thread = threading.Thread(target=self.remote_connection.recv) + recv_thread.start() + + send_gate = threading.Event() + exit_gate = threading.Event() + + def fragments(): + yield "😀" + send_gate.set() + exit_gate.wait() + yield "😀" + + send_thread = threading.Thread( + target=self.connection.send, + args=(fragments(),), + ) + send_thread.start() + + send_gate.wait() + # The check happens in four code paths, depending on the argument. + for message in [ + "😀", + b"\x01\x02\xfe\xff", + ["😀", "😀"], + [b"\x01\x02", b"\xfe\xff"], + ]: + with self.subTest(message=message): + with self.assertRaisesRegex( + RuntimeError, + "cannot call send while another thread is already running send", + ): + self.connection.send(message) + + exit_gate.set() + send_thread.join() + recv_thread.join() + + def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + self.connection.send([]) + self.connection.close() + self.assertEqual(list(iter(self.remote_connection)), []) + + def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + self.connection.send(["😀", b"\xfe\xff"]) + + def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + self.connection.send([None]) + + def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + self.connection.send({"type": "object"}) + + def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + self.connection.send(None) + + # Test close. + + def test_close(self): + """close sends a close frame.""" + self.connection.close() + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + self.connection.close(1001, "bye!") + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + with self.delay_frames_rcvd(MS): + self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + with self.delay_eof_rcvd(MS): + self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + self.connection.close() + + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + with self.drop_eof_rcvd(): + self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + # Remove socket.timeout when dropping Python < 3.10. + self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + + def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + self.connection.close() + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + self.connection.close() + self.assertNoFrameSent() + + def test_close_idempotency_race_condition(self): + """close waits if the connection is already closing.""" + + self.connection.close_timeout = 5 * MS + + def closer(): + with self.delay_frames_rcvd(3 * MS): + self.connection.close() + + close_thread = threading.Thread(target=closer) + close_thread.start() + + # Let closer() initiate the closing handshake and send a close frame. + time.sleep(MS) + self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + # Connection isn't closed yet. + with self.assertRaises(TimeoutError): + self.connection.recv(timeout=0) + + self.connection.close() + self.assertNoFrameSent() + + # Connection is closed now. + with self.assertRaises(ConnectionClosedOK): + self.connection.recv(timeout=0) + + close_thread.join() + + def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + close_gate = threading.Event() + exit_gate = threading.Event() + + def closer(): + close_gate.wait() + self.connection.close() + exit_gate.set() + + def fragments(): + yield "😀" + close_gate.set() + exit_gate.wait() + yield "😀" + + close_thread = threading.Thread(target=closer) + close_thread.start() + + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.send(fragments()) + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (unexpected error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + close_thread.join() + + # Test ping. + + @patch("random.getrandbits") + def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.return_value = 1918987876 + self.connection.ping() + getrandbits.assert_called_once_with(32) + self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + self.connection.ping("ping") + self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + self.connection.ping(b"ping") + self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + with self.remote_connection.protocol_mutex: # block response to ping + pong_waiter = self.connection.ping("idem") + with self.assertRaisesRegex( + RuntimeError, + "already waiting for a pong with the same data", + ): + self.connection.ping("idem") + self.assertTrue(pong_waiter.wait(MS)) + self.connection.ping("idem") # doesn't raise an exception + + def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + with self.drop_frames_rcvd(): + pong_waiter = self.connection.ping("this") + self.assertFalse(pong_waiter.wait(MS)) + self.remote_connection.pong("this") + self.assertTrue(pong_waiter.wait(MS)) + + def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + with self.drop_frames_rcvd(): + pong_waiter = self.connection.ping("this") + self.remote_connection.pong("that") + self.assertFalse(pong_waiter.wait(MS)) + + def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong with the same payload as a later ping.""" + with self.drop_frames_rcvd(): + pong_waiter = self.connection.ping("this") + self.connection.ping("that") + self.remote_connection.pong("that") + self.assertTrue(pong_waiter.wait(MS)) + + # Test pong. + + def test_pong(self): + """pong sends a pong frame.""" + self.connection.pong() + self.assertFrameSent(Frame(Opcode.PONG, b"")) + + def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + self.connection.pong("pong") + self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + self.connection.pong(b"pong") + self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + # Test attributes. + + def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + def test_local_address(self): + """Connection has a local_address attribute.""" + self.assertIsNotNone(self.connection.local_address) + + def test_remote_address(self): + """Connection has a remote_address attribute.""" + self.assertIsNotNone(self.connection.remote_address) + + def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + # Test reporting of network errors. + + @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") + def test_reading_in_recv_events_fails(self): + """Error when reading incoming frames is correctly reported.""" + # Inject a fault by closing the socket. This works only on BSD. + # I cannot find a way to achieve the same effect on Linux. + self.connection.socket.close() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, IOError) + + def test_writing_in_recv_events_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the socket for writing — but not by + # closing it because that would terminate the connection. + self.connection.socket.shutdown(socket.SHUT_WR) + # Receive a ping. Responding with a pong will fail. + self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) + + def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the socket for writing — but not by + # closing it because that would terminate the connection. + self.connection.socket.shutdown(socket.SHUT_WR) + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.pong() + self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received") + def test_unexpected_failure_in_recv_events(self, events_received): + """Unexpected internal error in recv_events() is correctly reported.""" + # Inject a fault in a random call in recv_events(). + # This test is tightly coupled to the implementation. + events_received.side_effect = AssertionError + # Receive a message to trigger the fault. + self.remote_connection.send("😀") + + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + @patch("websockets.protocol.Protocol.send_text") + def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + send_text.side_effect = AssertionError + + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.send("😀") + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 069da784b..825eb8797 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -31,7 +31,7 @@ def setUp(self): def tearDown(self): """ - Ensure the assembler goes back to its default state after each test. + Check that the assembler goes back to its default state after each test. This removes the need for testing various sequences. diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py new file mode 100644 index 000000000..536858149 --- /dev/null +++ b/tests/sync/test_server.py @@ -0,0 +1,389 @@ +import dataclasses +import http +import logging +import socket +import threading +import unittest + +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response +from websockets.sync.compatibility import socket_create_server +from websockets.sync.server import * + +from ..utils import MS, temp_unix_socket_path +from .client import CLIENT_CONTEXT, run_client, run_unix_client +from .server import ( + SERVER_CONTEXT, + EvalShellMixin, + crash, + do_nothing, + eval_shell, + run_server, + run_unix_server, +) + + +class ServerTests(EvalShellMixin, unittest.TestCase): + def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + with run_server() as server: + with run_client(server) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + + def test_connection_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + with run_server(process_request=remove_key_header) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 400", + ): + with run_client(server): + self.fail("did not raise") + + def test_connection_handler_returns(self): + """Connection handler returns.""" + with run_server(do_nothing) as server: + with run_client(server) as client: + with self.assertRaisesRegex( + ConnectionClosedOK, + r"received 1000 \(OK\); then sent 1000 \(OK\)", + ): + client.recv() + + def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + with run_server(crash) as server: + with run_client(server) as client: + with self.assertRaisesRegex( + ConnectionClosedError, + r"received 1011 \(unexpected error\); " + r"then sent 1011 \(unexpected error\)", + ): + client.recv() + + def test_existing_socket(self): + """Server receives connection using a pre-existing socket.""" + with socket_create_server(("localhost", 0)) as sock: + with run_server(sock=sock): + # Build WebSocket URI to ensure we connect to the right socket. + with run_client("ws://{}:{}/".format(*sock.getsockname())) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + + def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + with run_server( + subprotocols=["chat"], + select_subprotocol=select_subprotocol, + ) as server: + with run_client(server, subprotocols=["chat"]) as client: + self.assertEval(client, "ws.select_subprotocol_ran", "True") + self.assertEval(client, "ws.subprotocol", "chat") + + def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 400", + ): + with run_client(server): + self.fail("did not raise") + + def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 500", + ): + with run_client(server): + self.fail("did not raise") + + def test_process_request(self): + """Server runs process_request before processing the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + with run_server(process_request=process_request) as server: + with run_client(server) as client: + self.assertEval(client, "ws.process_request_ran", "True") + + def test_process_request_abort_handshake(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + with run_server(process_request=process_request) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 403", + ): + with run_client(server): + self.fail("did not raise") + + def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError + + with run_server(process_request=process_request) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 500", + ): + with run_client(server): + self.fail("did not raise") + + def test_process_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + with run_server(process_response=process_response) as server: + with run_client(server) as client: + self.assertEval(client, "ws.process_response_ran", "True") + + def test_process_response_override_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + with run_server(process_response=process_response) as server: + with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError + + with run_server(process_response=process_response) as server: + with self.assertRaisesRegex( + InvalidStatus, + "server rejected WebSocket connection: HTTP 500", + ): + with run_client(server): + self.fail("did not raise") + + def test_override_server(self): + """Server can override Server header with server_header.""" + with run_server(server_header="Neo") as server: + with run_client(server) as client: + self.assertEval(client, "ws.response.headers['Server']", "Neo") + + def test_remove_server(self): + """Server can remove Server header with server_header.""" + with run_server(server_header=None) as server: + with run_client(server) as client: + self.assertEval(client, "'Server' in ws.response.headers", "False") + + def test_compression_is_enabled(self): + """Server enables compression by default.""" + with run_server() as server: + with run_client(server) as client: + self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + def test_disable_compression(self): + """Server disables compression.""" + with run_server(compression=None) as server: + with run_client(server) as client: + self.assertEval(client, "ws.protocol.extensions", "[]") + + def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + with run_server(create_connection=create_connection) as server: + with run_client(server) as client: + self.assertEval(client, "ws.create_connection_ran", "True") + + def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + with run_server(open_timeout=MS) as server: + with socket.create_connection(server.socket.getsockname()) as sock: + self.assertEqual(sock.recv(4096), b"") + + def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + with run_server() as server: + # Patch handler to record a reference to the thread running it. + server_thread = None + conn_received = threading.Event() + original_handler = server.handler + + def handler(sock, addr): + nonlocal server_thread + server_thread = threading.current_thread() + nonlocal conn_received + conn_received.set() + original_handler(sock, addr) + + server.handler = handler + + with socket.create_connection(server.socket.getsockname()): + # Wait for the server to receive the connection, then close it. + conn_received.wait() + + # Wait for the server thread to terminate. + server_thread.join() + + +class SecureServerTests(EvalShellMixin, unittest.TestCase): + def test_connection(self): + """Server receives secure connection from client.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + self.assertEval(client, "ws.socket.version()[:3]", "TLS") + + def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + with run_server(ssl_context=SERVER_CONTEXT, open_timeout=MS) as server: + with socket.create_connection(server.socket.getsockname()) as sock: + self.assertEqual(sock.recv(4096), b"") + + def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + with run_server(ssl_context=SERVER_CONTEXT) as server: + # Patch handler to record a reference to the thread running it. + server_thread = None + conn_received = threading.Event() + original_handler = server.handler + + def handler(sock, addr): + nonlocal server_thread + server_thread = threading.current_thread() + nonlocal conn_received + conn_received.set() + original_handler(sock, addr) + + server.handler = handler + + with socket.create_connection(server.socket.getsockname()): + # Wait for the server to receive the connection, then close it. + conn_received.wait() + + # Wait for the server thread to terminate. + server_thread.join() + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixServerTests(EvalShellMixin, unittest.TestCase): + def test_connection(self): + """Server receives connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + with run_unix_server(path): + with run_unix_client(path) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixServerTests(EvalShellMixin, unittest.TestCase): + def test_connection(self): + """Server receives secure connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + self.assertEval(client, "ws.protocol.state.name", "OPEN") + self.assertEval(client, "ws.socket.version()[:3]", "TLS") + + +class ServerUsageErrorsTests(unittest.TestCase): + def test_unix_without_path_or_sock(self): + """Unix server requires path when sock isn't provided.""" + with self.assertRaisesRegex( + TypeError, + "missing path argument", + ): + unix_serve(eval_shell) + + def test_unix_with_path_and_sock(self): + """Unix server rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaisesRegex( + TypeError, + "path and sock arguments are incompatible", + ): + unix_serve(eval_shell, path="/", sock=sock) + + def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaisesRegex( + TypeError, + "subprotocols must be a list", + ): + serve(eval_shell, subprotocols="chat") + + def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaisesRegex( + ValueError, + "unsupported compression: False", + ): + serve(eval_shell, compression=False) + + +class WebSocketServerTests(unittest.TestCase): + def test_logger(self): + """WebSocketServer accepts a logger argument.""" + logger = logging.getLogger("test") + with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + + def test_fileno(self): + """WebSocketServer provides a fileno attribute.""" + with run_server() as server: + self.assertIsInstance(server.fileno(), int) + + def test_shutdown(self): + """WebSocketServer provides a shutdown method.""" + with run_server() as server: + server.shutdown() + # Check that the server socket is closed. + with self.assertRaises(OSError): + server.socket.accept() From 6df903cee5b91e1c7aa6c3f9716537a55d0b11cd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Feb 2023 17:27:25 +0100 Subject: [PATCH 394/762] Document thread-based implementation. --- README.rst | 48 +++++++++------- docs/faq/asyncio.rst | 12 ++-- docs/faq/client.rst | 5 ++ docs/faq/misc.rst | 9 +-- docs/index.rst | 30 ++++++---- docs/project/changelog.rst | 11 ++++ docs/reference/index.rst | 12 ++++ docs/reference/sync/client.rst | 49 ++++++++++++++++ docs/reference/sync/common.rst | 39 +++++++++++++ docs/reference/sync/server.rst | 60 +++++++++++++++++++ docs/spelling_wordlist.txt | 1 + example/echo.py | 4 +- example/hello.py | 13 +++-- src/websockets/client.py | 2 +- src/websockets/legacy/auth.py | 16 +++--- src/websockets/legacy/client.py | 67 ++++++++++------------ src/websockets/legacy/framing.py | 28 ++++----- src/websockets/legacy/handshake.py | 18 +++--- src/websockets/legacy/http.py | 16 +++--- src/websockets/legacy/protocol.py | 92 +++++++++++++++--------------- src/websockets/legacy/server.py | 74 ++++++++++++------------ src/websockets/protocol.py | 12 ++-- src/websockets/server.py | 6 +- src/websockets/sync/client.py | 4 +- 24 files changed, 405 insertions(+), 223 deletions(-) create mode 100644 docs/reference/sync/client.rst create mode 100644 docs/reference/sync/common.rst create mode 100644 docs/reference/sync/server.rst diff --git a/README.rst b/README.rst index fa1d91061..5ba523e8f 100644 --- a/README.rst +++ b/README.rst @@ -30,47 +30,52 @@ with a focus on correctness, simplicity, robustness, and performance. .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API -Built on top of ``asyncio``, Python's standard asynchronous I/O framework, it -provides an elegant coroutine-based API. +Built on top of ``asyncio``, Python's standard asynchronous I/O framework, the +default implementation provides an elegant coroutine-based API. -`Documentation is available on Read the Docs. `_ +An implementation on top of ``threading`` and a Sans-I/O implementation are also +available. -Here's how a client sends and receives messages: +`Documentation is available on Read the Docs. `_ .. copy-pasted because GitHub doesn't support the include directive +Here's an echo server with the ``asyncio`` API: + .. code:: python #!/usr/bin/env python import asyncio - from websockets import connect + from websockets.server import serve - async def hello(uri): - async with connect(uri) as websocket: - await websocket.send("Hello world!") - await websocket.recv() + async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + async def main(): + async with serve(echo, "localhost", 8765): + await asyncio.Future() # run forever - asyncio.run(hello("ws://localhost:8765")) + asyncio.run(main()) -And here's an echo server: +Here's how a client sends and receives messages with the ``threading`` API: .. code:: python #!/usr/bin/env python import asyncio - from websockets import serve + from websockets.sync.client import connect - async def echo(websocket): - async for message in websocket: - await websocket.send(message) + def hello(): + with connect("ws://localhost:8765") as websocket: + websocket.send("Hello world!") + message = websocket.recv() + print(f"Received: {message}") - async def main(): - async with serve(echo, "localhost", 8765): - await asyncio.Future() # run forever + hello() - asyncio.run(main()) Does that look good? @@ -91,9 +96,8 @@ Why should I use ``websockets``? The development of ``websockets`` is shaped by four principles: -1. **Correctness**: ``websockets`` is heavily tested for compliance - with :rfc:`6455`. Continuous integration fails under 100% branch - coverage. +1. **Correctness**: ``websockets`` is heavily tested for compliance with + :rfc:`6455`. Continuous integration fails under 100% branch coverage. 2. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and ``await ws.send(msg)``. ``websockets`` takes care of managing connections diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index d00cf3f47..e56a42d36 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -46,17 +46,13 @@ See `issue 867`_. Why am I having problems with threads? -------------------------------------- -You shouldn't use threads. Use tasks instead. - -Indeed, when you chose websockets, you chose :mod:`asyncio` as the primary -framework to handle concurrency. This choice is mutually exclusive with -:mod:`threading`. +If you choose websockets' default implementation based on :mod:`asyncio`, then +you shouldn't use threads. Indeed, choosing :mod:`asyncio` to handle concurrency +is mutually exclusive with :mod:`threading`. If you believe that you need to run websockets in a thread and some logic in another thread, you should run that logic in a :class:`~asyncio.Task` instead. - -If you believe that you cannot run that logic in the same event loop because it -will block websockets, :meth:`~asyncio.loop.run_in_executor` may help. +If it blocks the event loop, :meth:`~asyncio.loop.run_in_executor` will help. This question is really about :mod:`asyncio`. Please review the advice about :ref:`asyncio-multithreading` in the Python documentation. diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 73825e480..c4f5a35b9 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -43,6 +43,11 @@ To set other HTTP headers, for example the ``Authorization`` header, use the async with connect(..., extra_headers={"Authorization": ...}) as websocket: ... +In the :mod:`threading` API, this argument is named ``additional_headers``:: + + with connect(..., additional_headers={"Authorization": ...}) as websocket: + ... + How do I close a connection? ---------------------------- diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index e320cb808..4fc271322 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -39,8 +39,8 @@ Why is the default implementation located in ``websockets.legacy``? ................................................................... This is an artifact of websockets' history. For its first eight years, only the -:mod:`asyncio`-based implementation existed. Then, the Sans-I/O implementation -was added. Moving the code in a ``legacy`` submodule eased this refactoring and +:mod:`asyncio` implementation existed. Then, the Sans-I/O implementation was +added. Moving the code in a ``legacy`` submodule eased this refactoring and optimized maintainability. All public APIs were kept at their original locations. ``websockets.legacy`` @@ -61,11 +61,6 @@ you may need to disable: If websockets is still slower than another Python library, please file a bug. -Can I use websockets without ``async`` and ``await``? -..................................................... - -No, there is no convenient way to do this. You should use another library. - Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ diff --git a/docs/index.rst b/docs/index.rst index d64d80ac3..be6a0da05 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,26 +21,36 @@ websockets .. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge :target: https://bestpractices.coreinfrastructure.org/projects/6475 -websockets is a library for building WebSocket_ servers and -clients in Python with a focus on correctness, simplicity, robustness, and -performance. +websockets is a library for building WebSocket_ servers and clients in Python +with a focus on correctness, simplicity, robustness, and performance. .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API -Built on top of :mod:`asyncio`, Python's standard asynchronous I/O framework, -it provides an elegant coroutine-based API. +It supports several network I/O and control flow paradigms: -Here's how a client sends and receives messages: +1. The default implementation builds upon :mod:`asyncio`, Python's standard + asynchronous I/O framework. It provides an elegant coroutine-based API. It's + ideal for servers that handle many clients concurrently. +2. The :mod:`threading` implementation is a good alternative for clients, + especially if you aren't familiar with :mod:`asyncio`. It may also be used + for servers that don't need to serve many clients. +3. The `Sans-I/O`_ implementation is designed for integrating in third-party + libraries, typically application servers, in addition being used internally + by websockets. -.. literalinclude:: ../example/hello.py +.. _Sans-I/O: https://sans-io.readthedocs.io/ -And here's an echo server: +Here's an echo server with the :mod:`asyncio` API: .. literalinclude:: ../example/echo.py +Here's how a client sends and receives messages with the :mod:`threading` API: + +.. literalinclude:: ../example/hello.py + Don't worry about the opening and closing handshakes, pings and pongs, or any -other behavior described in the specification. websockets takes care of this -under the hood so you can focus on your application! +other behavior described in the WebSocket specification. websockets takes care +of this under the hood so you can focus on your application! Also, websockets provides an interactive client: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 85a4ac92d..95193e780 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -66,6 +66,17 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 10.0 introduces a implementation on top of :mod:`threading`. + :class: important + + It may be more convenient if you don't need to manage many connections and + you're more comfortable with :mod:`threading` than :mod:`asyncio`. + + It is particularly suited to client applications that establish only one + connection. It may be used for servers handling few connections. + + See :func:`~sync.client.connect` and :func:`~sync.server.serve` for details. + * Made it possible to close a server without closing existing connections. * Added :attr:`~server.ServerProtocol.select_subprotocol` to customize diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 3b708ef91..fa8047c3d 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -16,6 +16,18 @@ clients concurrently. asyncio/client asyncio/common +:mod:`threading` +---------------- + +This alternative implementation can be a good choice for clients. + +.. toctree:: + :titlesonly: + + sync/server + sync/client + sync/common + `Sans-I/O`_ ----------- diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst new file mode 100644 index 000000000..6cccd6ec4 --- /dev/null +++ b/docs/reference/sync/client.rst @@ -0,0 +1,49 @@ +Client (:mod:`threading`) +========================= + +.. automodule:: websockets.sync.client + +Opening a connection +-------------------- + +.. autofunction:: connect(uri, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) + +.. autofunction:: unix_connect(path, uri=None, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __iter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst new file mode 100644 index 000000000..8d97ab3c1 --- /dev/null +++ b/docs/reference/sync/common.rst @@ -0,0 +1,39 @@ +Both sides (:mod:`threading`) +============================= + +.. automodule:: websockets.sync.connection + +.. autoclass:: Connection + + .. automethod:: __iter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst new file mode 100644 index 000000000..35c112046 --- /dev/null +++ b/docs/reference/sync/server.rst @@ -0,0 +1,60 @@ +Server (:mod:`threading`) +========================= + +.. automodule:: websockets.sync.server + +Creating a server +----------------- + +.. autofunction:: serve(handler, host=None, port=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) + +.. autofunction:: unix_serve(handler, path=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) + +Running a server +---------------- + +.. autoclass:: WebSocketServer + + .. automethod:: serve_forever + + .. automethod:: shutdown + + .. automethod:: fileno + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __iter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 9cc2182e1..dfa7065e7 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -22,6 +22,7 @@ css ctrl deserialize django +dev Dockerfile dyno formatter diff --git a/example/echo.py b/example/echo.py index 4b673cb17..2e47e52d9 100755 --- a/example/echo.py +++ b/example/echo.py @@ -1,14 +1,14 @@ #!/usr/bin/env python import asyncio -import websockets +from websockets.server import serve async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): - async with websockets.serve(echo, "localhost", 8765): + async with serve(echo, "localhost", 8765): await asyncio.Future() # run forever asyncio.run(main()) diff --git a/example/hello.py b/example/hello.py index 84f55dc52..a3ce0699e 100755 --- a/example/hello.py +++ b/example/hello.py @@ -1,11 +1,12 @@ #!/usr/bin/env python import asyncio -import websockets +from websockets.sync.client import connect -async def hello(): - async with websockets.connect("ws://localhost:8765") as websocket: - await websocket.send("Hello world!") - await websocket.recv() +def hello(): + with connect("ws://localhost:8765") as websocket: + websocket.send("Hello world!") + message = websocket.recv() + print(f"Received: {message}") -asyncio.run(hello()) +hello() diff --git a/src/websockets/client.py b/src/websockets/client.py index b5f871571..a0d077fc2 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -60,7 +60,7 @@ class ClientProtocol(Protocol): preference. state: initial state of the WebSocket connection. max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. + :obj:`None` disables the limit. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 3511469e6..d3425836e 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -135,20 +135,20 @@ def basic_auth_protocol_factory( ) Args: - realm: indicates the scope of protection. It should contain only ASCII - characters because the encoding of non-ASCII characters is - undefined. Refer to section 2.2 of :rfc:`7235` for details. - credentials: defines hard coded authorized credentials. It can be a + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. + Refer to section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a ``(username, password)`` pair or a list of such pairs. - check_credentials: defines a coroutine that verifies credentials. - This coroutine receives ``username`` and ``password`` arguments + check_credentials: Coroutine that verifies credentials. + It receives ``username`` and ``password`` arguments and returns a :class:`bool`. One of ``credentials`` or ``check_credentials`` must be provided but not both. - create_protocol: factory that creates the protocol. By default, this + create_protocol: Factory that creates the protocol. By default, this is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced by a subclass. Raises: - TypeError: if the ``credentials`` or ``check_credentials`` argument is + TypeError: If the ``credentials`` or ``check_credentials`` argument is wrong. """ diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index f8876f59d..aa71ddb6e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -130,7 +130,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: after this coroutine returns. Raises: - InvalidMessage: if the HTTP message is malformed or isn't an + InvalidMessage: If the HTTP message is malformed or isn't an HTTP/1.1 GET response. """ @@ -273,15 +273,15 @@ async def handshake( Args: wsuri: URI of the WebSocket server. - origin: value of the ``Origin`` header. - available_extensions: list of supported extensions, in order in - which they should be tried. - available_subprotocols: list of supported subprotocols, in order - of decreasing preference. - extra_headers: arbitrary HTTP headers to add to the request. + origin: Value of the ``Origin`` header. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + extra_headers: Arbitrary HTTP headers to add to the handshake request. Raises: - InvalidHandshake: if the handshake fails. + InvalidHandshake: If the handshake fails. """ request_headers = Headers() @@ -376,28 +376,26 @@ class Connect: Args: uri: URI of the WebSocket server. - create_protocol: factory for the :class:`asyncio.Protocol` managing - the connection; defaults to :class:`WebSocketClientProtocol`; may - be set to a wrapper or a subclass to customize connection handling. - logger: logger for this connection; - defaults to ``logging.getLogger("websockets.client")``; - see the :doc:`logging guide <../../topics/logging>` for details. - compression: shortcut that enables the "permessage-deflate" extension - by default; may be set to :obj:`None` to disable compression; - see the :doc:`compression guide <../../topics/compression>` for details. - origin: value of the ``Origin`` header. This is useful when connecting - to a server that validates the ``Origin`` header to defend against - Cross-Site WebSocket Hijacking attacks. - extensions: list of supported extensions, in order in which they - should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + create_protocol: Factory for the :class:`asyncio.Protocol` managing + the connection. It defaults to :class:`WebSocketClientProtocol`. + Set it to a wrapper or a subclass to customize connection handling. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing preference. - extra_headers: arbitrary HTTP headers to add to the request. - user_agent_header: value of the ``User-Agent`` request header; - defaults to ``"Python/x.y.z websockets/X.Y"``; - :obj:`None` removes the header. - open_timeout: timeout for opening the connection in seconds; - :obj:`None` to disable the timeout + extra_headers: Arbitrary HTTP headers to add to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -418,13 +416,10 @@ class Connect: the TCP connection. The host name from ``uri`` is still used in the TLS handshake for secure connections and in the ``Host`` header. - Returns: - WebSocketClientProtocol: WebSocket connection. - Raises: - InvalidURI: if ``uri`` isn't a valid WebSocket URI. - InvalidHandshake: if the opening handshake fails. - ~asyncio.TimeoutError: if the opening handshake times out. + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidHandshake: If the opening handshake fails. + ~asyncio.TimeoutError: If the opening handshake times out. """ @@ -705,7 +700,7 @@ def unix_connect( It's mainly useful for debugging servers listening on Unix sockets. Args: - path: file system path to the Unix socket. + path: File system path to the Unix socket. uri: URI of the WebSocket server; the host is used in the TLS handshake for secure connections and in the ``Host`` header. diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 29864b136..4836eb284 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -51,16 +51,16 @@ async def read( Read a WebSocket frame. Args: - reader: coroutine that reads exactly the requested number of + reader: Coroutine that reads exactly the requested number of bytes, unless the end of file is reached. - mask: whether the frame should be masked i.e. whether the read + mask: Whether the frame should be masked i.e. whether the read happens on the server side. - max_size: maximum payload size in bytes. - extensions: list of extensions, applied in reverse order. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. Raises: - PayloadTooBig: if the frame exceeds ``max_size``. - ProtocolError: if the frame contains incorrect values. + PayloadTooBig: If the frame exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. """ @@ -128,14 +128,14 @@ def write( Write a WebSocket frame. Args: - frame: frame to write. - write: function that writes bytes. - mask: whether the frame should be masked i.e. whether the write + frame: Frame to write. + write: Function that writes bytes. + mask: Whether the frame should be masked i.e. whether the write happens on the client side. - extensions: list of extensions, applied in order. + extensions: List of extensions, applied in order. Raises: - ProtocolError: if the frame contains incorrect values. + ProtocolError: If the frame contains incorrect values. """ # The frame is written in a single call to write in order to prevent @@ -154,11 +154,11 @@ def parse_close(data: bytes) -> Tuple[int, str]: Parse the payload from a close frame. Returns: - Tuple[int, str]: close code and reason. + Close code and reason. Raises: - ProtocolError: if data is ill-formed. - UnicodeDecodeError: if the reason isn't valid UTF-8. + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. """ close = Close.parse(data) diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 569937bb9..ad8faf040 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -21,7 +21,7 @@ def build_request(headers: Headers) -> str: Update request headers passed in argument. Args: - headers: handshake request headers. + headers: Handshake request headers. Returns: str: ``key`` that must be passed to :func:`check_response`. @@ -45,14 +45,14 @@ def check_request(headers: Headers) -> str: the responsibility of the caller. Args: - headers: handshake request headers. + headers: Handshake request headers. Returns: str: ``key`` that must be passed to :func:`build_response`. Raises: - InvalidHandshake: if the handshake request is invalid; - then the server must return 400 Bad Request error. + InvalidHandshake: If the handshake request is invalid. + Then, the server must return a 400 Bad Request error. """ connection: List[ConnectionOption] = sum( @@ -110,8 +110,8 @@ def build_response(headers: Headers, key: str) -> None: Update response headers passed in argument. Args: - headers: handshake response headers. - key: returned by :func:`check_request`. + headers: Handshake response headers. + key: Returned by :func:`check_request`. """ headers["Upgrade"] = "websocket" @@ -128,11 +128,11 @@ def check_response(headers: Headers, key: str) -> None: the caller. Args: - headers: handshake response headers. - key: returned by :func:`build_request`. + headers: Handshake response headers. + key: Returned by :func:`build_request`. Raises: - InvalidHandshake: if the handshake response is invalid. + InvalidHandshake: If the handshake response is invalid. """ connection: List[ConnectionOption] = sum( diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 7cc3db844..2ac7f7092 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -56,12 +56,12 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: body, it may be read from ``stream`` after this coroutine returns. Args: - stream: input to read the request from + stream: Input to read the request from. Raises: - EOFError: if the connection is closed without a full HTTP request - SecurityError: if the request exceeds a security limit - ValueError: if the request isn't well formatted + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -103,12 +103,12 @@ async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers body, it may be read from ``stream`` after this coroutine returns. Args: - stream: input to read the response from + stream: Input to read the response from. Raises: - EOFError: if the connection is closed without a full HTTP response - SecurityError: if the response exceeds a security limit - ValueError: if the response isn't well formatted + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + ValueError: If the response isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index d31ec19a8..d1979cd12 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -79,33 +79,32 @@ class WebSocketCommonProtocol(asyncio.Protocol): simplicity. Once the connection is open, a Ping_ frame is sent every ``ping_interval`` - seconds. This serves as a keepalive. It helps keeping the connection - open, especially in the presence of proxies with short timeouts on - inactive connections. Set ``ping_interval`` to :obj:`None` to disable - this behavior. + seconds. This serves as a keepalive. It helps keeping the connection open, + especially in the presence of proxies with short timeouts on inactive + connections. Set ``ping_interval`` to :obj:`None` to disable this behavior. .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 If the corresponding Pong_ frame isn't received within ``ping_timeout`` - seconds, the connection is considered unusable and is closed with code - 1011. This ensures that the remote endpoint remains responsive. Set + seconds, the connection is considered unusable and is closed with code 1011. + This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to :obj:`None` to disable this behavior. .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. + The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds for clients and ``4 * close_timeout`` for servers. - See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. - - ``close_timeout`` needs to be a parameter of the protocol because - websockets usually calls :meth:`close` implicitly upon exit: + ``close_timeout`` is a parameter of the protocol because websockets usually + calls :meth:`close` implicitly upon exit: - * on the client side, when :func:`~websockets.client.connect` is used as a + * on the client side, when using :func:`~websockets.client.connect` as a context manager; - * on the server side, when the connection handler terminates; + * on the server side, when the connection handler terminates. To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. @@ -144,21 +143,21 @@ class WebSocketCommonProtocol(asyncio.Protocol): See the discussion of :doc:`memory usage <../../topics/memory>` for details. Args: - logger: logger for this connection; - defaults to ``logging.getLogger("websockets.protocol")``; - see the :doc:`logging guide <../../topics/logging>` for details. - ping_interval: delay between keepalive pings in seconds; - :obj:`None` to disable keepalive pings. - ping_timeout: timeout for keepalive pings in seconds; - :obj:`None` to disable timeouts. - close_timeout: timeout for closing the connection in seconds; - for legacy reasons, the actual timeout is 4 or 5 times larger. - max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. - max_queue: maximum number of incoming messages in receive buffer; - :obj:`None` to disable the limit. - read_limit: high-water mark of read buffer in bytes. - write_limit: high-water mark of write buffer in bytes. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.protocol")``. + See the :doc:`logging guide <../../topics/logging>` for details. + ping_interval: Delay between keepalive pings in seconds. + :obj:`None` disables keepalive pings. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + For legacy reasons, the actual timeout is 4 or 5 times larger. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + max_queue: Maximum number of incoming messages in receive buffer. + :obj:`None` disables the limit. + read_limit: High-water mark of read buffer in bytes. + write_limit: High-water mark of write buffer in bytes. """ @@ -484,10 +483,11 @@ async def __aiter__(self) -> AsyncIterator[Data]: """ Iterate on incoming messages. - The iterator exits normally when the connection is closed with the - close code 1000 (OK) or 1001(going away) or without a close code. It - raises a :exc:`~websockets.exceptions.ConnectionClosedError` exception - when the connection is closed with any other code. + The iterator exits normally when the connection is closed with the close + code 1000 (OK) or 1001 (going away) or without a close code. + + It raises a :exc:`~websockets.exceptions.ConnectionClosedError` + exception when the connection is closed with any other code. """ try: @@ -501,8 +501,8 @@ async def recv(self) -> Data: Receive the next message. When the connection is closed, :meth:`recv` raises - :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it - raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal connection closure and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol error or a network failure. This is how you detect the end of the @@ -511,8 +511,8 @@ async def recv(self) -> Data: Canceling :meth:`recv` is safe. There's no risk of losing the next message. The next invocation of :meth:`recv` will return it. - This makes it possible to enforce a timeout by wrapping :meth:`recv` - in :func:`~asyncio.wait_for`. + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.wait_for`. Returns: Data: A string (:class:`str`) for a Text_ frame. A bytestring @@ -522,8 +522,8 @@ async def recv(self) -> Data: .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 Raises: - ConnectionClosed: when the connection is closed. - RuntimeError: if two coroutines call :meth:`recv` concurrently. + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` concurrently. """ if self._pop_message_waiter is not None: @@ -626,8 +626,8 @@ async def send( to send. Raises: - ConnectionClosed: when the connection is closed. - TypeError: if ``message`` doesn't have a supported type. + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. """ await self.ensure_open() @@ -841,8 +841,8 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: latency = await pong_waiter Raises: - ConnectionClosed: when the connection is closed. - RuntimeError: if another ping was sent with the same data and + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ @@ -881,11 +881,11 @@ async def pong(self, data: Data = b"") -> None: wait, you should close the connection. Args: - data (Data): payload of the pong; a string will be encoded to + data (Data): Payload of the pong. A string will be encoded to UTF-8. Raises: - ConnectionClosed: when the connection is closed. + ConnectionClosed: When the connection is closed. """ await self.ensure_open() @@ -1604,11 +1604,11 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N Args: websockets (Iterable[WebSocketCommonProtocol]): WebSocket connections to which the message will be sent. - message (Data): message to send. + message (Data): Message to send. Raises: - RuntimeError: if a connection is busy sending a fragmented message. - TypeError: if ``message`` doesn't have a supported type. + RuntimeError: If a connection is busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. """ if not isinstance(message, (str, bytes, bytearray, memoryview)): diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 048a270b5..399df85d3 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -657,9 +657,9 @@ class WebSocketServer: when shutting down. Args: - logger: logger for this server; - defaults to ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../../topics/logging>` for details. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. """ @@ -918,43 +918,42 @@ class Serve: await stop Args: - ws_handler: connection handler. It receives the WebSocket connection, + ws_handler: Connection handler. It receives the WebSocket connection, which is a :class:`WebSocketServerProtocol`, in argument. - host: network interfaces the server is bound to; - see :meth:`~asyncio.loop.create_server` for details. - port: TCP port the server listens on; - see :meth:`~asyncio.loop.create_server` for details. - create_protocol: factory for the :class:`asyncio.Protocol` managing - the connection; defaults to :class:`WebSocketServerProtocol`; may - be set to a wrapper or a subclass to customize connection handling. - logger: logger for this server; - defaults to ``logging.getLogger("websockets.server")``; - see the :doc:`logging guide <../../topics/logging>` for details. - compression: shortcut that enables the "permessage-deflate" extension - by default; may be set to :obj:`None` to disable compression; - see the :doc:`compression guide <../../topics/compression>` for details. - origins: acceptable values of the ``Origin`` header; include - :obj:`None` in the list if the lack of an origin is acceptable. - This is useful for defending against Cross-Site WebSocket - Hijacking attacks. - extensions: list of supported extensions, in order in which they - should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + create_protocol: Factory for the :class:`asyncio.Protocol` managing + the connection. It defaults to :class:`WebSocketServerProtocol`. + Set it to a wrapper or a subclass to customize connection handling. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing preference. extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]): - arbitrary HTTP headers to add to the request; this can be + Arbitrary HTTP headers to add to the response. This can be a :data:`~websockets.datastructures.HeadersLike` or a callable taking the request path and headers in arguments and returning a :data:`~websockets.datastructures.HeadersLike`. - server_header: value of the ``Server`` response header; - defaults to ``"Python/x.y.z websockets/X.Y"``; - :obj:`None` removes the header. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. process_request (Optional[Callable[[str, Headers], \ Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): - intercept HTTP request before the opening handshake; - see :meth:`~WebSocketServerProtocol.process_request` for details. - select_subprotocol: select a subprotocol supported by the client; - see :meth:`~WebSocketServerProtocol.select_subprotocol` for details. + Intercept HTTP request before the opening handshake. + See :meth:`~WebSocketServerProtocol.process_request` for details. + select_subprotocol: Select a subprotocol supported by the client. + See :meth:`~WebSocketServerProtocol.select_subprotocol` for details. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -1137,17 +1136,18 @@ def unix_serve( **kwargs: Any, ) -> Serve: """ - Similar to :func:`serve`, but for listening on Unix sockets. + Start a WebSocket server listening on a Unix socket. - This function builds upon the event - loop's :meth:`~asyncio.loop.create_unix_server` method. + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It is only available on Unix. - It is only available on Unix. + Unrecognized keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_unix_server` method. It's useful for deploying a server behind a reverse proxy such as nginx. Args: - path: file system path to the Unix socket. + path: File system path to the Unix socket. """ return serve(ws_handler, path=path, unix=True, **kwargs) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 7bfa96f8b..3fdd3881c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -75,7 +75,7 @@ class Protocol: side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. state: initial state of the WebSocket connection. max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. + :obj:`None` disables the limit. logger: logger for this connection; depending on ``side``, defaults to ``logging.getLogger("websockets.client")`` or ``logging.getLogger("websockets.server")``; @@ -263,7 +263,8 @@ def receive_eof(self) -> None: After calling this method: - - You must call :meth:`data_to_send` and send this data to the network. + - You must call :meth:`data_to_send` and send this data to the network; + it will return ``[b""]``, signaling the end of the stream, or ``[]``. - You aren't expected to call :meth:`events_received`; it won't return any new events. @@ -481,8 +482,8 @@ def close_expected(self) -> bool: """ Tell if the TCP connection is expected to close soon. - Call this method immediately after any of the ``receive_*()`` or - :meth:`fail` methods. + Call this method immediately after any of the ``receive_*()``, + ``send_close()``, or :meth:`fail` methods. If it returns :obj:`True`, schedule closing the TCP connection after a short timeout if the other side hasn't already closed it. @@ -510,6 +511,9 @@ def parse(self) -> Generator[None, None, None]: :meth:`receive_data` and :meth:`receive_eof` run this generator coroutine until it needs more data or reaches EOF. + :meth:`parse` never raises an exception. Instead, it sets the + :attr:`parser_exc` and yields control. + """ try: while True: diff --git a/src/websockets/server.py b/src/websockets/server.py index 0dd579052..e49f30213 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -58,13 +58,13 @@ class ServerProtocol(Protocol): should be tried. subprotocols: list of supported subprotocols, in order of decreasing preference. - select_subprotocol: callback for selecting a subprotocol among + select_subprotocol: Callback for selecting a subprotocol among those supported by the client and the server. It has the same signature as the :meth:`select_subprotocol` method, including a :class:`ServerProtocol` instance as first argument. state: initial state of the WebSocket connection. max_size: maximum size of incoming messages in bytes; - :obj:`None` to disable the limit. + :obj:`None` disables the limit. logger: logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. @@ -120,7 +120,7 @@ def accept(self, request: Request) -> Response: request: WebSocket handshake request event received from the client. Returns: - Response: WebSocket handshake response event to send to the client. + WebSocket handshake response event to send to the client. """ try: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e5922582d..ec1167115 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -168,8 +168,8 @@ def connect( from ``uri``. You may call :func:`socket.create_connection` to create a suitable TCP socket. ssl_context: Configuration for enabling TLS on the connection. - server_hostname: Hostname for the TLS handshake. ``server_hostname`` - overrides the hostname from ``uri``. + server_hostname: Host name for the TLS handshake. ``server_hostname`` + overrides the host name from ``uri``. origin: Value of the ``Origin`` header, for servers that require it. extensions: List of supported extensions, in order in which they should be negotiated and run. From 399db68dfa31768569b5257e862cd08987286be1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 07:18:02 +0200 Subject: [PATCH 395/762] Ignore compatibility module in coverage measurement. --- setup.cfg | 1 + src/websockets/legacy/compatibility.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 3a8321c50..c2ca2d5dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,7 @@ branch = True omit = # */websockets matches src/websockets and .tox/**/site-packages/websockets */websockets/__main__.py + */websockets/legacy/compatibility.py tests/maxi_cov.py [coverage:paths] diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 296e9c584..303e203b4 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -14,7 +14,7 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ return {} -else: # pragma: no cover +else: def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ From 45f3b6047a3b09f7685fa1dc585ee34097480e58 Mon Sep 17 00:00:00 2001 From: shafemtol Date: Fri, 24 Feb 2023 18:32:26 +0100 Subject: [PATCH 396/762] Set server_hostname automatically when needed --- src/websockets/legacy/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index aa71ddb6e..f2d2c72b7 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -530,6 +530,8 @@ def __init__( else: # If sock is given, host and port shouldn't be specified. host, port = None, None + if kwargs.get("ssl"): + kwargs.setdefault("server_hostname", wsuri.host) # If host and port are given, override values from the URI. host = kwargs.pop("host", host) port = kwargs.pop("port", port) From b0876be3b7770854f1dd933e190205ee884405b1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 07:51:17 +0200 Subject: [PATCH 397/762] Simplify tests thanks to previous commit --- tests/legacy/test_client_server.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index d92338585..2da203029 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -428,11 +428,7 @@ def send(self, *args, **kwargs): self.assertFalse(client_socket.used_for_read) self.assertFalse(client_socket.used_for_write) - with self.temp_client( - sock=client_socket, - # "You must set server_hostname when using ssl without a host" - server_hostname="localhost" if self.secure else None, - ): + with self.temp_client(sock=client_socket): self.loop.run_until_complete(self.client.send("Hello!")) reply = self.loop.run_until_complete(self.client.recv()) self.assertEqual(reply, "Hello!") From 132bf7fa44bc9e0ed3ff29aa6bfd7cb9303e49a6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 07:52:41 +0200 Subject: [PATCH 398/762] Add changelog entry for previous commits --- docs/project/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 95193e780..5a3afcc43 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -82,6 +82,12 @@ New features * Added :attr:`~server.ServerProtocol.select_subprotocol` to customize negotiation of subprotocols in the Sans-I/O layer. +Improvements +............ + +* Set ``server_hostname`` automatically on TLS connections when providing a + ``sock`` argument to :func:`~sync.client.connect`. + 10.4 ---- From f0e547965a53582b45df45c5b6202dc3a1284240 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 29 Mar 2023 05:47:41 +0000 Subject: [PATCH 399/762] Bump pypa/cibuildwheel from 2.11.1 to 2.12.1 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.11.1 to 2.12.1. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.11.1...v2.12.1) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 0013ad103..90a075843 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -48,7 +48,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.11.1 + uses: pypa/cibuildwheel@v2.12.1 env: CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" From 09875a619f53fa7fb566d640943e7ea0c28dffae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 08:02:43 +0200 Subject: [PATCH 400/762] Attempt to fix #1317. --- tests/sync/test_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 8824ed894..c900f3b0f 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -111,7 +111,10 @@ def stall_connection(self, request): TimeoutError, "timed out during handshake", ): - with run_client(server, open_timeout=3 * MS): + # While it shouldn't take 50ms to open a connection, this + # test becomes flaky in CI when setting a smaller timeout, + # even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR. + with run_client(server, open_timeout=5 * MS): self.fail("did not raise") finally: gate.set() From 25a5252c385d867add96b9bc5df2d537b49d636a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Mar 2023 08:13:20 +0200 Subject: [PATCH 401/762] Skip test that fails randomly on PyPy. Fix #1314. --- tests/sync/test_connection.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 94850affe..f26ec3f95 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -1,5 +1,6 @@ import contextlib import logging +import platform import socket import sys import threading @@ -465,6 +466,10 @@ def test_close_idempotency(self): self.connection.close() self.assertNoFrameSent() + @unittest.skipIf( + platform.python_implementation() == "PyPy", + "this test fails randomly due to a bug in PyPy", # see #1314 for details + ) def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" From ce06dd6e15159dbfb83e33c16ec9afcb05ff1ec7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 08:49:19 +0200 Subject: [PATCH 402/762] Rewrite interactive client with synchronous API. Fix #1312. --- src/websockets/__main__.py | 139 +++++++++---------------------------- 1 file changed, 34 insertions(+), 105 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index c562d21b5..a7dd1aaef 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -1,16 +1,18 @@ from __future__ import annotations import argparse -import asyncio import os import signal import sys import threading -from typing import Any, Set -from .exceptions import ConnectionClosed -from .frames import Close -from .legacy.client import connect + +try: + import readline # noqa +except ImportError: # Windows has no `readline` normally + pass + +from .sync.client import ClientConnection, connect from .version import version as websockets_version @@ -46,21 +48,6 @@ def win_enable_vt100() -> None: raise RuntimeError("unable to set console mode") -def exit_from_event_loop_thread( - loop: asyncio.AbstractEventLoop, - stop: asyncio.Future[None], -) -> None: - loop.stop() - if not stop.done(): - # When exiting the thread that runs the event loop, raise - # KeyboardInterrupt in the main thread to exit the program. - if sys.platform == "win32": - ctrl_c = signal.CTRL_C_EVENT - else: - ctrl_c = signal.SIGINT - os.kill(os.getpid(), ctrl_c) - - def print_during_input(string: str) -> None: sys.stdout.write( # Save cursor position @@ -93,63 +80,20 @@ def print_over_input(string: str) -> None: sys.stdout.flush() -async def run_client( - uri: str, - loop: asyncio.AbstractEventLoop, - inputs: asyncio.Queue[str], - stop: asyncio.Future[None], -) -> None: - try: - websocket = await connect(uri) - except Exception as exc: - print_over_input(f"Failed to connect to {uri}: {exc}.") - exit_from_event_loop_thread(loop, stop) - return - else: - print_during_input(f"Connected to {uri}.") - - try: - while True: - incoming: asyncio.Future[Any] = asyncio.create_task(websocket.recv()) - outgoing: asyncio.Future[Any] = asyncio.create_task(inputs.get()) - done: Set[asyncio.Future[Any]] - pending: Set[asyncio.Future[Any]] - done, pending = await asyncio.wait( - [incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED - ) - - # Cancel pending tasks to avoid leaking them. - if incoming in pending: - incoming.cancel() - if outgoing in pending: - outgoing.cancel() - - if incoming in done: - try: - message = incoming.result() - except ConnectionClosed: - break - else: - if isinstance(message, str): - print_during_input("< " + message) - else: - print_during_input("< (binary) " + message.hex()) - - if outgoing in done: - message = outgoing.result() - await websocket.send(message) - - if stop in done: - break - - finally: - await websocket.close() - assert websocket.close_code is not None and websocket.close_reason is not None - close_status = Close(websocket.close_code, websocket.close_reason) - - print_over_input(f"Connection closed: {close_status}.") - - exit_from_event_loop_thread(loop, stop) +def print_incoming_messages(websocket: ClientConnection, stop: threading.Event) -> None: + for message in websocket: + if isinstance(message, str): + print_during_input("< " + message) + else: + print_during_input("< (binary) " + message.hex()) + if not stop.is_set(): + # When the server closes the connection, raise KeyboardInterrupt + # in the main thread to exit the program. + if sys.platform == "win32": + ctrl_c = signal.CTRL_C_EVENT + else: + ctrl_c = signal.SIGINT + os.kill(os.getpid(), ctrl_c) def main() -> None: @@ -184,29 +128,17 @@ def main() -> None: sys.stderr.flush() try: - import readline # noqa - except ImportError: # Windows has no `readline` normally - pass - - # Create an event loop that will run in a background thread. - loop = asyncio.new_event_loop() - - # Due to zealous removal of the loop parameter in the Queue constructor, - # we need a factory coroutine to run in the freshly created event loop. - async def queue_factory() -> asyncio.Queue[str]: - return asyncio.Queue() - - # Create a queue of user inputs. There's no need to limit its size. - inputs: asyncio.Queue[str] = loop.run_until_complete(queue_factory()) - - # Create a stop condition when receiving SIGINT or SIGTERM. - stop: asyncio.Future[None] = loop.create_future() + websocket = connect(args.uri) + except Exception as exc: + print(f"Failed to connect to {args.uri}: {exc}.") + sys.exit(1) + else: + print(f"Connected to {args.uri}.") - # Schedule the task that will manage the connection. - loop.create_task(run_client(args.uri, loop, inputs, stop)) + stop = threading.Event() - # Start the event loop in a background thread. - thread = threading.Thread(target=loop.run_forever) + # Start the thread that reads messages from the connection. + thread = threading.Thread(target=print_incoming_messages, args=(websocket, stop)) thread.start() # Read from stdin in the main thread in order to receive signals. @@ -214,17 +146,14 @@ async def queue_factory() -> asyncio.Queue[str]: while True: # Since there's no size limit, put_nowait is identical to put. message = input("> ") - loop.call_soon_threadsafe(inputs.put_nowait, message) + websocket.send(message) except (KeyboardInterrupt, EOFError): # ^C, ^D - loop.call_soon_threadsafe(stop.set_result, None) + stop.set() + websocket.close() + print_over_input("Connection closed.") - # Wait for the event loop to terminate. thread.join() - # For reasons unclear, even though the loop is closed in the thread, - # it still thinks it's running here. - loop.close() - if __name__ == "__main__": main() From 2fcc4837faca2a05f395805455070dc0347e9ab1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 08:29:51 +0200 Subject: [PATCH 403/762] Improve error handling in broadcast(). Fix #1319. --- docs/project/changelog.rst | 2 + docs/topics/logging.rst | 5 ++ src/websockets/legacy/protocol.py | 74 ++++++++++++++++++++++-------- tests/legacy/test_client_server.py | 4 +- tests/legacy/test_protocol.py | 60 ++++++++++++++++++++++-- 5 files changed, 119 insertions(+), 26 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5a3afcc43..5a0ea423d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -85,6 +85,8 @@ New features Improvements ............ +* Improved error handling in :func:`~websockets.broadcast`. + * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 294a6cda8..e7abd96ce 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -213,6 +213,11 @@ Here's what websockets logs at each level. * Exceptions raised by connection handler coroutines in servers * Exceptions resulting from bugs in websockets +``WARNING`` +........... + +* Failures in :func:`~websockets.broadcast` + ``INFO`` ........ diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index d1979cd12..67bca0ef4 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -7,6 +7,7 @@ import random import ssl import struct +import sys import time import uuid import warnings @@ -1573,13 +1574,17 @@ def eof_received(self) -> None: self.reader.feed_eof() -def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> None: +def broadcast( + websockets: Iterable[WebSocketCommonProtocol], + message: Data, + raise_exceptions: bool = False, +) -> None: """ Broadcast a message to several WebSocket connections. - A string (:class:`str`) is sent as a Text_ frame. A bytestring or - bytes-like object (:class:`bytes`, :class:`bytearray`, or - :class:`memoryview`) is sent as a Binary_ frame. + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent + as a Binary_ frame. .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 @@ -1587,33 +1592,42 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N :func:`broadcast` pushes the message synchronously to all connections even if their write buffers are overflowing. There's no backpressure. - :func:`broadcast` skips silently connections that aren't open in order to - avoid errors on connections where the closing handshake is in progress. - - If you broadcast messages faster than a connection can handle them, - messages will pile up in its write buffer until the connection times out. - Keep low values for ``ping_interval`` and ``ping_timeout`` to prevent - excessive memory usage by slow connections when you use :func:`broadcast`. + If you broadcast messages faster than a connection can handle them, messages + will pile up in its write buffer until the connection times out. Keep + ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage + from slow connections. Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, :func:`broadcast` doesn't support sending fragmented messages. Indeed, - fragmentation is useful for sending large messages without buffering - them in memory, while :func:`broadcast` buffers one copy per connection - as fast as possible. + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. + + :func:`broadcast` skips connections that aren't open in order to avoid + errors on connections where the closing handshake is in progress. + + :func:`broadcast` ignores failures to write the message on some connections. + It continues writing to other connections. On Python 3.11 and above, you + may set ``raise_exceptions`` to :obj:`True` to record failures and raise all + exceptions in a :pep:`654` :exc:`ExceptionGroup`. Args: - websockets (Iterable[WebSocketCommonProtocol]): WebSocket connections - to which the message will be sent. - message (Data): Message to send. + websockets: WebSocket connections to which the message will be sent. + message: Message to send. + raise_exceptions: Whether to raise an exception in case of failures. Raises: - RuntimeError: If a connection is busy sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ if not isinstance(message, (str, bytes, bytearray, memoryview)): raise TypeError("data must be str or bytes-like") + if raise_exceptions: + if sys.version_info[:2] < (3, 11): # pragma: no cover + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions = [] + opcode, data = prepare_data(message) for websocket in websockets: @@ -1621,6 +1635,26 @@ def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> N continue if websocket._fragmented_message_waiter is not None: - raise RuntimeError("busy sending a fragmented message") + if raise_exceptions: + exception = RuntimeError("sending a fragmented message") + exceptions.append(exception) + else: + websocket.logger.warning( + "skipped broadcast: sending a fragmented message", + ) + + try: + websocket.write_frame_sync(True, opcode, data) + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + websocket.logger.warning( + "skipped broadcast: failed to write message", + exc_info=True, + ) - websocket.write_frame_sync(True, opcode, data) + if raise_exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2da203029..b8fd259c9 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1553,11 +1553,11 @@ async def run_client(): await ws.recv() else: # Exit block with an exception. - raise Exception("BOOM!") + raise Exception("BOOM") pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: - with self.assertRaisesRegex(Exception, "BOOM!"): + with self.assertRaisesRegex(Exception, "BOOM"): self.loop.run_until_complete(run_client()) # Iteration 1 diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 328bc80a2..ab8155b9b 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1,5 +1,7 @@ import asyncio import contextlib +import logging +import sys import unittest import unittest.mock import warnings @@ -1468,26 +1470,76 @@ def test_broadcast_two_clients(self): def test_broadcast_skips_closed_connection(self): self.close_connection() - broadcast([self.protocol], "café") + with self.assertNoLogs(): + broadcast([self.protocol], "café") self.assertNoFrameSent() def test_broadcast_skips_closing_connection(self): close_task = self.half_close_connection_local() - broadcast([self.protocol], "café") + with self.assertNoLogs(): + broadcast([self.protocol], "café") self.assertNoFrameSent() self.loop.run_until_complete(close_task) # cleanup - def test_broadcast_within_fragmented_text(self): + def test_broadcast_skips_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) - with self.assertRaises(RuntimeError): + with self.assertLogs("websockets", logging.WARNING) as logs: + broadcast([self.protocol], "café") + + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: sending a fragmented message"], + ) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_reports_connection_sending_fragmented_text(self): + self.make_drain_slow() + self.loop.create_task(self.protocol.send(["ca", "fé"])) + self.run_loop_once() + self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.protocol], "café", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + self.assertEqual( + str(raised.exception.exceptions[0]), "sending a fragmented message" + ) + + def test_broadcast_skips_connection_failing_to_send(self): + # Configure mock to raise an exception when writing to the network. + self.protocol.transport.write.side_effect = RuntimeError + + with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: failed to write message"], + ) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_reports_connection_failing_to_send(self): + # Configure mock to raise an exception when writing to the network. + self.protocol.transport.write.side_effect = RuntimeError("BOOM") + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.protocol], "café", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + self.assertEqual(str(raised.exception.exceptions[0]), "failed to write message") + self.assertEqual(str(raised.exception.exceptions[0].__cause__), "BOOM") + class ServerTests(CommonTests, AsyncioTestCase): def setUp(self): From f269e1e9704ff2776dc5ed54d11e120aa06d62e3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 10:11:26 +0200 Subject: [PATCH 404/762] Document that connect() can raise OSError. Fix #1265. --- src/websockets/legacy/client.py | 1 + src/websockets/sync/client.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index f2d2c72b7..b79bbab8c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -418,6 +418,7 @@ class Connect: Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. ~asyncio.TimeoutError: If the opening handshake times out. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index ec1167115..087ff5f56 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -198,6 +198,7 @@ def connect( Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. From 0a58739a54c9de69b23f57e1183e9864e02d5513 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 14:19:02 +0200 Subject: [PATCH 405/762] It's an HTTP, but a URI. --- src/websockets/exceptions.py | 2 +- src/websockets/sync/server.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 1f4b9265c..22a3b583f 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -374,7 +374,7 @@ class InvalidState(WebSocketException, AssertionError): class InvalidURI(WebSocketException): """ - Raised when connecting to an URI that isn't a valid WebSocket URI. + Raised when connecting to a URI that isn't a valid WebSocket URI. """ diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index a53ae2b25..9284c6188 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -352,11 +352,11 @@ def handler(websocket): ` method. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force a HTTP 101 Continue response, + continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. process_response: Intercept the response during the opening handshake. Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force a HTTP 101 Continue response, + continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to From 8a6665892051b26d65344a573c1aea89f9c0e64c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 14:41:58 +0200 Subject: [PATCH 406/762] Pluralize API consistently. --- docs/project/changelog.rst | 2 +- docs/reference/index.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5a0ea423d..fc723a626 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -22,7 +22,7 @@ When a release contains backwards-incompatible API changes, the major version is increased, else the minor version is increased. Patch versions are only for fixing regressions shortly after a release. -Only documented API are public. Undocumented, private API may change without +Only documented APIs are public. Undocumented, private APIs may change without notice. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index fa8047c3d..cc4658dab 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -57,7 +57,7 @@ extensions. Shared ------ -These low-level API are shared by all implementations. +These low-level APIs are shared by all implementations. .. toctree:: :titlesonly: @@ -69,7 +69,7 @@ These low-level API are shared by all implementations. API stability ------------- -Public API documented in this API reference are subject to the +Public APIs documented in this API reference are subject to the :ref:`backwards-compatibility policy `. Anything that isn't listed in the API reference is a private API. There's no From d0a292c5a25fcf4c27deb031e1b7e814494aac0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 17:58:29 +0200 Subject: [PATCH 407/762] Fix typo in docs. --- src/websockets/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/server.py b/src/websockets/server.py index e49f30213..dcdf3b71e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -319,8 +319,8 @@ def process_extensions( Accept or reject each extension proposed in the client request. Negotiate parameters for accepted extensions. - :rfc:`6455` leaves the rules up to the specification of each - :extension. + Per :rfc:`6455`, negotiation rules are defined by the specification of + each extension. To provide this level of flexibility, for each extension proposed by the client, we check for a match with each extension available in the From 1c9bdceccd9023e88e1cab5a1044f42c341930ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 18:05:34 +0200 Subject: [PATCH 408/762] Add open_timeout to serve(). --- docs/project/changelog.rst | 8 ++++++++ src/websockets/legacy/server.py | 21 ++++++++++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index fc723a626..68fb2709f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -63,6 +63,12 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. +.. admonition:: :func:`~server.serve` times out on the opening handshake after 10 seconds by default. + :class: note + + You can adjust the timeout with the ``open_timeout`` parameter. Set it to + :obj:`None` to disable the timeout entirely. + New features ............ @@ -77,6 +83,8 @@ New features See :func:`~sync.client.connect` and :func:`~sync.server.serve` for details. +* Added ``open_timeout`` to :func:`~server.serve`. + * Made it possible to close a server without closing existing connections. * Added :attr:`~server.ServerProtocol.select_subprotocol` to customize diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 399df85d3..3be86e45c 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -115,6 +115,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + open_timeout: Optional[float] = 10, **kwargs: Any, ) -> None: if logger is None: @@ -136,6 +137,7 @@ def __init__( self.server_header = server_header self._process_request = process_request self._select_subprotocol = select_subprotocol + self.open_timeout = open_timeout def connection_made(self, transport: asyncio.BaseTransport) -> None: """ @@ -161,16 +163,21 @@ async def handler(self) -> None: """ try: try: - await self.handshake( - origins=self.origins, - available_extensions=self.available_extensions, - available_subprotocols=self.available_subprotocols, - extra_headers=self.extra_headers, + await asyncio.wait_for( + self.handshake( + origins=self.origins, + available_extensions=self.available_extensions, + available_subprotocols=self.available_subprotocols, + extra_headers=self.extra_headers, + ), + self.open_timeout, ) # Remove this branch when dropping support for Python < 3.8 # because CancelledError no longer inherits Exception. except asyncio.CancelledError: # pragma: no cover raise + except asyncio.TimeoutError: # pragma: no cover + raise except ConnectionError: raise except Exception as exc: @@ -954,6 +961,8 @@ class Serve: See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. See :meth:`~WebSocketServerProtocol.select_subprotocol` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, @@ -997,6 +1006,7 @@ def __init__( select_subprotocol: Optional[ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] ] = None, + open_timeout: Optional[float] = 10, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: Optional[float] = None, @@ -1059,6 +1069,7 @@ def __init__( host=host, port=port, secure=secure, + open_timeout=open_timeout, ping_interval=ping_interval, ping_timeout=ping_timeout, close_timeout=close_timeout, From db6b1a892ce8508cf362d541b45430966ab2bc02 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 18:30:51 +0200 Subject: [PATCH 409/762] Add feature support matrices. Fix #1313. --- docs/reference/features.rst | 182 +++++++++++++++++++++++++++++++++ docs/reference/index.rst | 21 ++-- docs/reference/limitations.rst | 39 ------- 3 files changed, 193 insertions(+), 49 deletions(-) create mode 100644 docs/reference/features.rst delete mode 100644 docs/reference/limitations.rst diff --git a/docs/reference/features.rst b/docs/reference/features.rst new file mode 100644 index 000000000..f4b592a4b --- /dev/null +++ b/docs/reference/features.rst @@ -0,0 +1,182 @@ +Features +======== + +.. currentmodule:: websockets + +Feature support matrices summarize which implementations support which features. + +.. raw:: html + + + +.. |aio| replace:: :mod:`asyncio` +.. |sync| replace:: :mod:`threading` +.. |sans| replace:: `Sans-I/O`_ +.. _Sans-I/O: https://sans-io.readthedocs.io/ + +Both sides +---------- + +.. table:: + :class: support-matrix-table + + +------------------------------------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | + +====================================+========+========+========+ + | Perfom the opening handshake | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Send a message | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Receive a message | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Iterate over received messages | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+ + | Send a fragmented message | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Receive a fragmented message after | ✅ | ✅ | ❌ | + | reassembly | | | | + +------------------------------------+--------+--------+--------+ + | Receive a fragmented message frame | ❌ | ✅ | ✅ | + | by frame (`#479`_) | | | | + +------------------------------------+--------+--------+--------+ + | Send a ping | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Respond to pings automatically | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Send a pong | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Perfom the closing handshake | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Report close codes and reasons | ❌ | ✅ | ✅ | + | from both sides | | | | + +------------------------------------+--------+--------+--------+ + | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Tune memory usage for compression | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Negotiate extensions | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Implement custom extensions | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Negotiate a subprotocol | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Enforce security limits | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Log events | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Enforce opening timeout | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Enforce closing timeout | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Keepalive | ✅ | ❌ | — | + +------------------------------------+--------+--------+--------+ + | Heartbeat | ✅ | ❌ | — | + +------------------------------------+--------+--------+--------+ + +.. _#479: https://github.com/aaugustin/websockets/issues/479 + +Server +------ + +.. table:: + :class: support-matrix-table + + +------------------------------------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | + +====================================+========+========+========+ + | Listen on a TCP socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Listen on a Unix socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Listen using a preexisting socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Close server on context exit | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Close connection on handler exit | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Shut down server gracefully | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Check ``Origin`` header | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Customize subprotocol selection | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Configure ``Server`` header | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Alter opening handshake request | ❌ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Alter opening handshake response | ❌ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ❌ | ❌ | + +------------------------------------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | + +------------------------------------+--------+--------+--------+ + | Force HTTP response | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + +Client +------ + +.. table:: + :class: support-matrix-table + + +------------------------------------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | + +====================================+========+========+========+ + | Connect to a TCP socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Connect to a Unix socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Connect using a preexisting socket | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Close connection on context exit | ✅ | ✅ | — | + +------------------------------------+--------+--------+--------+ + | Reconnect automatically | ✅ | ❌ | — | + +------------------------------------+--------+--------+--------+ + | Configure ``Origin`` header | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Alter opening handshake request | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | + | (`#784`_) | | | | + +------------------------------------+--------+--------+--------+ + | Follow HTTP redirects | ✅ | ❌ | — | + +------------------------------------+--------+--------+--------+ + | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | + +------------------------------------+--------+--------+--------+ + | Connect via a SOCKS5 proxy | ❌ | ❌ | — | + | (`#475`_) | | | | + +------------------------------------+--------+--------+--------+ + +.. _#364: https://github.com/aaugustin/websockets/issues/364 +.. _#475: https://github.com/aaugustin/websockets/issues/475 +.. _#784: https://github.com/aaugustin/websockets/issues/784 + +Known limitations +----------------- + +There is no way to control compression of outgoing frames on a per-frame basis +(`#538`_). If compression is enabled, all frames are compressed. + +.. _#538: https://github.com/aaugustin/websockets/issues/538 + +The client API doesn't attempt to guarantee that there is no more than one +connection to a given IP address in a CONNECTING state. This behavior is +`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right +layer for enforcing this constraint. It's the caller's responsibility. + +.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 diff --git a/docs/reference/index.rst b/docs/reference/index.rst index cc4658dab..9364bc887 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -3,6 +3,17 @@ API reference .. currentmodule:: websockets +Features +-------- + +Check which implementations support which features and known limitations. + +.. toctree:: + :titlesonly: + + features + + :mod:`asyncio` -------------- @@ -93,13 +104,3 @@ For convenience, many public APIs can be imported directly from the If you're using such tools, stick to the full import paths, as explained in this FAQ: :ref:`real-import-paths` - -Limitations ------------ - -There are a few known limitations in the current API. - -.. toctree:: - :titlesonly: - - limitations diff --git a/docs/reference/limitations.rst b/docs/reference/limitations.rst deleted file mode 100644 index 696aa38fd..000000000 --- a/docs/reference/limitations.rst +++ /dev/null @@ -1,39 +0,0 @@ -Limitations -=========== - -.. currentmodule:: websockets - -Client ------- - -The client doesn't attempt to guarantee that there is no more than one -connection to a given IP address in a CONNECTING state. This behavior is -`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the -right layer for enforcing this constraint. It's the caller's responsibility. - -.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 - -The client doesn't support connecting through an HTTP proxy (`issue 364`_) or a -SOCKS proxy (`issue 475`_). - -.. _issue 364: https://github.com/aaugustin/websockets/issues/364 -.. _issue 475: https://github.com/aaugustin/websockets/issues/475 - -Server ------- - -At this time, there are no known limitations affecting only the server. - -Both sides ----------- - -There is no way to control compression of outgoing frames on a per-frame basis -(`issue 538`_). If compression is enabled, all frames are compressed. - -.. _issue 538: https://github.com/aaugustin/websockets/issues/538 - -There is no way to receive each fragment of a fragmented messages as it -arrives (`issue 479`_). websockets always reassembles fragmented messages -before returning them. - -.. _issue 479: https://github.com/aaugustin/websockets/issues/479 From 6dcac04c10c61dde9aa5be0374965c31f16e77f4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 18:32:46 +0200 Subject: [PATCH 410/762] Hide "Both sides" API references from navigation. The only reason for their existence is the need to hyperlink to functions that may be used on both sides. --- docs/reference/asyncio/common.rst | 2 ++ docs/reference/index.rst | 3 --- docs/reference/sansio/common.rst | 2 ++ docs/reference/sync/common.rst | 2 ++ 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index ee8dc54ac..dc7a54ee1 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,3 +1,5 @@ +:orphan: + Both sides (:mod:`asyncio`) =========================== diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 9364bc887..2a9556dd9 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -25,7 +25,6 @@ clients concurrently. asyncio/server asyncio/client - asyncio/common :mod:`threading` ---------------- @@ -37,7 +36,6 @@ This alternative implementation can be a good choice for clients. sync/server sync/client - sync/common `Sans-I/O`_ ----------- @@ -52,7 +50,6 @@ application servers. sansio/server sansio/client - sansio/common Extensions ---------- diff --git a/docs/reference/sansio/common.rst b/docs/reference/sansio/common.rst index 2678c1361..cd1ef3c63 100644 --- a/docs/reference/sansio/common.rst +++ b/docs/reference/sansio/common.rst @@ -1,3 +1,5 @@ +:orphan: + Both sides (`Sans-I/O`_) ========================= diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 8d97ab3c1..3dc6d4a50 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -1,3 +1,5 @@ +:orphan: + Both sides (:mod:`threading`) ============================= From d3d4cf4a2baf7362b004e9bffe5344287dbb9a51 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 18:43:44 +0200 Subject: [PATCH 411/762] Build architecture independent wheels. Fix #1300. --- .github/workflows/wheels.yml | 18 ++++++++++++++---- setup.py | 7 +++++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 90a075843..68bfbdef4 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -8,8 +8,10 @@ on: jobs: sdist: - name: Build source distribution + name: Build source distribution and architecture-independent wheel runs-on: ubuntu-latest + env: + BUILD_EXTENSION: no steps: - name: Check out repository uses: actions/checkout@v3 @@ -23,10 +25,20 @@ jobs: uses: actions/upload-artifact@v3 with: path: dist/*.tar.gz + - name: Install wheel + run: pip install wheel + - name: Build wheel + run: python setup.py bdist_wheel + - name: Save wheel + uses: actions/upload-artifact@v3 + with: + path: dist/*.whl wheels: - name: Build wheels on ${{ matrix.os }} + name: Build architecture-specific wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} + env: + BUILD_EXTENSION: yes strategy: matrix: os: @@ -36,8 +48,6 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v3 - - name: Make extension build mandatory - run: touch .cibuildwheel - name: Install Python 3.x uses: actions/setup-python@v4 with: diff --git a/setup.py b/setup.py index 5ed472503..c8e01f24b 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import os import pathlib import re @@ -27,11 +28,13 @@ "websockets/sync", ] -ext_modules = [ +# Set BUILD_EXTENSION to yes or no to force building or not building the +# speedups extension. If unset, the extension is built only if possible. +ext_modules = [] if os.environ.get("BUILD_EXTENSION") == "no" else [ setuptools.Extension( "websockets.speedups", sources=["src/websockets/speedups.c"], - optional=not (root_dir / ".cibuildwheel").exists(), + optional=os.environ.get("BUILD_EXTENSION") != "yes", ) ] From d81a4cebe033e0c8933fff12581444234e8bb4db Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 21:48:56 +0200 Subject: [PATCH 412/762] Move build configuration to pyproject.toml. Keep only dynamic configuration in setup.py. Remove configuration that setuptools no longer requires. --- pyproject.toml | 35 +++++++++++++++++++++++++++ setup.cfg | 11 --------- setup.py | 64 ++++++++++++-------------------------------------- 3 files changed, 50 insertions(+), 60 deletions(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..87538c5e0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "websockets" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +requires-python = ">=3.7" +license = { text = "BSD-3-Clause" } +authors = [ + { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, +] +keywords = ["WebSocket"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dynamic = ["version", "readme"] + +[project.urls] +homepage = "https://github.com/aaugustin/websockets" +changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" +documentation = "https://websockets.readthedocs.io/" +funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" +tracker = "https://github.com/aaugustin/websockets/issues" diff --git a/setup.cfg b/setup.cfg index c2ca2d5dd..28ea12c12 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,14 +1,3 @@ -[bdist_wheel] -python-tag = py37.py38.py39.py310.py311 - -[metadata] -license_files = LICENSE -project_urls = - Changelog = https://websockets.readthedocs.io/en/stable/project/changelog.html - Documentation = https://websockets.readthedocs.io/ - Funding = https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme - Tracker = https://github.com/aaugustin/websockets/issues - [flake8] ignore = E203,E731,F403,F405,W503 max-line-length = 88 diff --git a/setup.py b/setup.py index c8e01f24b..ae0aaa65d 100644 --- a/setup.py +++ b/setup.py @@ -7,66 +7,32 @@ root_dir = pathlib.Path(__file__).parent -description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" - -long_description = (root_dir / "README.rst").read_text(encoding="utf-8") +exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) -# PyPI disables the "raw" directive. +# PyPI disables the "raw" directive. Remove this section of the README. long_description = re.sub( r"^\.\. raw:: html.*?^(?=\w)", "", - long_description, + (root_dir / "README.rst").read_text(encoding="utf-8"), flags=re.DOTALL | re.MULTILINE, ) -exec((root_dir / "src" / "websockets" / "version.py").read_text(encoding="utf-8")) - -packages = [ - "websockets", - "websockets/extensions", - "websockets/legacy", - "websockets/sync", -] - # Set BUILD_EXTENSION to yes or no to force building or not building the # speedups extension. If unset, the extension is built only if possible. -ext_modules = [] if os.environ.get("BUILD_EXTENSION") == "no" else [ - setuptools.Extension( - "websockets.speedups", - sources=["src/websockets/speedups.c"], - optional=os.environ.get("BUILD_EXTENSION") != "yes", - ) -] - +if os.environ.get("BUILD_EXTENSION") == "no": + ext_modules = [] +else: + ext_modules = [ + setuptools.Extension( + "websockets.speedups", + sources=["src/websockets/speedups.c"], + optional=os.environ.get("BUILD_EXTENSION") != "yes", + ) + ] + +# Static values are declared in pyproject.toml. setuptools.setup( - name="websockets", version=version, - description=description, long_description=long_description, - url="https://github.com/aaugustin/websockets", - author="Aymeric Augustin", - author_email="aymeric.augustin@m4x.org", - license="BSD-3-Clause", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], - package_dir={"": "src"}, - package_data={"websockets": ["py.typed"]}, - packages=packages, ext_modules=ext_modules, - include_package_data=True, - zip_safe=False, - python_requires=">=3.7", - test_loader="unittest:TestLoader", ) From 0924e9cdd9e74065aa23d160414187f202771d61 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 22:14:37 +0200 Subject: [PATCH 413/762] Move coverage configuration to pyproject.toml. --- pyproject.toml | 25 +++++++++++++++++++++++++ setup.cfg | 22 ---------------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 87538c5e0..fe0602b98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,28 @@ changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" documentation = "https://websockets.readthedocs.io/" funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" tracker = "https://github.com/aaugustin/websockets/issues" + +[tool.coverage.run] +branch = true +omit = [ + # */websockets matches src/websockets and .tox/**/site-packages/websockets + "*/websockets/__main__.py", + "*/websockets/legacy/compatibility.py", + "tests/maxi_cov.py", +] + +[tool.coverage.paths] +source = [ + "src/websockets", + ".tox/*/lib/python*/site-packages/websockets", +] + +[tool.coverage.report] +exclude_lines = [ + "if self.debug:", + "pragma: no cover", + "raise AssertionError", + "raise NotImplementedError", + "self.fail\\(\".*\"\\)", + "@unittest.skip", +] diff --git a/setup.cfg b/setup.cfg index 28ea12c12..ca211fc63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,25 +6,3 @@ max-line-length = 88 profile = black combine_as_imports = True lines_after_imports = 2 - -[coverage:run] -branch = True -omit = - # */websockets matches src/websockets and .tox/**/site-packages/websockets - */websockets/__main__.py - */websockets/legacy/compatibility.py - tests/maxi_cov.py - -[coverage:paths] -source = - src/websockets - .tox/*/lib/python*/site-packages/websockets - -[coverage:report] -exclude_lines = - if self.debug: - pragma: no cover - raise AssertionError - raise NotImplementedError - self.fail\(".*"\) - @unittest.skip From af9173790a620c5f3687ea61f17e58c7466fc020 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Apr 2023 22:33:30 +0200 Subject: [PATCH 414/762] Replace flake8 and isort by ruff. Configure in pyproject.toml instead of setup.cfg. Fix #1306. --- .github/workflows/tests.yml | 4 +--- Makefile | 11 ++++++----- pyproject.toml | 16 ++++++++++++++++ setup.cfg | 8 -------- src/websockets/legacy/server.py | 8 ++++---- tox.ini | 13 ++++--------- 6 files changed, 31 insertions(+), 29 deletions(-) delete mode 100644 setup.cfg diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 34f8d8c5c..163d1bbae 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,9 +44,7 @@ jobs: - name: Check code formatting run: tox -e black - name: Check code style - run: tox -e flake8 - - name: Check imports ordering - run: tox -e isort + run: tox -e ruff - name: Check types statically run: tox -e mypy diff --git a/Makefile b/Makefile index ac5d6a4aa..cf3b53393 100644 --- a/Makefile +++ b/Makefile @@ -1,18 +1,19 @@ -.PHONY: default style test coverage maxi_cov build clean +.PHONY: default style types tests coverage maxi_cov build clean export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src export PYTHONWARNINGS=default -default: coverage style +default: style types tests style: - isort --project websockets src tests black src tests - flake8 src tests + ruff --fix src tests + +types: mypy --strict src -test: +tests: python -m unittest coverage: diff --git a/pyproject.toml b/pyproject.toml index fe0602b98..99d0fa08c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,19 @@ exclude_lines = [ "self.fail\\(\".*\"\\)", "@unittest.skip", ] + +[tool.ruff] +select = [ + "E", # pycodestyle + "F", # Pyflakes + "W", # pycodestyle + "I", # isort +] +ignore = [ + "F403", + "F405", +] + +[tool.ruff.isort] +combine-as-imports = true +lines-after-imports = 2 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index ca211fc63..000000000 --- a/setup.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[flake8] -ignore = E203,E731,F403,F405,W503 -max-line-length = 88 - -[isort] -profile = black -combine_as_imports = True -lines_after_imports = 2 diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3be86e45c..92252b136 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -552,10 +552,10 @@ def select_subprotocol( subprotocols = set(client_subprotocols) & set(server_subprotocols) if not subprotocols: return None - priority = lambda p: ( - client_subprotocols.index(p) + server_subprotocols.index(p) - ) - return sorted(subprotocols, key=priority)[0] + return sorted( + subprotocols, + key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p), + )[0] async def handshake( self, diff --git a/tox.ini b/tox.ini index 0fcab4d79..939d8c0cd 100644 --- a/tox.ini +++ b/tox.ini @@ -7,8 +7,7 @@ envlist = py311 coverage black - flake8 - isort + ruff mypy [testenv] @@ -31,13 +30,9 @@ deps = coverage commands = black --check src tests deps = black -[testenv:flake8] -commands = flake8 src tests -deps = flake8 - -[testenv:isort] -commands = isort --check-only src tests -deps = isort +[testenv:ruff] +commands = ruff src tests +deps = ruff [testenv:mypy] commands = mypy --strict src From 5b376aa4fc97424a10ad0b095c6cdbc5af81fd24 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 08:58:55 +0200 Subject: [PATCH 415/762] Reduce reliance on # pragma: no cover. --- pyproject.toml | 1 + src/websockets/frames.py | 2 +- src/websockets/legacy/framing.py | 2 +- src/websockets/legacy/protocol.py | 20 ++++++-------------- src/websockets/legacy/server.py | 18 +++++++++--------- src/websockets/sync/compatibility.py | 2 +- tests/legacy/test_client_server.py | 4 ++-- tests/test_utils.py | 2 +- 8 files changed, 22 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99d0fa08c..989b6b5e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ source = [ [tool.coverage.report] exclude_lines = [ + "except ImportError:", "if self.debug:", "pragma: no cover", "raise AssertionError", diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 52d81746d..8e0e6d873 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -13,7 +13,7 @@ try: from .speedups import apply_mask -except ImportError: # pragma: no cover +except ImportError: from .utils import apply_mask diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4836eb284..dab501d2a 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -9,7 +9,7 @@ try: from ..speedups import apply_mask -except ImportError: # pragma: no cover +except ImportError: from ..utils import apply_mask diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 67bca0ef4..7f9ab2bd8 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -714,9 +714,7 @@ async def send( await self.write_frame(False, opcode, data) # Other fragments. - # coverage reports this code as not covered, but it is - # exercised by tests - changing it breaks the tests! - async for fragment in aiter_message: # pragma: no cover + async for fragment in aiter_message: confirm_opcode, data = prepare_data(fragment) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") @@ -1150,8 +1148,8 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: if ping_id == frame.data: self.latency = pong_timestamp - ping_timestamp break - else: # pragma: no cover - assert False, "ping_id is in self.pings" + else: + raise AssertionError("solicited pong not found in pings") # Remove acknowledged pings from self.pings. for ping_id in ping_ids: del self.pings[ping_id] @@ -1320,9 +1318,7 @@ async def close_connection(self) -> None: # A client should wait for a TCP close from the server. if self.is_client and hasattr(self, "transfer_data_task"): if await self.wait_for_connection_lost(): - # Coverage marks this line as a partially executed branch. - # I suspect a bug in coverage. Ignore it for now. - return # pragma: no cover + return if self.debug: self.logger.debug("! timed out waiting for TCP close") @@ -1340,9 +1336,7 @@ async def close_connection(self) -> None: pass if await self.wait_for_connection_lost(): - # Coverage marks this line as a partially executed branch. - # I suspect a bug in coverage. Ignore it for now. - return # pragma: no cover + return if self.debug: self.logger.debug("! timed out waiting for TCP close") @@ -1378,9 +1372,7 @@ async def close_transport(self) -> None: self.transport.abort() # connection_lost() is called quickly after aborting. - # Coverage marks this line as a partially executed branch. - # I suspect a bug in coverage. Ignore it for now. - await self.wait_for_connection_lost() # pragma: no cover + await self.wait_for_connection_lost() async def wait_for_connection_lost(self) -> bool: """ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 92252b136..3506276b9 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -840,7 +840,7 @@ def is_serving(self) -> bool: """ return self.server.is_serving() - async def start_serving(self) -> None: + async def start_serving(self) -> None: # pragma: no cover """ See :meth:`asyncio.Server.start_serving`. @@ -852,9 +852,9 @@ async def start_serving(self) -> None: await server.start_serving() """ - await self.server.start_serving() # pragma: no cover + await self.server.start_serving() - async def serve_forever(self) -> None: + async def serve_forever(self) -> None: # pragma: no cover """ See :meth:`asyncio.Server.serve_forever`. @@ -870,7 +870,7 @@ async def serve_forever(self) -> None: instead of exiting a :func:`serve` context. """ - await self.server.serve_forever() # pragma: no cover + await self.server.serve_forever() @property def sockets(self) -> Iterable[socket.socket]: @@ -880,17 +880,17 @@ def sockets(self) -> Iterable[socket.socket]: """ return self.server.sockets - async def __aenter__(self) -> WebSocketServer: - return self # pragma: no cover + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], - ) -> None: - self.close() # pragma: no cover - await self.wait_closed() # pragma: no cover + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() class Serve: diff --git a/src/websockets/sync/compatibility.py b/src/websockets/sync/compatibility.py index 3064263e9..38d2ab668 100644 --- a/src/websockets/sync/compatibility.py +++ b/src/websockets/sync/compatibility.py @@ -3,7 +3,7 @@ try: from socket import create_server as socket_create_server -except ImportError: # pragma: no cover +except ImportError: import socket def socket_create_server(address, family=socket.AF_INET): # type: ignore diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index b8fd259c9..3dd06d01e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -406,11 +406,11 @@ def __init__(self, *args, **kwargs): self.used_for_write = False super().__init__(*args, **kwargs) - def recv(self, *args, **kwargs): # pragma: no cover + def recv(self, *args, **kwargs): self.used_for_read = True return super().recv(*args, **kwargs) - def recv_into(self, *args, **kwargs): # pragma: no cover + def recv_into(self, *args, **kwargs): self.used_for_read = True return super().recv_into(*args, **kwargs) diff --git a/tests/test_utils.py b/tests/test_utils.py index acd60edfc..678fcfe79 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -82,7 +82,7 @@ def test_apply_mask_check_mask_length(self): try: from websockets.speedups import apply_mask as c_apply_mask -except ImportError: # pragma: no cover +except ImportError: pass else: From 5c7a44266ff63a486a2ab6dd8fa994136fbbbaff Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 09:01:46 +0200 Subject: [PATCH 416/762] Reduce usage of # noqa. Specify which warnings are ignored. --- src/websockets/__init__.py | 4 ++-- src/websockets/__main__.py | 2 +- src/websockets/auth.py | 2 +- src/websockets/client.py | 4 ++-- src/websockets/connection.py | 2 +- src/websockets/legacy/framing.py | 7 +++++-- src/websockets/server.py | 2 +- tests/extensions/test_base.py | 2 +- tests/test_auth.py | 2 +- tests/test_http.py | 2 +- tests/test_typing.py | 2 +- 11 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 826decc48..dcf3d8150 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,10 +1,10 @@ from __future__ import annotations from .imports import lazy_import -from .version import version as __version__ # noqa +from .version import version as __version__ # noqa: F401 -__all__ = [ # noqa +__all__ = [ "AbortHandshake", "basic_auth_protocol_factory", "BasicAuthWebSocketServerProtocol", diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index a7dd1aaef..f2ea5cf4e 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -8,7 +8,7 @@ try: - import readline # noqa + import readline # noqa: F401 except ImportError: # Windows has no `readline` normally pass diff --git a/src/websockets/auth.py b/src/websockets/auth.py index afcb38cff..5292e4f7f 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,4 +1,4 @@ from __future__ import annotations # See #940 for why lazy_import isn't used here for backwards compatibility. -from .legacy.auth import * # noqa +from .legacy.auth import * diff --git a/src/websockets/client.py b/src/websockets/client.py index a0d077fc2..bf8427c37 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -38,7 +38,7 @@ # See #940 for why lazy_import isn't used here for backwards compatibility. -from .legacy.client import * # isort:skip # noqa +from .legacy.client import * # isort:skip # noqa: I001 __all__ = ["ClientProtocol"] @@ -90,7 +90,7 @@ def __init__( self.available_subprotocols = subprotocols self.key = generate_key() - def connect(self) -> Request: # noqa: F811 + def connect(self) -> Request: """ Create a handshake request to open a connection. diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 5ce4d6a3b..88bcda1aa 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -3,7 +3,7 @@ import warnings # lazy_import doesn't support this use case. -from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa +from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 warnings.warn( diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index dab501d2a..b77b869e3 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -145,8 +145,11 @@ def write( # Backwards compatibility with previously documented public APIs - -from ..frames import Close, prepare_ctrl as encode_data, prepare_data # noqa +from ..frames import ( # noqa: E402, F401, I001 + Close, + prepare_ctrl as encode_data, + prepare_data, +) def parse_close(data: bytes) -> Tuple[int, str]: diff --git a/src/websockets/server.py b/src/websockets/server.py index dcdf3b71e..ecb0f74a6 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -39,7 +39,7 @@ # See #940 for why lazy_import isn't used here for backwards compatibility. -from .legacy.server import * # isort:skip # noqa +from .legacy.server import * # isort:skip # noqa: I001 __all__ = ["ServerProtocol"] diff --git a/tests/extensions/test_base.py b/tests/extensions/test_base.py index ba8657b65..b18ffb6fb 100644 --- a/tests/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,4 +1,4 @@ -from websockets.extensions.base import * # noqa +from websockets.extensions.base import * # Abstract classes don't provide any behavior to test. diff --git a/tests/test_auth.py b/tests/test_auth.py index d5a8bd9ad..28db93155 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1 +1 @@ -from websockets.auth import * # noqa +from websockets.auth import * diff --git a/tests/test_http.py b/tests/test_http.py index 16bec9468..036bc1410 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1 +1 @@ -from websockets.http import * # noqa +from websockets.http import * diff --git a/tests/test_typing.py b/tests/test_typing.py index 6eb1fe6c5..202de840f 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1 +1 @@ -from websockets.typing import * # noqa +from websockets.typing import * From 5113cd3afe7cbf5f740d80db8449670c27c90101 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 7 Mar 2023 11:56:19 -1000 Subject: [PATCH 417/762] Use asyncio.timeout instead of asyncio.wait_for. asyncio.wait_for creates a task whereas asyncio.timeout doesn't. Fallback to a vendored version of async_timeout on Python < 3.11. async.timeout will become the underlying implementation for async.wait_for in Python 3.12: https://github.com/python/cpython/pull/98518 --- pyproject.toml | 1 + src/websockets/legacy/async_timeout.py | 225 +++++++++++++++++++++++++ src/websockets/legacy/compatibility.py | 9 + src/websockets/legacy/protocol.py | 34 ++-- 4 files changed, 246 insertions(+), 23 deletions(-) create mode 100644 src/websockets/legacy/async_timeout.py diff --git a/pyproject.toml b/pyproject.toml index 989b6b5e0..0707c6442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ branch = true omit = [ # */websockets matches src/websockets and .tox/**/site-packages/websockets "*/websockets/__main__.py", + "*/websockets/legacy/async_timeout.py", "*/websockets/legacy/compatibility.py", "tests/maxi_cov.py", ] diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/legacy/async_timeout.py new file mode 100644 index 000000000..0a2208927 --- /dev/null +++ b/src/websockets/legacy/async_timeout.py @@ -0,0 +1,225 @@ +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License, Version 2.0. + +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Optional, Type + + +if sys.version_info >= (3, 8): + from typing import final +else: + from typing_extensions import final + + +__version__ = "4.0.2" + + +__all__ = ("timeout", "timeout_at", "Timeout") + + +def timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + if delay is not None: + deadline = loop.time() + delay # type: Optional[float] + else: + deadline = None + return Timeout(deadline, loop) + + +def timeout_at(deadline: Optional[float]) -> "Timeout": + """Schedule the timeout at absolute time. + + deadline argument points on the time in the same clock system + as loop.time(). + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + + >>> async with timeout_at(loop.time() + 10): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + """ + loop = asyncio.get_running_loop() + return Timeout(deadline, loop) + + +class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + +@final +class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__( + self, deadline: Optional[float], loop: asyncio.AbstractEventLoop + ) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout() instead", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + + The delay can be negative. + + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError("cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + + deadline argument points on the time in the same clock system + as loop.time(). + + If new deadline is in the past the timeout is raised immediately. + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError("cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon(self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "asyncio.Task[None]") -> None: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index 303e203b4..cb9b02c86 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -5,6 +5,9 @@ from typing import Any, Dict +__all__ = ["asyncio_timeout", "loop_if_py_lt_38"] + + if sys.version_info[:2] >= (3, 8): def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: @@ -22,3 +25,9 @@ def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: """ return {"loop": loop} + + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout # noqa: F401 +else: + from .async_timeout import timeout as asyncio_timeout # noqa: F401 diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 7f9ab2bd8..78b59ee81 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -53,7 +53,7 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import loop_if_py_lt_38 +from .compatibility import asyncio_timeout, loop_if_py_lt_38 from .framing import Frame @@ -761,19 +761,16 @@ async def close(self, code: int = 1000, reason: str = "") -> None: """ try: - await asyncio.wait_for( - self.write_close_frame(Close(code, reason)), - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await self.write_close_frame(Close(code, reason)) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers # are full, the closing handshake won't complete anyway. # Fail the connection to shut down faster. self.fail_connection() - # If no close frame is received within the timeout, wait_for() cancels - # the data transfer task and raises TimeoutError. + # If no close frame is received within the timeout, asyncio_timeout() + # cancels the data transfer task and raises TimeoutError. # If close() is called multiple times concurrently and one of these # calls hits the timeout, the data transfer task will be canceled. @@ -782,11 +779,8 @@ async def close(self, code: int = 1000, reason: str = "") -> None: try: # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses. - await asyncio.wait_for( - self.transfer_data_task, - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await self.transfer_data_task except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1268,11 +1262,8 @@ async def keepalive_ping(self) -> None: if self.ping_timeout is not None: try: - await asyncio.wait_for( - pong_waiter, - self.ping_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.ping_timeout): + await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1384,11 +1375,8 @@ async def wait_for_connection_lost(self) -> bool: """ if not self.connection_lost_waiter.done(): try: - await asyncio.wait_for( - asyncio.shield(self.connection_lost_waiter), - self.close_timeout, - **loop_if_py_lt_38(self.loop), - ) + async with asyncio_timeout(self.close_timeout): + await asyncio.shield(self.connection_lost_waiter) except asyncio.TimeoutError: pass # Re-check self.connection_lost_waiter.done() synchronously because From 5a6f74e2248181ccd25a638304e8959d3ea90f91 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 07:44:51 +0200 Subject: [PATCH 418/762] Backport typing.final for compatibility with Python 3.7. --- src/websockets/legacy/async_timeout.py | 46 ++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/legacy/async_timeout.py index 0a2208927..8264094f5 100644 --- a/src/websockets/legacy/async_timeout.py +++ b/src/websockets/legacy/async_timeout.py @@ -1,5 +1,5 @@ # From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py -# Licensed under the Apache License, Version 2.0. +# Licensed under the Apache License (Apache-2.0) import asyncio import enum @@ -9,11 +9,48 @@ from typing import Optional, Type -if sys.version_info >= (3, 8): +# From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py +# Licensed under the Python Software Foundation License (PSF-2.0) + +if sys.version_info >= (3, 11): from typing import final else: - from typing_extensions import final + # @final exists in 3.8+, but we backport it for all versions + # before 3.11 to keep support for the __final__ attribute. + # See https://bugs.python.org/issue46342 + def final(f): + """This decorator can be used to indicate to type checkers that + the decorated method cannot be overridden, and decorated class + cannot be subclassed. For example: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker + ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... + + There is no runtime checking of these properties. The decorator + sets the ``__final__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return f + +# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py __version__ = "4.0.2" @@ -223,3 +260,6 @@ def _on_timeout(self, task: "asyncio.Task[None]") -> None: self._state = _State.TIMEOUT # drop the reference early self._timeout_handler = None + + +# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py From 808d8540af05e6bbfd74be8ae501b200ce7c966e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 08:14:27 +0200 Subject: [PATCH 419/762] Replace asyncio.wait_for with asyncio.timeout. Some instances were missed in a previous commit. Also update documentation. --- docs/faq/common.rst | 9 +++++++-- docs/topics/design.rst | 2 +- src/websockets/legacy/client.py | 4 +++- src/websockets/legacy/protocol.py | 5 +++-- src/websockets/legacy/server.py | 10 ++++------ tests/legacy/test_client_server.py | 4 +++- 6 files changed, 21 insertions(+), 13 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 2a512ea90..505149a64 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -105,9 +105,14 @@ you can adjust it with the ``ping_timeout`` argument. How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? -------------------------------------------------------------------------------- -Use :func:`~asyncio.wait_for`:: +On Python ≥ 3.11, use :func:`asyncio.timeout`:: - await asyncio.wait_for(websocket.recv(), timeout=10) + async with asyncio.timeout(timeout=10): + message = await websocket.recv() + +On older versions of Python, use :func:`asyncio.wait_for`:: + + message = await asyncio.wait_for(websocket.recv(), timeout=10) This technique works for most APIs. When it doesn't, for example with asynchronous context managers, websockets provides an ``open_timeout`` argument. diff --git a/docs/topics/design.rst b/docs/topics/design.rst index 33dd187b9..f164d2990 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -437,7 +437,7 @@ propagate cancellation to them. prevent cancellation. :meth:`~legacy.protocol.WebSocketCommonProtocol.close` waits for the data transfer -task to terminate with :func:`~asyncio.wait_for`. If it's canceled or if the +task to terminate with :func:`~asyncio.timeout`. If it's canceled or if the timeout elapses, :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` is canceled, which is correct at this point. :meth:`~legacy.protocol.WebSocketCommonProtocol.close` then waits for diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b79bbab8c..c5e9d0d52 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -44,6 +44,7 @@ from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri +from .compatibility import asyncio_timeout from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol @@ -650,7 +651,8 @@ def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: return self.__await_impl_timeout__().__await__() async def __await_impl_timeout__(self) -> WebSocketClientProtocol: - return await asyncio.wait_for(self.__await_impl__(), self.open_timeout) + async with asyncio_timeout(self.open_timeout): + return await self.__await_impl__() async def __await_impl__(self) -> WebSocketClientProtocol: for redirects in range(self.MAX_REDIRECTS_ALLOWED): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 78b59ee81..8b921e6fe 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -107,7 +107,8 @@ class WebSocketCommonProtocol(asyncio.Protocol): context manager; * on the server side, when the connection handler terminates. - To apply a timeout to any other API, wrap it in :func:`~asyncio.wait_for`. + To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or + :func:`~asyncio.wait_for`. The ``max_size`` parameter enforces the maximum size for incoming messages in bytes. The default value is 1 MiB. If a larger message is received, @@ -513,7 +514,7 @@ async def recv(self) -> Data: message. The next invocation of :meth:`recv` will return it. This makes it possible to enforce a timeout by wrapping :meth:`recv` in - :func:`~asyncio.wait_for`. + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. Returns: Data: A string (:class:`str`) for a Text_ frame. A bytestring diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 3506276b9..25d5a7144 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -46,7 +46,7 @@ from ..http import USER_AGENT from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol -from .compatibility import loop_if_py_lt_38 +from .compatibility import asyncio_timeout, loop_if_py_lt_38 from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol @@ -163,15 +163,13 @@ async def handler(self) -> None: """ try: try: - await asyncio.wait_for( - self.handshake( + async with asyncio_timeout(self.open_timeout): + await self.handshake( origins=self.origins, available_extensions=self.available_extensions, available_subprotocols=self.available_subprotocols, extra_headers=self.extra_headers, - ), - self.open_timeout, - ) + ) # Remove this branch when dropping support for Python < 3.8 # because CancelledError no longer inherits Exception. except asyncio.CancelledError: # pragma: no cover diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 3dd06d01e..133af0536 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -29,6 +29,7 @@ ) from websockets.http import USER_AGENT from websockets.legacy.client import * +from websockets.legacy.compatibility import asyncio_timeout from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * @@ -1129,7 +1130,8 @@ def test_client_connect_canceled_during_handshake(self): async def cancelled_client(): start_client = connect(get_server_uri(self.server), sock=sock) - await asyncio.wait_for(start_client, 5 * MS) + async with asyncio_timeout(5 * MS): + await start_client with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(cancelled_client()) From f075aac67e15cdf4bc06078e23b82eac5fb2d758 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 08:38:39 +0200 Subject: [PATCH 420/762] Restore semantics of tests. They relied (accidentally) on wait_for() creating a task, causing the event loop to run once when calling close(). --- tests/legacy/test_protocol.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ab8155b9b..409aef901 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -190,7 +190,6 @@ def half_close_connection_local(self, code=1000, reason="close"): close_frame_data = Close(code, reason).serialize() # Trigger the closing handshake from the local endpoint. close_task = self.loop.create_task(self.protocol.close(code, reason)) - self.run_loop_once() # wait_for executes self.run_loop_once() # write_frame executes # Empty the outgoing data stream so we can make assertions later on. self.assertOneFrameSent(True, OP_CLOSE, close_frame_data) @@ -919,6 +918,7 @@ def test_answer_ping_does_not_crash_if_connection_closing(self): close_task = self.half_close_connection_local() self.receive_frame(Frame(True, OP_PING, b"test")) + self.run_loop_once() with self.assertNoLogs(): self.loop.run_until_complete(self.protocol.close()) @@ -931,6 +931,7 @@ def test_answer_ping_does_not_crash_if_connection_closed(self): # which prevents responding with a pong frame properly. self.receive_frame(Frame(True, OP_PING, b"test")) self.receive_eof() + self.run_loop_once() with self.assertNoLogs(): self.loop.run_until_complete(self.protocol.close()) @@ -1362,6 +1363,7 @@ def test_remote_close_and_connection_lost(self): # which prevents echoing the close frame properly. self.receive_frame(self.close_frame) self.receive_eof() + self.run_loop_once() with self.assertNoLogs(): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) @@ -1375,6 +1377,7 @@ def test_simultaneous_close(self): # https://github.com/aaugustin/websockets/issues/339 self.loop.call_soon(self.receive_frame, self.remote_close) self.loop.call_soon(self.receive_eof_if_client) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="local")) @@ -1386,6 +1389,7 @@ def test_simultaneous_close(self): def test_close_preserves_incoming_frames(self): self.receive_frame(Frame(True, OP_TEXT, b"hello")) + self.run_loop_once() self.loop.call_later(MS, self.receive_frame, self.close_frame) self.loop.call_later(MS, self.receive_eof_if_client) @@ -1573,6 +1577,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(1000, "close") @@ -1589,6 +1594,7 @@ def test_local_close_connection_lost_timeout_after_close(self): # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(1000, "close") @@ -1631,6 +1637,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): # HACK: disable write_eof => other end drops connection emulation. self.transport._eof = True self.receive_frame(self.close_frame) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(1000, "close") @@ -1650,5 +1657,6 @@ def test_local_close_connection_lost_timeout_after_close(self): # HACK: disable close => other end drops connection emulation. self.transport._closing = True self.receive_frame(self.close_frame) + self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) self.assertConnectionClosed(1000, "close") From 901e434fac7bf60018c950bdaf85b9946cc4309d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 09:30:00 +0200 Subject: [PATCH 421/762] Work around bug in coverage. --- src/websockets/legacy/protocol.py | 3 ++- tests/legacy/test_protocol.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 8b921e6fe..733abb3b9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1361,7 +1361,8 @@ async def close_transport(self) -> None: # Abort the TCP connection. Buffers are discarded. if self.debug: self.logger.debug("x aborting TCP connection") - self.transport.abort() + # Due to a bug in coverage, this is erroneously reported as not covered. + self.transport.abort() # pragma: no cover # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 409aef901..a05dcc6f6 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1579,7 +1579,8 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1000, "close") # pragma: no cover def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1596,7 +1597,8 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1000, "close") # pragma: no cover class ClientTests(CommonTests, AsyncioTestCase): @@ -1614,7 +1616,8 @@ def test_local_close_send_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1006, "") # pragma: no cover def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1624,7 +1627,8 @@ def test_local_close_receive_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1006, "") # pragma: no cover def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS @@ -1639,7 +1643,8 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1000, "close") # pragma: no cover def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1659,4 +1664,5 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.assertConnectionClosed(1000, "close") # pragma: no cover From 00835ccf2bc9bb483b6bf3a69dd487d3745fbb27 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:07:35 +0200 Subject: [PATCH 422/762] Fix typo. --- docs/reference/features.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index f4b592a4b..7e6e262dc 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -28,7 +28,7 @@ Both sides +------------------------------------+--------+--------+--------+ | | |aio| | |sync| | |sans| | +====================================+========+========+========+ - | Perfom the opening handshake | ✅ | ✅ | ✅ | + | Perform the opening handshake | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+ | Send a message | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+ @@ -50,7 +50,7 @@ Both sides +------------------------------------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+ - | Perfom the closing handshake | ✅ | ✅ | ✅ | + | Perform the closing handshake | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+ | Report close codes and reasons | ❌ | ✅ | ✅ | | from both sides | | | | From 7dd4ede471f95b59b2d15c669b927773a2371183 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:06:06 +0200 Subject: [PATCH 423/762] Add changelog for d3d4cf4a. --- docs/project/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 68fb2709f..8e9be789e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -93,6 +93,8 @@ New features Improvements ............ +* Added platform-independent wheels. + * Improved error handling in :func:`~websockets.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a From f516cf51e166ec0cc797fa6bb68b559e4c1fed8b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:00:26 +0200 Subject: [PATCH 424/762] Release version 11.0 --- docs/project/changelog.rst | 3 +-- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8e9be789e..14caf2df9 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,11 +25,10 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. - 11.0 ---- -*In development* +*April 2, 2023* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 1a3d884d8..375112512 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,7 +16,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "11.0" From fe1879fc22ced71f3b9b23f9fd7ac5a727ebb6bd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:01:02 +0200 Subject: [PATCH 425/762] Start version 11.1 --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 14caf2df9..bdb7d7f7d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +11.1 +---- + +*In development* + 11.0 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index 375112512..802dba546 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -16,9 +16,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "11.0" +tag = version = commit = "11.1" if not released: # pragma: no cover From 1bf73423044dedc325435d261a39473337f5ddcf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Apr 2023 10:48:19 +0200 Subject: [PATCH 426/762] Drop support for Python 3.7. --- .github/workflows/tests.yml | 4 ---- docs/intro/index.rst | 2 +- docs/project/changelog.rst | 11 ++++++++++- pyproject.toml | 3 +-- src/websockets/datastructures.py | 8 +------- src/websockets/legacy/client.py | 8 -------- src/websockets/legacy/compatibility.py | 23 +---------------------- src/websockets/legacy/protocol.py | 17 ++++------------- src/websockets/legacy/server.py | 16 ++++------------ src/websockets/sync/compatibility.py | 21 --------------------- src/websockets/sync/server.py | 5 ++--- src/websockets/version.py | 6 +++--- tests/legacy/test_protocol.py | 3 +-- tests/legacy/utils.py | 3 ++- tests/sync/client.py | 16 ---------------- tests/sync/server.py | 4 +--- tests/sync/test_server.py | 3 +-- 17 files changed, 32 insertions(+), 121 deletions(-) delete mode 100644 src/websockets/sync/compatibility.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 163d1bbae..603426412 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -57,19 +57,15 @@ jobs: strategy: matrix: python: - - "3.7" - "3.8" - "3.9" - "3.10" - "3.11" - - "pypy-3.7" - "pypy-3.8" - "pypy-3.9" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.7" - is_main: false - python: "pypy-3.8" is_main: false - python: "pypy-3.9" diff --git a/docs/intro/index.rst b/docs/intro/index.rst index fe4e704d6..095262a20 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -websockets requires Python ≥ 3.7. +websockets requires Python ≥ 3.8. .. admonition:: Use the most recent Python release :class: tip diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index bdb7d7f7d..7b6972fd0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,11 +25,20 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. -11.1 +12.0 ---- *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: websockets 12.0 requires Python ≥ 3.8. + :class: tip + + websockets 11.0 is the last version supporting Python 3.7. + + 11.0 ---- diff --git a/pyproject.toml b/pyproject.toml index 0707c6442..a4622ae2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "websockets" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -requires-python = ">=3.7" +requires-python = ">=3.8" license = { text = "BSD-3-Clause" } authors = [ { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, @@ -19,7 +19,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 36a2cbaf9..a0a648463 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from typing import ( Any, Dict, @@ -9,17 +8,12 @@ List, Mapping, MutableMapping, + Protocol, Tuple, Union, ) -if sys.version_info[:2] >= (3, 8): - from typing import Protocol -else: # pragma: no cover - Protocol = object # mypy will report errors on Python 3.7. - - __all__ = ["Headers", "HeadersLike", "MultipleValuesError"] diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index c5e9d0d52..48622523e 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -137,10 +137,6 @@ async def read_http_response(self) -> Tuple[int, Headers]: """ try: status_code, reason, headers = await read_response(self.reader) - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: # pragma: no cover - raise except Exception as exc: raise InvalidMessage("did not receive a valid HTTP response") from exc @@ -601,10 +597,6 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: try: async with self as protocol: yield protocol - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: # pragma: no cover - raise except Exception: # Add a random initial delay between 0 and 5 seconds. # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/legacy/compatibility.py index cb9b02c86..6bd01e70d 100644 --- a/src/websockets/legacy/compatibility.py +++ b/src/websockets/legacy/compatibility.py @@ -1,30 +1,9 @@ from __future__ import annotations -import asyncio import sys -from typing import Any, Dict -__all__ = ["asyncio_timeout", "loop_if_py_lt_38"] - - -if sys.version_info[:2] >= (3, 8): - - def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: - """ - Helper for the removal of the loop argument in Python 3.10. - - """ - return {} - -else: - - def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]: - """ - Helper for the removal of the loop argument in Python 3.10. - - """ - return {"loop": loop} +__all__ = ["asyncio_timeout"] if sys.version_info[:2] >= (3, 11): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 733abb3b9..0422c10da 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -53,7 +53,7 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import asyncio_timeout, loop_if_py_lt_38 +from .compatibility import asyncio_timeout from .framing import Frame @@ -244,7 +244,7 @@ def __init__( self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._drain_lock = asyncio.Lock(**loop_if_py_lt_38(loop)) + self._drain_lock = asyncio.Lock() # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -339,7 +339,7 @@ async def _drain(self) -> None: # pragma: no cover # write(...); yield from drain() # in a loop would never call connection_lost(), so it # would not see an error when the socket is closed. - await asyncio.sleep(0, **loop_if_py_lt_38(self.loop)) + await asyncio.sleep(0) await self._drain_helper() def connection_open(self) -> None: @@ -551,7 +551,6 @@ async def recv(self) -> Data: await asyncio.wait( [pop_message_waiter, self.transfer_data_task], return_when=asyncio.FIRST_COMPLETED, - **loop_if_py_lt_38(self.loop), ) finally: self._pop_message_waiter = None @@ -1247,10 +1246,7 @@ async def keepalive_ping(self) -> None: try: while True: - await asyncio.sleep( - self.ping_interval, - **loop_if_py_lt_38(self.loop), - ) + await asyncio.sleep(self.ping_interval) # ping() raises CancelledError if the connection is closed, # when close_connection() cancels self.keepalive_ping_task. @@ -1272,11 +1268,6 @@ async def keepalive_ping(self) -> None: self.fail_connection(1011, "keepalive ping timeout") break - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: - raise - except ConnectionClosed: pass diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 25d5a7144..a17c52328 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -46,7 +46,7 @@ from ..http import USER_AGENT from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol -from .compatibility import asyncio_timeout, loop_if_py_lt_38 +from .compatibility import asyncio_timeout from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol @@ -170,10 +170,6 @@ async def handler(self) -> None: available_subprotocols=self.available_subprotocols, extra_headers=self.extra_headers, ) - # Remove this branch when dropping support for Python < 3.8 - # because CancelledError no longer inherits Exception. - except asyncio.CancelledError: # pragma: no cover - raise except asyncio.TimeoutError: # pragma: no cover raise except ConnectionError: @@ -770,7 +766,7 @@ async def _close(self, close_connections: bool) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0, **loop_if_py_lt_38(self.get_loop())) + await asyncio.sleep(0) if close_connections: # Close OPEN connections with status code 1001. Since the server was @@ -784,18 +780,14 @@ async def _close(self, close_connections: bool) -> None: ] # asyncio.wait doesn't accept an empty first argument. if close_tasks: - await asyncio.wait( - close_tasks, - **loop_if_py_lt_38(self.get_loop()), - ) + await asyncio.wait(close_tasks) # Wait until all connection handlers are complete. # asyncio.wait doesn't accept an empty first argument. if self.websockets: await asyncio.wait( - [websocket.handler_task for websocket in self.websockets], - **loop_if_py_lt_38(self.get_loop()), + [websocket.handler_task for websocket in self.websockets] ) # Tell wait_closed() to return. diff --git a/src/websockets/sync/compatibility.py b/src/websockets/sync/compatibility.py deleted file mode 100644 index 38d2ab668..000000000 --- a/src/websockets/sync/compatibility.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - - -try: - from socket import create_server as socket_create_server -except ImportError: - import socket - - def socket_create_server(address, family=socket.AF_INET): # type: ignore - """Simplified backport of socket.create_server from Python 3.8.""" - sock = socket.socket(family, socket.SOCK_STREAM) - try: - sock.bind(address) - sock.listen() - return sock - except socket.error: - sock.close() - raise - - -__all__ = ["socket_create_server"] diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 9284c6188..e25646a5c 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -18,7 +18,6 @@ from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, Subprotocol -from .compatibility import socket_create_server from .connection import Connection from .utils import Deadline @@ -397,9 +396,9 @@ def handler(websocket): if unix: if path is None: raise TypeError("missing path argument") - sock = socket_create_server(path, family=socket.AF_UNIX) + sock = socket.create_server(path, family=socket.AF_UNIX) else: - sock = socket_create_server((host, port)) + sock = socket.create_server((host, port)) else: if path is not None: raise TypeError("path and sock arguments are incompatible") diff --git a/src/websockets/version.py b/src/websockets/version.py index 802dba546..f45f069ab 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -1,5 +1,7 @@ from __future__ import annotations +import importlib.metadata + __all__ = ["tag", "version", "commit"] @@ -18,7 +20,7 @@ released = False -tag = version = commit = "11.1" +tag = version = commit = "12.0" if not released: # pragma: no cover @@ -56,8 +58,6 @@ def get_version(tag: str) -> str: # Read version from package metadata if it is installed. try: - import importlib.metadata # move up when dropping Python 3.7 - return importlib.metadata.version("websockets") except ImportError: pass diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index a05dcc6f6..9338e15bd 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -16,7 +16,6 @@ OP_TEXT, Close, ) -from websockets.legacy.compatibility import loop_if_py_lt_38 from websockets.legacy.framing import Frame from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast from websockets.protocol import State @@ -117,7 +116,7 @@ def make_drain_slow(self, delay=MS): original_drain = self.protocol._drain async def delayed_drain(): - await asyncio.sleep(delay, **loop_if_py_lt_38(self.loop)) + await asyncio.sleep(delay) await original_drain() self.protocol._drain = delayed_drain diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index bb4eebb52..4a21dcaeb 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -9,7 +9,8 @@ class AsyncioTestCase(unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. - Replace with IsolatedAsyncioTestCase when dropping Python < 3.8. + IsolatedAsyncioTestCase was introduced in Python 3.8 for similar purposes + but isn't a drop-in replacement. """ diff --git a/tests/sync/client.py b/tests/sync/client.py index 51bbd4388..683893e88 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -1,7 +1,5 @@ import contextlib import ssl -import sys -import warnings from websockets.sync.client import * from websockets.sync.server import WebSocketServer @@ -19,20 +17,6 @@ CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) -# Work around https://github.com/openssl/openssl/issues/7967 - -# This bug causes connect() to hang in tests for the client. Including this -# workaround acknowledges that the issue could happen outside of the test suite. - -# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it -# happens, we can look for a library-level fix, but it won't be easy. - -if sys.version_info[:2] < (3, 8): # pragma: no cover - # ssl.OP_NO_TLSv1_3 was introduced and deprecated on Python 3.7. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - CLIENT_CONTEXT.options |= ssl.OP_NO_TLSv1_3 - @contextlib.contextmanager def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): diff --git a/tests/sync/server.py b/tests/sync/server.py index 5f0cd3b07..a9a77438c 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,6 +1,5 @@ import contextlib import ssl -import sys import threading from websockets.sync.server import * @@ -19,8 +18,7 @@ # It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it # happens, we can look for a library-level fix, but it won't be easy. -if sys.version_info[:2] >= (3, 8): # pragma: no cover - SERVER_CONTEXT.num_tickets = 0 +SERVER_CONTEXT.num_tickets = 0 def crash(ws): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 536858149..60e70c0a5 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -12,7 +12,6 @@ NegotiationError, ) from websockets.http11 import Request, Response -from websockets.sync.compatibility import socket_create_server from websockets.sync.server import * from ..utils import MS, temp_unix_socket_path @@ -72,7 +71,7 @@ def test_connection_handler_raises_exception(self): def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" - with socket_create_server(("localhost", 0)) as sock: + with socket.create_server(("localhost", 0)) as sock: with run_server(sock=sock): # Build WebSocket URI to ensure we connect to the right socket. with run_client("ws://{}:{}/".format(*sock.getsockname())) as client: From fff80c316e25af645ab5cc9b7217b2064d429d73 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Mon, 3 Apr 2023 00:03:43 +0300 Subject: [PATCH 427/762] Fix FAQ link in issue template --- .github/ISSUE_TEMPLATE/issue.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md index f2704152c..6c54e7fd5 100644 --- a/.github/ISSUE_TEMPLATE/issue.md +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -12,7 +12,7 @@ assignees: '' Thanks for taking the time to report an issue! Did you check the FAQ? Perhaps you'll find the answer you need: -https://websockets.readthedocs.io/en/stable/howto/faq.html +https://websockets.readthedocs.io/en/stable/faq/index.html Is your question really about asyncio? Perhaps the dev guide will help: https://docs.python.org/3/library/asyncio-dev.html From 19be15b387bf92697ff71e6f3e6d5038fe1515d8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 6 Apr 2023 07:54:30 +0200 Subject: [PATCH 428/762] Restore speedups.c in source distribution. --- .github/workflows/wheels.yml | 7 +++---- docs/project/changelog.rst | 9 +++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 68bfbdef4..00bd0ccc5 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -10,8 +10,6 @@ jobs: sdist: name: Build source distribution and architecture-independent wheel runs-on: ubuntu-latest - env: - BUILD_EXTENSION: no steps: - name: Check out repository uses: actions/checkout@v3 @@ -28,6 +26,8 @@ jobs: - name: Install wheel run: pip install wheel - name: Build wheel + env: + BUILD_EXTENSION: no run: python setup.py bdist_wheel - name: Save wheel uses: actions/upload-artifact@v3 @@ -37,8 +37,6 @@ jobs: wheels: name: Build architecture-specific wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} - env: - BUILD_EXTENSION: yes strategy: matrix: os: @@ -60,6 +58,7 @@ jobs: - name: Build wheels uses: pypa/cibuildwheel@v2.12.1 env: + BUILD_EXTENSION: yes CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" CIBW_ARCHS_LINUX: "auto aarch64" CIBW_SKIP: cp36-* diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7b6972fd0..f35caa19e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,15 @@ Backwards-incompatible changes websockets 11.0 is the last version supporting Python 3.7. +11.0.1 +------ + +*April 6, 2023* + +Bug fixes +......... + +* Restored the C extension in the source distribution. 11.0 ---- From 659e562d83187a419592e1a472b0d7af657d6642 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 6 Apr 2023 07:37:02 +0200 Subject: [PATCH 429/762] Move cibuildwheel configuration to pyproject.toml. cibuildwheel can read requires-python. --- .github/workflows/wheels.yml | 3 --- pyproject.toml | 8 ++++++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 00bd0ccc5..d9c9ea833 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -59,9 +59,6 @@ jobs: uses: pypa/cibuildwheel@v2.12.1 env: BUILD_EXTENSION: yes - CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" - CIBW_ARCHS_LINUX: "auto aarch64" - CIBW_SKIP: cp36-* - name: Save wheels uses: actions/upload-artifact@v3 with: diff --git a/pyproject.toml b/pyproject.toml index a4622ae2d..9065aab45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,14 @@ documentation = "https://websockets.readthedocs.io/" funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" tracker = "https://github.com/aaugustin/websockets/issues" +# On a macOS runner, build Intel, Universal, and Apple Silicon wheels. +[tool.cibuildwheel.macos] +archs = ["x86_64", "universal2", "arm64"] + +# On an Linux Intel runner with QEMU installed, build Intel and ARM wheels. +[tool.cibuildwheel.linux] +archs = ["auto", "aarch64"] + [tool.coverage.run] branch = true omit = [ From 1bc18193b32b210cc804e7e174d15c0b27db21f0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 08:11:30 +0200 Subject: [PATCH 430/762] Move repository to python-websockets organization. --- .github/ISSUE_TEMPLATE/issue.md | 4 ++-- README.rst | 14 +++++++------- docs/conf.py | 2 +- docs/faq/asyncio.rst | 2 +- docs/index.rst | 4 ++-- docs/intro/tutorial3.rst | 8 ++++---- docs/project/changelog.rst | 2 +- docs/project/contributing.rst | 6 +++--- docs/reference/features.rst | 10 +++++----- docs/topics/authentication.rst | 2 +- docs/topics/broadcast.rst | 2 +- docs/topics/compression.rst | 4 ++-- example/tutorial/step3/main.js | 2 +- experiments/compression/benchmark.py | 2 +- pyproject.toml | 4 ++-- tests/legacy/test_protocol.py | 2 +- 16 files changed, 35 insertions(+), 35 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/issue.md b/.github/ISSUE_TEMPLATE/issue.md index 6c54e7fd5..3cf4e3b77 100644 --- a/.github/ISSUE_TEMPLATE/issue.md +++ b/.github/ISSUE_TEMPLATE/issue.md @@ -18,12 +18,12 @@ Is your question really about asyncio? Perhaps the dev guide will help: https://docs.python.org/3/library/asyncio-dev.html Did you look for similar issues? Please keep the discussion in one place :-) -https://github.com/aaugustin/websockets/issues?q=is%3Aissue +https://github.com/python-websockets/websockets/issues?q=is%3Aissue Is your issue related to cryptocurrency in any way? Please don't file it. https://websockets.readthedocs.io/en/stable/project/contributing.html#cryptocurrency-users For bugs, providing a reproduction helps a lot. Take an existing example and tweak it! -https://github.com/aaugustin/websockets/tree/main/example +https://github.com/python-websockets/websockets/tree/main/example --> diff --git a/README.rst b/README.rst index 5ba523e8f..f53d3d0fc 100644 --- a/README.rst +++ b/README.rst @@ -13,8 +13,8 @@ .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main?label=tests - :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml +.. |tests| image:: https://img.shields.io/github/checks-status/python-websockets/websockets/main?label=tests + :target: https://github.com/python-websockets/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ @@ -84,7 +84,7 @@ Does that look good? .. raw:: html
- +

websockets for enterprise

Available as part of the Tidelift Subscription

The maintainers of websockets and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. Learn more.

@@ -147,13 +147,13 @@ contact`_. Tidelift will coordinate the fix and disclosure. For anything else, please open an issue_ or send a `pull request`_. -.. _issue: https://github.com/aaugustin/websockets/issues/new -.. _pull request: https://github.com/aaugustin/websockets/compare/ +.. _issue: https://github.com/python-websockets/websockets/issues/new +.. _pull request: https://github.com/python-websockets/websockets/compare/ Participants must uphold the `Contributor Covenant code of conduct`_. -.. _Contributor Covenant code of conduct: https://github.com/aaugustin/websockets/blob/main/CODE_OF_CONDUCT.md +.. _Contributor Covenant code of conduct: https://github.com/python-websockets/websockets/blob/main/CODE_OF_CONDUCT.md ``websockets`` is released under the `BSD license`_. -.. _BSD license: https://github.com/aaugustin/websockets/blob/main/LICENSE +.. _BSD license: https://github.com/python-websockets/websockets/blob/main/LICENSE diff --git a/docs/conf.py b/docs/conf.py index 58f0c2d55..9d61dc717 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -97,7 +97,7 @@ # Configure viewcode extension. from websockets.version import commit -code_url = f"https://github.com/aaugustin/websockets/blob/{commit}" +code_url = f"https://github.com/python-websockets/websockets/blob/{commit}" def linkcode_resolve(domain, info): # Non-linkable objects from the starter kit in the tutorial. diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index e56a42d36..e77f50add 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -41,7 +41,7 @@ every coroutine. See `issue 867`_. -.. _issue 867: https://github.com/aaugustin/websockets/issues/867 +.. _issue 867: https://github.com/python-websockets/websockets/issues/867 Why am I having problems with threads? -------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index be6a0da05..d9737db12 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,8 +12,8 @@ websockets .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg :target: https://pypi.python.org/pypi/websockets -.. |tests| image:: https://img.shields.io/github/checks-status/aaugustin/websockets/main?label=tests - :target: https://github.com/aaugustin/websockets/actions/workflows/tests.yml +.. |tests| image:: https://img.shields.io/github/checks-status/python-websockets/websockets/main?label=tests + :target: https://github.com/python-websockets/websockets/actions/workflows/tests.yml .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg :target: https://websockets.readthedocs.io/ diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 4d42447b7..6fdec113b 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -198,7 +198,7 @@ in ``main.js``: You can take this strategy one step further by checking the address of the HTTP server and determining the address of the WebSocket server accordingly. -Add this function to ``main.js``; replace ``aaugustin`` by your GitHub +Add this function to ``main.js``; replace ``python-websockets`` by your GitHub username and ``websockets-tutorial`` by the name of your app on Heroku: .. literalinclude:: ../../example/tutorial/step3/main.js @@ -226,12 +226,12 @@ Deploy the web application Go to GitHub and create a new repository called ``websockets-tutorial``. -Push your code to this repository. You must replace ``aaugustin`` by your -GitHub username in the following command: +Push your code to this repository. You must replace ``python-websockets`` by +your GitHub username in the following command: .. code-block:: console - $ git remote add origin git@github.com:aaugustin/websockets-tutorial.git + $ git remote add origin git@github.com:python-websockets/websockets-tutorial.git $ git push -u origin main Enumerating objects: 11, done. Counting objects: 100% (11/11), done. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f35caa19e..127ea3142 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -225,7 +225,7 @@ Improvements * Reverted optimization of default compression settings for clients, mainly to avoid triggering bugs in poorly implemented servers like `AWS API Gateway`_. - .. _AWS API Gateway: https://github.com/aaugustin/websockets/issues/1065 + .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 * Mirrored the entire :class:`~asyncio.Server` API in :class:`~server.WebSocketServer`. diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 43fd58dc8..020ed7ad8 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -10,7 +10,7 @@ This project and everyone participating in it is governed by the `Code of Conduct`_. By participating, you are expected to uphold this code. Please report inappropriate behavior to aymeric DOT augustin AT fractalideas DOT com. -.. _Code of Conduct: https://github.com/aaugustin/websockets/blob/main/CODE_OF_CONDUCT.md +.. _Code of Conduct: https://github.com/python-websockets/websockets/blob/main/CODE_OF_CONDUCT.md *(If I'm the person with the inappropriate behavior, please accept my apologies. I know I can mess up. I can't expect you to tell me, but if you @@ -31,8 +31,8 @@ If you're wondering why things are done in a certain way, the :doc:`design document <../topics/design>` provides lots of details about the internals of websockets. -.. _issue: https://github.com/aaugustin/websockets/issues/new -.. _pull request: https://github.com/aaugustin/websockets/compare/ +.. _issue: https://github.com/python-websockets/websockets/issues/new +.. _pull request: https://github.com/python-websockets/websockets/compare/ Questions --------- diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 7e6e262dc..3cc52ec10 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -78,7 +78,7 @@ Both sides | Heartbeat | ✅ | ❌ | — | +------------------------------------+--------+--------+--------+ -.. _#479: https://github.com/aaugustin/websockets/issues/479 +.. _#479: https://github.com/python-websockets/websockets/issues/479 Server ------ @@ -162,9 +162,9 @@ Client | (`#475`_) | | | | +------------------------------------+--------+--------+--------+ -.. _#364: https://github.com/aaugustin/websockets/issues/364 -.. _#475: https://github.com/aaugustin/websockets/issues/475 -.. _#784: https://github.com/aaugustin/websockets/issues/784 +.. _#364: https://github.com/python-websockets/websockets/issues/364 +.. _#475: https://github.com/python-websockets/websockets/issues/475 +.. _#784: https://github.com/python-websockets/websockets/issues/784 Known limitations ----------------- @@ -172,7 +172,7 @@ Known limitations There is no way to control compression of outgoing frames on a per-frame basis (`#538`_). If compression is enabled, all frames are compressed. -.. _#538: https://github.com/aaugustin/websockets/issues/538 +.. _#538: https://github.com/python-websockets/websockets/issues/538 The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 1849d635a..60dd38766 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -152,7 +152,7 @@ The `experiments/authentication`_ directory demonstrates these techniques. Run the experiment in an environment where websockets is installed: -.. _experiments/authentication: https://github.com/aaugustin/websockets/tree/main/experiments/authentication +.. _experiments/authentication: https://github.com/python-websockets/websockets/tree/main/experiments/authentication .. code-block:: console diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 9a25cbf7d..1acb372d4 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -51,7 +51,7 @@ message, or else it will never let the server run. That's why it includes A complete example is available in the `experiments/broadcast`_ directory. -.. _experiments/broadcast: https://github.com/aaugustin/websockets/tree/main/experiments/broadcast +.. _experiments/broadcast: https://github.com/python-websockets/websockets/tree/main/experiments/broadcast The naive way ------------- diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index f0b7ce898..eaf99070d 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -158,7 +158,7 @@ corpus. Defaults must be safe for all applications, hence a more conservative choice. -.. _compression/benchmark.py: https://github.com/aaugustin/websockets/blob/main/experiments/compression/benchmark.py +.. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py The benchmark focuses on compression because it's more expensive than decompression. Indeed, leaving aside small allocations, theoretical memory @@ -199,7 +199,7 @@ like the server side: 3. On a more pragmatic note, some servers misbehave badly when a client configures compression settings. `AWS API Gateway`_ is the worst offender. - .. _AWS API Gateway: https://github.com/aaugustin/websockets/issues/1065 + .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 Unfortunately, even though websockets is right and AWS is wrong, many users jump to the conclusion that websockets doesn't work. diff --git a/example/tutorial/step3/main.js b/example/tutorial/step3/main.js index 15afd4163..3000fa2f7 100644 --- a/example/tutorial/step3/main.js +++ b/example/tutorial/step3/main.js @@ -1,7 +1,7 @@ import { createBoard, playMove } from "./connect4.js"; function getWebSocketServer() { - if (window.location.host === "aaugustin.github.io") { + if (window.location.host === "python-websockets.github.io") { return "wss://websockets-tutorial.herokuapp.com/"; } else if (window.location.host === "localhost:8000") { return "ws://localhost:8001/"; diff --git a/experiments/compression/benchmark.py b/experiments/compression/benchmark.py index bdcdd8e95..c5b13c8fa 100644 --- a/experiments/compression/benchmark.py +++ b/experiments/compression/benchmark.py @@ -20,7 +20,7 @@ def _corpus(): OAUTH_TOKEN = getpass.getpass("OAuth Token? ") COMMIT_API = ( f'curl -H "Authorization: token {OAUTH_TOKEN}" ' - f"https://api.github.com/repos/aaugustin/websockets/git/commits/:sha" + f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" ) commits = [] diff --git a/pyproject.toml b/pyproject.toml index 9065aab45..941b01c08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,11 @@ classifiers = [ dynamic = ["version", "readme"] [project.urls] -homepage = "https://github.com/aaugustin/websockets" +homepage = "https://github.com/python-websockets/websockets" changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" documentation = "https://websockets.readthedocs.io/" funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" -tracker = "https://github.com/aaugustin/websockets/issues" +tracker = "https://github.com/python-websockets/websockets/issues" # On a macOS runner, build Intel, Universal, and Apple Silicon wheels. [tool.cibuildwheel.macos] diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 9338e15bd..514d91ef8 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1373,7 +1373,7 @@ def test_remote_close_and_connection_lost(self): def test_simultaneous_close(self): # Receive the incoming close frame right after self.protocol.close() # starts executing. This reproduces the error described in: - # https://github.com/aaugustin/websockets/issues/339 + # https://github.com/python-websockets/websockets/issues/339 self.loop.call_soon(self.receive_frame, self.remote_close) self.loop.call_soon(self.receive_eof_if_client) self.run_loop_once() From 8c6f7267e6ebb619904df7a4c8c005cc30307de6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 08:13:02 +0200 Subject: [PATCH 431/762] Capitalize link titles for nicer display on PyPI. --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 941b01c08..c26e3aabc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,11 @@ classifiers = [ dynamic = ["version", "readme"] [project.urls] -homepage = "https://github.com/python-websockets/websockets" -changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" -documentation = "https://websockets.readthedocs.io/" -funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" -tracker = "https://github.com/python-websockets/websockets/issues" +Homepage = "https://github.com/python-websockets/websockets" +Changelog = "https://websockets.readthedocs.io/en/stable/project/changelog.html" +Documentation = "https://websockets.readthedocs.io/" +Funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" +Tracker = "https://github.com/python-websockets/websockets/issues" # On a macOS runner, build Intel, Universal, and Apple Silicon wheels. [tool.cibuildwheel.macos] From a0daea5b870535996f61231541a38aa43b797b9a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Apr 2023 18:00:55 +0200 Subject: [PATCH 432/762] Make recv buffer size a class attribute. This supports overriding it by subclassing, rather than monkey-patching. Also use a more realistic value in the docs. --- docs/howto/sansio.rst | 2 +- src/websockets/sync/connection.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index 08b09f7ce..d41519ff0 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -133,7 +133,7 @@ When reaching the end of the data stream, call the protocol's For example, if ``sock`` is a :obj:`~socket.socket`:: try: - data = sock.recv(4096) + data = sock.recv(65536) except OSError: # socket closed data = b"" if data: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 64e5c8b44..59c0bc071 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -23,8 +23,6 @@ logger = logging.getLogger(__name__) -BUFSIZE = 65536 - class Connection: """ @@ -39,6 +37,8 @@ class Connection: """ + recv_bufsize = 65536 + def __init__( self, socket: socket.socket, @@ -525,7 +525,7 @@ def recv_events(self) -> None: try: if self.close_deadline is not None: self.socket.settimeout(self.close_deadline.timeout()) - data = self.socket.recv(BUFSIZE) + data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: self.logger.debug("error while receiving data", exc_info=True) From a6e14978190ddf79e03f34df39a27a99b5619c2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Apr 2023 18:30:22 +0200 Subject: [PATCH 433/762] Archive benchmark script for stream readers. --- experiments/optimization/streams.py | 301 ++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 experiments/optimization/streams.py diff --git a/experiments/optimization/streams.py b/experiments/optimization/streams.py new file mode 100644 index 000000000..ca24a5983 --- /dev/null +++ b/experiments/optimization/streams.py @@ -0,0 +1,301 @@ +""" +Benchmark two possible implementations of a stream reader. + +The difference lies in the data structure that buffers incoming data: + +* ``ByteArrayStreamReader`` uses a ``bytearray``; +* ``BytesDequeStreamReader`` uses a ``deque[bytes]``. + +``ByteArrayStreamReader`` is faster for streaming small frames, which is the +standard use case of websockets, likely due to its simple implementation and +to ``bytearray`` being fast at appending data and removing data at the front +(https://hg.python.org/cpython/rev/499a96611baa). + +``BytesDequeStreamReader`` is faster for large frames and for bursts, likely +because it copies payloads only once, while ``ByteArrayStreamReader`` copies +them twice. + +""" + + +import collections +import os +import timeit + + +# Implementations + + +class ByteArrayStreamReader: + def __init__(self): + self.buffer = bytearray() + self.eof = False + + def readline(self): + n = 0 # number of bytes to read + p = 0 # number of bytes without a newline + while True: + n = self.buffer.find(b"\n", p) + 1 + if n > 0: + break + p = len(self.buffer) + yield + r = self.buffer[:n] + del self.buffer[:n] + return r + + def readexactly(self, n): + assert n >= 0 + while len(self.buffer) < n: + yield + r = self.buffer[:n] + del self.buffer[:n] + return r + + def feed_data(self, data): + self.buffer += data + + def feed_eof(self): + self.eof = True + + def at_eof(self): + return self.eof and not self.buffer + + +class BytesDequeStreamReader: + def __init__(self): + self.buffer = collections.deque() + self.eof = False + + def readline(self): + b = [] + while True: + # Read next chunk + while True: + try: + c = self.buffer.popleft() + except IndexError: + yield + else: + break + # Handle chunk + n = c.find(b"\n") + 1 + if n == len(c): + # Read exactly enough data + b.append(c) + break + elif n > 0: + # Read too much data + b.append(c[:n]) + self.buffer.appendleft(c[n:]) + break + else: # n == 0 + # Need to read more data + b.append(c) + return b"".join(b) + + def readexactly(self, n): + if n == 0: + return b"" + b = [] + while True: + # Read next chunk + while True: + try: + c = self.buffer.popleft() + except IndexError: + yield + else: + break + # Handle chunk + n -= len(c) + if n == 0: + # Read exactly enough data + b.append(c) + break + elif n < 0: + # Read too much data + b.append(c[:n]) + self.buffer.appendleft(c[n:]) + break + else: # n >= 0 + # Need to read more data + b.append(c) + return b"".join(b) + + def feed_data(self, data): + self.buffer.append(data) + + def feed_eof(self): + self.eof = True + + def at_eof(self): + return self.eof and not self.buffer + + +# Tests + + +class Protocol: + def __init__(self, StreamReader): + self.reader = StreamReader() + self.events = [] + # Start parser coroutine + self.parser = self.run_parser() + next(self.parser) + + def run_parser(self): + while True: + frame = yield from self.reader.readexactly(2) + self.events.append(frame) + frame = yield from self.reader.readline() + self.events.append(frame) + + def data_received(self, data): + self.reader.feed_data(data) + next(self.parser) # run parser until more data is needed + events, self.events = self.events, [] + return events + + +def run_test(StreamReader): + proto = Protocol(StreamReader) + + actual = proto.data_received(b"a") + expected = [] + assert actual == expected, f"{actual} != {expected}" + + actual = proto.data_received(b"b") + expected = [b"ab"] + assert actual == expected, f"{actual} != {expected}" + + actual = proto.data_received(b"c") + expected = [] + assert actual == expected, f"{actual} != {expected}" + + actual = proto.data_received(b"\n") + expected = [b"c\n"] + assert actual == expected, f"{actual} != {expected}" + + actual = proto.data_received(b"efghi\njklmn") + expected = [b"ef", b"ghi\n", b"jk"] + assert actual == expected, f"{actual} != {expected}" + + +# Benchmarks + + +def get_frame_packets(size, packet_size=None): + if size < 126: + frame = bytes([138, size]) + elif size < 65536: + frame = bytes([138, 126]) + bytes(divmod(size, 256)) + else: + size1, size2 = divmod(size, 65536) + frame = ( + bytes([138, 127]) + bytes(divmod(size1, 256)) + bytes(divmod(size2, 256)) + ) + frame += os.urandom(size) + if packet_size is None: + return [frame] + else: + packets = [] + while frame: + packets.append(frame[:packet_size]) + frame = frame[packet_size:] + return packets + + +def benchmark_stream(StreamReader, packets, size, count): + reader = StreamReader() + for _ in range(count): + for packet in packets: + reader.feed_data(packet) + yield from reader.readexactly(2) + if size >= 65536: + yield from reader.readexactly(4) + elif size >= 126: + yield from reader.readexactly(2) + yield from reader.readexactly(size) + reader.feed_eof() + assert reader.at_eof() + + +def benchmark_burst(StreamReader, packets, size, count): + reader = StreamReader() + for _ in range(count): + for packet in packets: + reader.feed_data(packet) + reader.feed_eof() + for _ in range(count): + yield from reader.readexactly(2) + if size >= 65536: + yield from reader.readexactly(4) + elif size >= 126: + yield from reader.readexactly(2) + yield from reader.readexactly(size) + assert reader.at_eof() + + +def run_benchmark(size, count, packet_size=None, number=1000): + stmt = f"list(benchmark(StreamReader, packets, {size}, {count}))" + setup = f"packets = get_frame_packets({size}, {packet_size})" + context = globals() + + context["StreamReader"] = context["ByteArrayStreamReader"] + context["benchmark"] = context["benchmark_stream"] + bas = min(timeit.repeat(stmt, setup, number=number, globals=context)) + context["benchmark"] = context["benchmark_burst"] + bab = min(timeit.repeat(stmt, setup, number=number, globals=context)) + + context["StreamReader"] = context["BytesDequeStreamReader"] + context["benchmark"] = context["benchmark_stream"] + bds = min(timeit.repeat(stmt, setup, number=number, globals=context)) + context["benchmark"] = context["benchmark_burst"] + bdb = min(timeit.repeat(stmt, setup, number=number, globals=context)) + + print( + f"Frame size = {size} bytes, " + f"frame count = {count}, " + f"packet size = {packet_size}" + ) + print(f"* ByteArrayStreamReader (stream): {bas / number * 1_000_000:.1f}µs") + print( + f"* BytesDequeStreamReader (stream): " + f"{bds / number * 1_000_000:.1f}µs ({(bds / bas - 1) * 100:+.1f}%)" + ) + print(f"* ByteArrayStreamReader (burst): {bab / number * 1_000_000:.1f}µs") + print( + f"* BytesDequeStreamReader (burst): " + f"{bdb / number * 1_000_000:.1f}µs ({(bdb / bab - 1) * 100:+.1f}%)" + ) + print() + + +if __name__ == "__main__": + run_test(ByteArrayStreamReader) + run_test(BytesDequeStreamReader) + + run_benchmark(size=8, count=1000) + run_benchmark(size=60, count=1000) + run_benchmark(size=500, count=500) + run_benchmark(size=4_000, count=200) + run_benchmark(size=30_000, count=100) + run_benchmark(size=250_000, count=50) + run_benchmark(size=2_000_000, count=20) + + run_benchmark(size=4_000, count=200, packet_size=1024) + run_benchmark(size=30_000, count=100, packet_size=1024) + run_benchmark(size=250_000, count=50, packet_size=1024) + run_benchmark(size=2_000_000, count=20, packet_size=1024) + + run_benchmark(size=30_000, count=100, packet_size=4096) + run_benchmark(size=250_000, count=50, packet_size=4096) + run_benchmark(size=2_000_000, count=20, packet_size=4096) + + run_benchmark(size=30_000, count=100, packet_size=16384) + run_benchmark(size=250_000, count=50, packet_size=16384) + run_benchmark(size=2_000_000, count=20, packet_size=16384) + + run_benchmark(size=250_000, count=50, packet_size=65536) + run_benchmark(size=2_000_000, count=20, packet_size=65536) From bb17be2fb0b99c8638788648fdd83f9049c1c344 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 11 Apr 2023 21:50:53 +0200 Subject: [PATCH 434/762] Add scripts to benchmark parsers. --- experiments/optimization/parse_frames.py | 101 +++++++++++++++++++ experiments/optimization/parse_handshake.py | 102 ++++++++++++++++++++ 2 files changed, 203 insertions(+) create mode 100644 experiments/optimization/parse_frames.py create mode 100644 experiments/optimization/parse_handshake.py diff --git a/experiments/optimization/parse_frames.py b/experiments/optimization/parse_frames.py new file mode 100644 index 000000000..e3acbe3c2 --- /dev/null +++ b/experiments/optimization/parse_frames.py @@ -0,0 +1,101 @@ +"""Benchark parsing WebSocket frames.""" + +import subprocess +import sys +import timeit + +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.frames import Frame, Opcode +from websockets.streams import StreamReader + + +# 256kB of text, compressible by about 70%. +text = subprocess.check_output(["git", "log", "8dd8e410"], text=True) + + +def get_frame(size): + repeat, remainder = divmod(size, 256 * 1024) + payload = repeat * text + text[:remainder] + return Frame(Opcode.TEXT, payload.encode(), True) + + +def parse_frame(data, count, mask, extensions): + reader = StreamReader() + for _ in range(count): + reader.feed_data(data) + parser = Frame.parse( + reader.read_exact, + mask=mask, + extensions=extensions, + ) + try: + next(parser) + except StopIteration: + pass + else: + assert False, "parser should return frame" + reader.feed_eof() + assert reader.at_eof(), "parser should consume all data" + + +def run_benchmark(size, count, compression=False, number=100): + if compression: + extensions = [PerMessageDeflate(True, True, 12, 12, {"memLevel": 5})] + else: + extensions = [] + globals = { + "get_frame": get_frame, + "parse_frame": parse_frame, + "extensions": extensions, + } + sppf = ( + min( + timeit.repeat( + f"parse_frame(data, {count}, mask=True, extensions=extensions)", + f"data = get_frame({size})" + f".serialize(mask=True, extensions=extensions)", + number=number, + globals=globals, + ) + ) + / number + / count + * 1_000_000 + ) + cppf = ( + min( + timeit.repeat( + f"parse_frame(data, {count}, mask=False, extensions=extensions)", + f"data = get_frame({size})" + f".serialize(mask=False, extensions=extensions)", + number=number, + globals=globals, + ) + ) + / number + / count + * 1_000_000 + ) + print(f"{size}\t{compression}\t{sppf:.2f}\t{cppf:.2f}") + + +if __name__ == "__main__": + print("Sizes are in bytes. Times are in µs per frame.", file=sys.stderr) + print("Run `tabs -16` for clean output. Pipe stdout to TSV for saving.") + print(file=sys.stderr) + + print("size\tcompression\tserver\tclient") + run_benchmark(size=8, count=1000, compression=False) + run_benchmark(size=60, count=1000, compression=False) + run_benchmark(size=500, count=1000, compression=False) + run_benchmark(size=4_000, count=1000, compression=False) + run_benchmark(size=30_000, count=200, compression=False) + run_benchmark(size=250_000, count=100, compression=False) + run_benchmark(size=2_000_000, count=20, compression=False) + + run_benchmark(size=8, count=1000, compression=True) + run_benchmark(size=60, count=1000, compression=True) + run_benchmark(size=500, count=200, compression=True) + run_benchmark(size=4_000, count=100, compression=True) + run_benchmark(size=30_000, count=20, compression=True) + run_benchmark(size=250_000, count=10, compression=True) diff --git a/experiments/optimization/parse_handshake.py b/experiments/optimization/parse_handshake.py new file mode 100644 index 000000000..af5a4ecae --- /dev/null +++ b/experiments/optimization/parse_handshake.py @@ -0,0 +1,102 @@ +"""Benchark parsing WebSocket handshake requests.""" + +# The parser for responses is designed similarly and should perform similarly. + +import sys +import timeit + +from websockets.http11 import Request +from websockets.streams import StreamReader + + +CHROME_HANDSHAKE = ( + b"GET / HTTP/1.1\r\n" + b"Host: localhost:5678\r\n" + b"Connection: Upgrade\r\n" + b"Pragma: no-cache\r\n" + b"Cache-Control: no-cache\r\n" + b"User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36\r\n" + b"Upgrade: websocket\r\n" + b"Origin: null\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"Accept-Encoding: gzip, deflate, br\r\n" + b"Accept-Language: en-GB,en;q=0.9,en-US;q=0.8,fr;q=0.7\r\n" + b"Sec-WebSocket-Key: ebkySAl+8+e6l5pRKTMkyQ==\r\n" + b"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" + b"\r\n" +) + +FIREFOX_HANDSHAKE = ( + b"GET / HTTP/1.1\r\n" + b"Host: localhost:5678\r\n" + b"User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) " + b"Gecko/20100101 Firefox/111.0\r\n" + b"Accept: */*\r\n" + b"Accept-Language: en-US,en;q=0.7,fr-FR;q=0.3\r\n" + b"Accept-Encoding: gzip, deflate, br\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"Origin: null\r\n" + b"Sec-WebSocket-Extensions: permessage-deflate\r\n" + b"Sec-WebSocket-Key: 1PuS+hnb+0AXsL7z2hNAhw==\r\n" + b"Connection: keep-alive, Upgrade\r\n" + b"Sec-Fetch-Dest: websocket\r\n" + b"Sec-Fetch-Mode: websocket\r\n" + b"Sec-Fetch-Site: cross-site\r\n" + b"Pragma: no-cache\r\n" + b"Cache-Control: no-cache\r\n" + b"Upgrade: websocket\r\n" + b"\r\n" +) + +WEBSOCKETS_HANDSHAKE = ( + b"GET / HTTP/1.1\r\n" + b"Host: localhost:8765\r\n" + b"Upgrade: websocket\r\n" + b"Connection: Upgrade\r\n" + b"Sec-WebSocket-Key: 9c55e0/siQ6tJPCs/QR8ZA==\r\n" + b"Sec-WebSocket-Version: 13\r\n" + b"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" + b"User-Agent: Python/3.11 websockets/11.0\r\n" + b"\r\n" +) + + +def parse_handshake(handshake): + reader = StreamReader() + reader.feed_data(handshake) + parser = Request.parse(reader.read_line) + try: + next(parser) + except StopIteration: + pass + else: + assert False, "parser should return request" + reader.feed_eof() + assert reader.at_eof(), "parser should consume all data" + + +def run_benchmark(name, handshake, number=10000): + ph = ( + min( + timeit.repeat( + "parse_handshake(handshake)", + number=number, + globals={"parse_handshake": parse_handshake, "handshake": handshake}, + ) + ) + / number + * 1_000_000 + ) + print(f"{name}\t{len(handshake)}\t{ph:.1f}") + + +if __name__ == "__main__": + print("Sizes are in bytes. Times are in µs per frame.", file=sys.stderr) + print("Run `tabs -16` for clean output. Pipe stdout to TSV for saving.") + print(file=sys.stderr) + + print("client\tsize\ttime") + run_benchmark("Chrome", CHROME_HANDSHAKE) + run_benchmark("Firefox", FIREFOX_HANDSHAKE) + run_benchmark("websockets", WEBSOCKETS_HANDSHAKE) From c38507fde13b9d2ac9bf62b8f68dba8d93f1fcd3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Apr 2023 08:14:33 +0200 Subject: [PATCH 435/762] Avoid deadlock when closing sync connection with unread messages. Fix #1336. --- docs/project/changelog.rst | 11 ++++++++ src/websockets/sync/connection.py | 23 +++++++++++------ tests/sync/test_connection.py | 43 +++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 127ea3142..f54880b06 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,17 @@ Backwards-incompatible changes websockets 11.0 is the last version supporting Python 3.7. +11.0.2 +------ + +*April 18, 2023* + +Bug fixes +......... + +* Fixed a deadlock in the :mod:`threading` implementation when closing a + connection without reading all messages. + 11.0.1 ------ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 59c0bc071..afebf5ea9 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -382,8 +382,9 @@ def close(self, code: int = 1000, reason: str = "") -> None: """ Perform the closing handshake. - :meth:`close` waits for the other end to complete the handshake and - for the TCP connection to terminate. + :meth:`close` waits for the other end to complete the handshake, for the + TCP connection to terminate, and for all incoming messages to be read + with :meth:`recv`. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. @@ -574,9 +575,13 @@ def recv_events(self) -> None: # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. - for event in events: - # This isn't expected to raise an exception. - self.process_event(event) + try: + for event in events: + # This may raise EOFError if the closing handshake + # times out while a message is waiting to be read. + self.process_event(event) + except EOFError: + break # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. @@ -600,7 +605,6 @@ def recv_events(self) -> None: self.protocol.state = CLOSED finally: # This isn't expected to raise an exception. - self.recv_messages.close() self.close_socket() @contextlib.contextmanager @@ -745,13 +749,16 @@ def set_recv_events_exc(self, exc: Optional[BaseException]) -> None: def close_socket(self) -> None: """ - Shutdown and close socket. + Shutdown and close socket. Close message assembler. - shutdown() is required to interrupt recv() on Linux. + Calling close_socket() guarantees that recv_events() terminates. Indeed, + recv_events() may block only on socket.recv() or on recv_messages.put(). """ + # shutdown() is required to interrupt recv() on Linux. try: self.socket.shutdown(socket.SHUT_RDWR) except OSError: pass # socket is already closed self.socket.close() + self.recv_messages.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index f26ec3f95..0e7cff948 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -458,6 +458,49 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + def test_close_waits_for_recv(self): + self.remote_connection.send("😀") + + close_thread = threading.Thread(target=self.connection.close) + close_thread.start() + + # Let close() initiate the closing handshake and send a close frame. + time.sleep(MS) + self.assertTrue(close_thread.is_alive()) + + # Connection isn't closed yet. + self.connection.recv() + + # Let close() receive a close frame and finish the closing handshake. + time.sleep(MS) + self.assertFalse(close_thread.is_alive()) + + # Connection is closed now. + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + def test_close_timeout_waiting_for_recv(self): + self.remote_connection.send("😀") + + close_thread = threading.Thread(target=self.connection.close) + close_thread.start() + + # Let close() time out during the closing handshake. + time.sleep(3 * MS) + self.assertFalse(close_thread.is_alive()) + + # Connection is closed now. + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() From fe629dede6eb083013e0d2373d5c3120c0078db3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 20 Apr 2023 16:45:55 +0200 Subject: [PATCH 436/762] Fix typo in changelog. Fix #1340. --- docs/project/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f54880b06..c5574f442 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -105,7 +105,7 @@ Backwards-incompatible changes New features ............ -.. admonition:: websockets 10.0 introduces a implementation on top of :mod:`threading`. +.. admonition:: websockets 11.0 introduces a implementation on top of :mod:`threading`. :class: important It may be more convenient if you don't need to manage many connections and From e152ceda9ec5d9e5052b3ae395ade2fae83d73fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 4 May 2023 15:30:42 +0200 Subject: [PATCH 437/762] Add Open Collective and GitHub Sponsors. --- .github/FUNDING.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 7ae223b3d..c6c5426a5 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1 +1,3 @@ -tidelift: "pypi/websockets" +github: python-websockets +open_collective: websockets +tidelift: pypi/websockets From 5aafc9e88fb1abf03012d71bdace770e3771c376 Mon Sep 17 00:00:00 2001 From: Carl Harris Date: Sun, 7 May 2023 09:41:16 -0400 Subject: [PATCH 438/762] Use selectors instead of select.poll in sync.WebSocket Server for multi-platform support (#1349) * use multiplatform selector instead of poll * don't use os.pipe with the I/O multiplexing selector on win32 On the Win32 platform, only sockets can be used with I/O multiplexing (such as that performed by selectors.DefaultSelector); the pipe cannot be added to the selector. However, on the win32 platform, simply closing the listener socket is enough to cause the call to select to return -- the additional pipe is redundant. On Mac OS X (and possibly other BSD derivatives), closing the listener socket isn't enough. In the interest of maximum compatibility, we simply disable the use of os.pipe on the Win32 platform. * exclude platform checks for win32 from coverage testing --- pyproject.toml | 1 + src/websockets/sync/server.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c26e3aabc..530052ddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ source = [ exclude_lines = [ "except ImportError:", "if self.debug:", + "if sys.platform != \"win32\":", "pragma: no cover", "raise AssertionError", "raise NotImplementedError", diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index e25646a5c..072fce717 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -3,9 +3,10 @@ import http import logging import os -import select +import selectors import socket import ssl +import sys import threading from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type @@ -199,7 +200,8 @@ def __init__( if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger - self.shutdown_watcher, self.shutdown_notifier = os.pipe() + if sys.platform != "win32": + self.shutdown_watcher, self.shutdown_notifier = os.pipe() def serve_forever(self) -> None: """ @@ -214,15 +216,16 @@ def serve_forever(self) -> None: server.serve_forever() """ - poller = select.poll() - poller.register(self.socket) - poller.register(self.shutdown_watcher) + poller = selectors.DefaultSelector() + poller.register(self.socket, selectors.EVENT_READ) + if sys.platform != "win32": + poller.register(self.shutdown_watcher, selectors.EVENT_READ) while True: - poller.poll() + poller.select() try: # If the socket is closed, this will raise an exception and exit - # the loop. So we don't need to check the return value of poll(). + # the loop. So we don't need to check the return value of select(). sock, addr = self.socket.accept() except OSError: break @@ -235,7 +238,8 @@ def shutdown(self) -> None: """ self.socket.close() - os.write(self.shutdown_notifier, b"x") + if sys.platform != "win32": + os.write(self.shutdown_notifier, b"x") def fileno(self) -> int: """ From bf51a57f3d9c1f9ed7c04709a3b09ec6f337c4c5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 May 2023 15:44:46 +0200 Subject: [PATCH 439/762] Add changelog for previous commit. --- docs/project/changelog.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c5574f442..36bcfc58b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,16 @@ Backwards-incompatible changes websockets 11.0 is the last version supporting Python 3.7. +11.0.3 +------ + +*May 7, 2023* + +Bug fixes +......... + +* Fixed the :mod:`threading` implementation of servers on Windows. + 11.0.2 ------ From 9c578a1f5a5e20b8901ac75f12a4dfec4877f6d2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 7 May 2023 16:20:55 +0200 Subject: [PATCH 440/762] Create GitHub release when pushing tag. Fix #1347. --- .github/workflows/wheels.yml | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index d9c9ea833..8aa5c0b7b 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -64,18 +64,25 @@ jobs: with: path: wheelhouse/*.whl - upload_pypi: - name: Upload to PyPI + release: + name: Release needs: - sdist - wheels runs-on: ubuntu-latest + # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/download-artifact@v3 + - name: Download artifacts + uses: actions/download-artifact@v3 with: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@release/v1 + - name: Upload to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} + - name: Create GitHub release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: gh release create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." From 73394df61d3b4c62f4868b75d16e1a81b2329f0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 09:23:47 +0200 Subject: [PATCH 441/762] Document how to override DNS resolution. Ref #1343. --- docs/faq/client.rst | 10 ++++++++++ docs/faq/server.rst | 8 +++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index c4f5a35b9..c590ac107 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -48,6 +48,16 @@ In the :mod:`threading` API, this argument is named ``additional_headers``:: with connect(..., additional_headers={"Authorization": ...}) as websocket: ... +How do I force the IP address that the client connects to? +---------------------------------------------------------- + +Use the ``host`` argument of :meth:`~asyncio.loop.create_connection`:: + + await websockets.connect("ws://example.com", host="192.168.0.1") + +:func:`~client.connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection`. + How do I close a connection? ---------------------------- diff --git a/docs/faq/server.rst b/docs/faq/server.rst index d1388c701..08b412d30 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -252,10 +252,12 @@ It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address async def handler(websocket): remote_ip = websocket.remote_address[0] -How do I set the IP addresses my server listens on? ---------------------------------------------------- +How do I set the IP addresses that my server listens on? +-------------------------------------------------------- -Look at the ``host`` argument of :meth:`~asyncio.loop.create_server`. +Use the ``host`` argument of :meth:`~asyncio.loop.create_server`:: + + await websockets.serve(handler, host="192.168.0.1", port=8080) :func:`~server.serve` accepts the same arguments as :meth:`~asyncio.loop.create_server`. From 03d62c97fcafffa5cdbe4bb55b2a8d17a62eca33 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 16:52:39 +0200 Subject: [PATCH 442/762] Fix server shutdown on Python 3.12. Ref https://github.com/python/cpython/issues/79033. Fix #1356. --- src/websockets/legacy/server.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index a17c52328..77e0fdab7 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -761,18 +761,13 @@ async def _close(self, close_connections: bool) -> None: # Stop accepting new connections. self.server.close() - # Wait until self.server.close() completes. - await self.server.wait_closed() - # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. await asyncio.sleep(0) if close_connections: - # Close OPEN connections with status code 1001. Since the server was - # closed, handshake() closes OPENING connections with an HTTP 503 - # error. Wait until all connections are closed. - + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. close_tasks = [ asyncio.create_task(websocket.close(1001)) for websocket in self.websockets @@ -782,8 +777,10 @@ async def _close(self, close_connections: bool) -> None: if close_tasks: await asyncio.wait(close_tasks) - # Wait until all connection handlers are complete. + # Wait until all TCP connections are closed. + await self.server.wait_closed() + # Wait until all connection handlers terminate. # asyncio.wait doesn't accept an empty first argument. if self.websockets: await asyncio.wait( From 2845257be549b41e8279b2ae19f38258f72fb587 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 18:45:25 +0200 Subject: [PATCH 443/762] Mention application-level heartbeats. Fix #1330. --- docs/faq/common.rst | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 505149a64..0f3cd5910 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -151,4 +151,11 @@ See :doc:`../topics/timeouts` for details. How do I respond to pings? -------------------------- -Don't bother; websockets takes care of responding to pings with pongs. +If you are referring to Ping_ and Pong_ frames defined in the WebSocket +protocol, don't bother, because websockets handles them for you. + +.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 +.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + +If you are connecting to a server that defines its own heartbeat at the +application level, then you need to build that logic into your application. From 89fc4086b9bd8bb4c070284276c2eb5973f8b27c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 18:49:14 +0200 Subject: [PATCH 444/762] Don't crash if git is super slow. Fix #1334. --- src/websockets/version.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index f45f069ab..3f171b391 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -46,7 +46,11 @@ def get_version(tag: str) -> str: text=True, ).stdout.strip() # subprocess.run raises FileNotFoundError if git isn't on $PATH. - except (FileNotFoundError, subprocess.CalledProcessError): + except ( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): pass else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" From 2b627b26b3cd15f222881450cf0e8c19f7e826fd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 10:21:41 +0200 Subject: [PATCH 445/762] Provide an enum for close codes. Fix #1335. --- docs/faq/common.rst | 8 +- docs/howto/django.rst | 2 +- docs/howto/fly.rst | 3 +- docs/howto/heroku.rst | 3 +- docs/howto/kubernetes.rst | 2 +- docs/howto/render.rst | 3 +- docs/project/changelog.rst | 5 + docs/reference/datastructures.rst | 23 +- docs/topics/authentication.rst | 2 +- example/django/authentication.py | 3 +- example/django/notifications.py | 3 +- experiments/authentication/app.py | 3 +- src/websockets/exceptions.py | 2 +- src/websockets/frames.py | 89 +++++--- src/websockets/legacy/protocol.py | 36 ++-- src/websockets/protocol.py | 17 +- src/websockets/sync/connection.py | 19 +- src/websockets/sync/server.py | 4 +- tests/extensions/test_permessage_deflate.py | 3 +- tests/legacy/test_client_server.py | 27 +-- tests/legacy/test_framing.py | 8 +- tests/legacy/test_protocol.py | 131 +++++++---- tests/sync/test_connection.py | 14 +- tests/sync/test_server.py | 4 +- tests/test_exceptions.py | 67 ++++-- tests/test_frames.py | 41 +++- tests/test_protocol.py | 227 ++++++++++++-------- 27 files changed, 480 insertions(+), 269 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 0f3cd5910..2c63c4f36 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -54,8 +54,8 @@ There are several reasons why long-lived connections may be lost: If you're facing a reproducible issue, :ref:`enable debug logs ` to see when and how connections are closed. -What does ``ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received`` mean? ------------------------------------------------------------------------------------------------------------------------ +What does ``ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received`` mean? +--------------------------------------------------------------------------------------------------------------------- If you're seeing this traceback in the logs of a server: @@ -70,7 +70,7 @@ If you're seeing this traceback in the logs of a server: Traceback (most recent call last): ... - websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received or if a client crashes with this traceback: @@ -84,7 +84,7 @@ or if a client crashes with this traceback: Traceback (most recent call last): ... - websockets.exceptions.ConnectionClosedError: sent 1011 (unexpected error) keepalive ping timeout; no close frame received + websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received it means that the WebSocket connection suffered from excessive latency and was closed after reaching the timeout of websockets' keepalive mechanism. diff --git a/docs/howto/django.rst b/docs/howto/django.rst index c955a5ec1..e3da0a878 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -158,7 +158,7 @@ closes the connection: $ python -m websockets ws://localhost:8888/ Connected to ws://localhost:8888. > not a token - Connection closed: 1011 (unexpected error) authentication failed. + Connection closed: 1011 (internal error) authentication failed. You can also test from a browser by generating a new token and running the following code in the JavaScript console of the browser: diff --git a/docs/howto/fly.rst b/docs/howto/fly.rst index 7e404de61..ed001a2ae 100644 --- a/docs/howto/fly.rst +++ b/docs/howto/fly.rst @@ -174,5 +174,4 @@ away). Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing -handshake and the connection would be closed with code 1006 (connection closed -abnormally). +handshake and the connection would be closed with code 1006 (abnormal closure). diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index 2b3a44819..a97d2e7ce 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -180,5 +180,4 @@ away). Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing -handshake and the connection would be closed with code 1006 (connection closed -abnormally). +handshake and the connection would be closed with code 1006 (abnormal closure). diff --git a/docs/howto/kubernetes.rst b/docs/howto/kubernetes.rst index c217e5946..064a6ac4d 100644 --- a/docs/howto/kubernetes.rst +++ b/docs/howto/kubernetes.rst @@ -77,7 +77,7 @@ shut down gracefully: < Hey there! Connection closed: 1001 (going away). -If it didn't, you'd get code 1006 (connection closed abnormally). +If it didn't, you'd get code 1006 (abnormal closure). Deploy application ------------------ diff --git a/docs/howto/render.rst b/docs/howto/render.rst index b8c417b66..70bf8c376 100644 --- a/docs/howto/render.rst +++ b/docs/howto/render.rst @@ -167,7 +167,6 @@ deployment completes, the connection is closed with code 1001 (going away). Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing -handshake and the connection would be closed with code 1006 (connection closed -abnormally). +handshake and the connection would be closed with code 1006 (abnormal closure). Remember to downgrade to a free plan if you upgraded just for testing this feature. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 36bcfc58b..94fc5ebd9 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,11 @@ Backwards-incompatible changes websockets 11.0 is the last version supporting Python 3.7. +Improvements +............ + +* Added :class:`~frames.CloseCode`. + 11.0.3 ------ diff --git a/docs/reference/datastructures.rst b/docs/reference/datastructures.rst index 8217052d1..7c037da4c 100644 --- a/docs/reference/datastructures.rst +++ b/docs/reference/datastructures.rst @@ -11,19 +11,32 @@ WebSocket events .. autoclass:: Opcode .. autoattribute:: CONT - .. autoattribute:: TEXT - .. autoattribute:: BINARY - .. autoattribute:: CLOSE - .. autoattribute:: PING - .. autoattribute:: PONG .. autoclass:: Close + .. autoclass:: CloseCode + + .. autoattribute:: OK + .. autoattribute:: GOING_AWAY + .. autoattribute:: PROTOCOL_ERROR + .. autoattribute:: UNSUPPORTED_DATA + .. autoattribute:: NO_STATUS_RCVD + .. autoattribute:: CONNECTION_CLOSED_ABNORMALLY + .. autoattribute:: INVALID_DATA + .. autoattribute:: POLICY_VIOLATION + .. autoattribute:: MESSAGE_TOO_BIG + .. autoattribute:: MANDATORY_EXTENSION + .. autoattribute:: INTERNAL_ERROR + .. autoattribute:: SERVICE_RESTART + .. autoattribute:: TRY_AGAIN_LATER + .. autoattribute:: BAD_GATEWAY + .. autoattribute:: TLS_FAILURE + HTTP events ----------- diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 60dd38766..31bc8e6da 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -189,7 +189,7 @@ connection: token = await websocket.recv() user = get_user(token) if user is None: - await websocket.close(1011, "authentication failed") + await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return ... diff --git a/example/django/authentication.py b/example/django/authentication.py index 7f60f8275..f6dad0f55 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -8,13 +8,14 @@ django.setup() from sesame.utils import get_user +from websockets.frames import CloseCode async def handler(websocket): sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) if user is None: - await websocket.close(1011, "authentication failed") + await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return await websocket.send(f"Hello {user}!") diff --git a/example/django/notifications.py b/example/django/notifications.py index 7275a1ef7..3a9ed10cf 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -11,6 +11,7 @@ from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user +from websockets.frames import CloseCode CONNECTIONS = {} @@ -33,7 +34,7 @@ async def handler(websocket): sesame = await websocket.recv() user = await asyncio.to_thread(get_user, sesame) if user is None: - await websocket.close(1011, "authentication failed") + await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return ct_ids = await asyncio.to_thread(get_content_types, user) diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index 6b3b2ae3f..039e21174 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -9,6 +9,7 @@ import uuid import websockets +from websockets.frames import CloseCode # User accounts database @@ -95,7 +96,7 @@ async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) if user is None: - await websocket.close(1011, "authentication failed") + await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return await websocket.send(f"Hello {user}!") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 22a3b583f..9d8476648 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -121,7 +121,7 @@ def __str__(self) -> str: @property def code(self) -> int: if self.rcvd is None: - return 1006 + return frames.CloseCode.ABNORMAL_CLOSURE return self.rcvd.code @property diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 8e0e6d873..6b1befb2e 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -52,48 +52,69 @@ class Opcode(enum.IntEnum): CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG -# See https://www.iana.org/assignments/websocket/websocket.xhtml -CLOSE_CODES = { - 1000: "OK", - 1001: "going away", - 1002: "protocol error", - 1003: "unsupported type", +class CloseCode(enum.IntEnum): + """Close code values for WebSocket close frames.""" + + NORMAL_CLOSURE = 1000 + GOING_AWAY = 1001 + PROTOCOL_ERROR = 1002 + UNSUPPORTED_DATA = 1003 # 1004 is reserved - 1005: "no status code [internal]", - 1006: "connection closed abnormally [internal]", - 1007: "invalid data", - 1008: "policy violation", - 1009: "message too big", - 1010: "extension required", - 1011: "unexpected error", - 1012: "service restart", - 1013: "try again later", - 1014: "bad gateway", - 1015: "TLS failure [internal]", + NO_STATUS_RCVD = 1005 + ABNORMAL_CLOSURE = 1006 + INVALID_DATA = 1007 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + MANDATORY_EXTENSION = 1010 + INTERNAL_ERROR = 1011 + SERVICE_RESTART = 1012 + TRY_AGAIN_LATER = 1013 + BAD_GATEWAY = 1014 + TLS_HANDSHAKE = 1015 + + +# See https://www.iana.org/assignments/websocket/websocket.xhtml +CLOSE_CODE_EXPLANATIONS: dict[int, str] = { + CloseCode.NORMAL_CLOSURE: "OK", + CloseCode.GOING_AWAY: "going away", + CloseCode.PROTOCOL_ERROR: "protocol error", + CloseCode.UNSUPPORTED_DATA: "unsupported data", + CloseCode.NO_STATUS_RCVD: "no status received [internal]", + CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]", + CloseCode.INVALID_DATA: "invalid frame payload data", + CloseCode.POLICY_VIOLATION: "policy violation", + CloseCode.MESSAGE_TOO_BIG: "message too big", + CloseCode.MANDATORY_EXTENSION: "mandatory extension", + CloseCode.INTERNAL_ERROR: "internal error", + CloseCode.SERVICE_RESTART: "service restart", + CloseCode.TRY_AGAIN_LATER: "try again later", + CloseCode.BAD_GATEWAY: "bad gateway", + CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]", } # Close code that are allowed in a close frame. # Using a set optimizes `code in EXTERNAL_CLOSE_CODES`. EXTERNAL_CLOSE_CODES = { - 1000, - 1001, - 1002, - 1003, - 1007, - 1008, - 1009, - 1010, - 1011, - 1012, - 1013, - 1014, + CloseCode.NORMAL_CLOSURE, + CloseCode.GOING_AWAY, + CloseCode.PROTOCOL_ERROR, + CloseCode.UNSUPPORTED_DATA, + CloseCode.INVALID_DATA, + CloseCode.POLICY_VIOLATION, + CloseCode.MESSAGE_TOO_BIG, + CloseCode.MANDATORY_EXTENSION, + CloseCode.INTERNAL_ERROR, + CloseCode.SERVICE_RESTART, + CloseCode.TRY_AGAIN_LATER, + CloseCode.BAD_GATEWAY, } + OK_CLOSE_CODES = { - 1000, - 1001, - 1005, + CloseCode.NORMAL_CLOSURE, + CloseCode.GOING_AWAY, + CloseCode.NO_STATUS_RCVD, } @@ -397,7 +418,7 @@ def __str__(self) -> str: elif 4000 <= self.code < 5000: explanation = "private use" else: - explanation = CLOSE_CODES.get(self.code, "unknown") + explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown") result = f"{self.code} ({explanation})" if self.reason: @@ -425,7 +446,7 @@ def parse(cls, data: bytes) -> Close: close.check() return close elif len(data) == 0: - return cls(1005, "") + return cls(CloseCode.NO_STATUS_RCVD, "") else: raise exceptions.ProtocolError("close frame too short") diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 0422c10da..19cee0e65 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -47,6 +47,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, Opcode, prepare_ctrl, prepare_data, @@ -459,7 +460,7 @@ def close_code(self) -> Optional[int]: if self.state is not State.CLOSED: return None elif self.close_rcvd is None: - return 1006 + return CloseCode.ABNORMAL_CLOSURE else: return self.close_rcvd.code @@ -681,7 +682,7 @@ async def send( except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. - self.fail_connection(1011) + self.fail_connection(CloseCode.INTERNAL_ERROR) raise finally: @@ -726,7 +727,7 @@ async def send( except (Exception, asyncio.CancelledError): # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. - self.fail_connection(1011) + self.fail_connection(CloseCode.INTERNAL_ERROR) raise finally: @@ -736,7 +737,11 @@ async def send( else: raise TypeError("data must be str, bytes-like, or iterable") - async def close(self, code: int = 1000, reason: str = "") -> None: + async def close( + self, + code: int = CloseCode.NORMAL_CLOSURE, + reason: str = "", + ) -> None: """ Perform the closing handshake. @@ -986,7 +991,7 @@ async def transfer_data(self) -> None: except ProtocolError as exc: self.transfer_data_exc = exc - self.fail_connection(1002) + self.fail_connection(CloseCode.PROTOCOL_ERROR) except (ConnectionError, TimeoutError, EOFError, ssl.SSLError) as exc: # Reading data with self.reader.readexactly may raise: @@ -997,15 +1002,15 @@ async def transfer_data(self) -> None: # bytes are available than requested; # - ssl.SSLError if the other side infringes the TLS protocol. self.transfer_data_exc = exc - self.fail_connection(1006) + self.fail_connection(CloseCode.ABNORMAL_CLOSURE) except UnicodeDecodeError as exc: self.transfer_data_exc = exc - self.fail_connection(1007) + self.fail_connection(CloseCode.INVALID_DATA) except PayloadTooBig as exc: self.transfer_data_exc = exc - self.fail_connection(1009) + self.fail_connection(CloseCode.MESSAGE_TOO_BIG) except Exception as exc: # This shouldn't happen often because exceptions expected under @@ -1014,7 +1019,7 @@ async def transfer_data(self) -> None: self.logger.error("data transfer failed", exc_info=True) self.transfer_data_exc = exc - self.fail_connection(1011) + self.fail_connection(CloseCode.INTERNAL_ERROR) async def read_message(self) -> Optional[Data]: """ @@ -1265,7 +1270,10 @@ async def keepalive_ping(self) -> None: except asyncio.TimeoutError: if self.debug: self.logger.debug("! timed out waiting for keepalive pong") - self.fail_connection(1011, "keepalive ping timeout") + self.fail_connection( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) break except ConnectionClosed: @@ -1377,7 +1385,11 @@ async def wait_for_connection_lost(self) -> bool: # and the moment this coroutine resumes running. return self.connection_lost_waiter.done() - def fail_connection(self, code: int = 1006, reason: str = "") -> None: + def fail_connection( + self, + code: int = CloseCode.ABNORMAL_CLOSURE, + reason: str = "", + ) -> None: """ 7.1.7. Fail the WebSocket Connection @@ -1408,7 +1420,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because of # an error reading from or writing to the network. # Don't send a close frame if the connection is broken. - if code != 1006 and self.state is State.OPEN: + if code != CloseCode.ABNORMAL_CLOSURE and self.state is State.OPEN: close = Close(code, reason) # Write the close frame without draining the write buffer. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 3fdd3881c..765e6b9bb 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -23,6 +23,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, Frame, ) from .http11 import Request, Response @@ -181,7 +182,7 @@ def close_code(self) -> Optional[int]: if self.state is not CLOSED: return None elif self.close_rcvd is None: - return 1006 + return CloseCode.ABNORMAL_CLOSURE else: return self.close_rcvd.code @@ -362,7 +363,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") - close = Close(1005, "") + close = Close(CloseCode.NO_STATUS_RCVD, "") data = b"" else: close = Close(code, reason) @@ -419,7 +420,7 @@ def fail(self, code: int, reason: str = "") -> None: # sent if it's CLOSING), except when failing the connection because # of an error reading from or writing to the network. if self.state is OPEN: - if code != 1006: + if code != CloseCode.ABNORMAL_CLOSURE: close = Close(code, reason) data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) @@ -549,25 +550,25 @@ def parse(self) -> Generator[None, None, None]: self.recv_frame(frame) except ProtocolError as exc: - self.fail(1002, str(exc)) + self.fail(CloseCode.PROTOCOL_ERROR, str(exc)) self.parser_exc = exc except EOFError as exc: - self.fail(1006, str(exc)) + self.fail(CloseCode.ABNORMAL_CLOSURE, str(exc)) self.parser_exc = exc except UnicodeDecodeError as exc: - self.fail(1007, f"{exc.reason} at position {exc.start}") + self.fail(CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}") self.parser_exc = exc except PayloadTooBig as exc: - self.fail(1009, str(exc)) + self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) self.parser_exc = exc except Exception as exc: self.logger.error("parser failed", exc_info=True) # Don't include exception details, which may be security-sensitive. - self.fail(1011) + self.fail(CloseCode.INTERNAL_ERROR) self.parser_exc = exc # During an abnormal closure, execution ends here after catching an diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index afebf5ea9..4a8879e37 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -11,7 +11,7 @@ from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError -from ..frames import DATA_OPCODES, BytesLike, Frame, Opcode, prepare_ctrl +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol @@ -141,7 +141,10 @@ def __exit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - self.close(1000 if exc_type is None else 1011) + if exc_type is None: + self.close() + else: + self.close(CloseCode.INTERNAL_ERROR) def __iter__(self) -> Iterator[Data]: """ @@ -372,13 +375,16 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise else: raise TypeError("data must be bytes, str, or iterable") - def close(self, code: int = 1000, reason: str = "") -> None: + def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: """ Perform the closing handshake. @@ -399,7 +405,10 @@ def close(self, code: int = 1000, reason: str = "") -> None: # to terminate after calling a method that sends a close frame. with self.send_context(): if self.send_in_progress: - self.protocol.fail(1011, "close during fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) else: self.protocol.send_close(code, reason) except ConnectionClosed: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 072fce717..14767968c 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -11,6 +11,8 @@ from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type +from websockets.frames import CloseCode + from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import validate_subprotocols @@ -497,7 +499,7 @@ def protocol_select_subprotocol( handler(connection) except Exception: protocol.logger.error("connection handler failed", exc_info=True) - connection.close(1011) + connection.close(CloseCode.INTERNAL_ERROR) else: connection.close() diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index c341fdb32..0e698566f 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -18,6 +18,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, Frame, ) @@ -74,7 +75,7 @@ def test_no_encode_decode_pong_frame(self): self.assertEqual(self.extension.decode(frame), frame) def test_no_encode_decode_close_frame(self): - frame = Frame(OP_CLOSE, Close(1000, "").serialize()) + frame = Frame(OP_CLOSE, Close(CloseCode.NORMAL_CLOSURE, "").serialize()) self.assertEqual(self.extension.encode(frame), frame) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 133af0536..cba24ad32 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -27,6 +27,7 @@ PerMessageDeflate, ServerPerMessageDeflateFactory, ) +from websockets.frames import CloseCode from websockets.http import USER_AGENT from websockets.legacy.client import * from websockets.legacy.compatibility import asyncio_timeout @@ -1150,7 +1151,7 @@ def test_server_handler_crashes(self, send): self.loop.run_until_complete(self.client.recv()) # Connection ends with an unexpected error. - self.assertEqual(self.client.close_code, 1011) + self.assertEqual(self.client.close_code, CloseCode.INTERNAL_ERROR) @with_server() @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.close") @@ -1163,7 +1164,7 @@ def test_server_close_crashes(self, close): self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. - self.assertEqual(self.client.close_code, 1006) + self.assertEqual(self.client.close_code, CloseCode.ABNORMAL_CLOSURE) @with_server() @with_client() @@ -1196,8 +1197,8 @@ def test_server_shuts_down_during_connection_handling(self): self.loop.run_until_complete(self.client.recv()) # Server closed the connection with 1001 Going Away. - self.assertEqual(self.client.close_code, 1001) - self.assertEqual(server_ws.close_code, 1001) + self.assertEqual(self.client.close_code, CloseCode.GOING_AWAY) + self.assertEqual(server_ws.close_code, CloseCode.GOING_AWAY) @with_server() def test_server_shuts_down_gracefully_during_connection_handling(self): @@ -1208,8 +1209,8 @@ def test_server_shuts_down_gracefully_during_connection_handling(self): self.loop.run_until_complete(self.client.recv()) # Client closed the connection with 1000 OK. - self.assertEqual(self.client.close_code, 1000) - self.assertEqual(server_ws.close_code, 1000) + self.assertEqual(self.client.close_code, CloseCode.NORMAL_CLOSURE) + self.assertEqual(server_ws.close_code, CloseCode.NORMAL_CLOSURE) @with_server() def test_server_shuts_down_and_waits_until_handlers_terminate(self): @@ -1271,7 +1272,7 @@ def test_connection_error_during_closing_handshake(self, close): self.assertEqual(reply, "Hello!") # Connection ends with an abnormal closure. - self.assertEqual(self.client.close_code, 1006) + self.assertEqual(self.client.close_code, CloseCode.ABNORMAL_CLOSURE) class ClientServerTests( @@ -1467,12 +1468,12 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - async def echo_handler_1001(ws): + async def echo_handler_going_away(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) - await ws.close(1001) + await ws.close(CloseCode.GOING_AWAY) - @with_server(handler=echo_handler_1001) + @with_server(handler=echo_handler_going_away) def test_iterate_on_messages_going_away_exit_ok(self): messages = [] @@ -1486,12 +1487,12 @@ async def run_client(): self.assertEqual(messages, self.MESSAGES) - async def echo_handler_1011(ws): + async def echo_handler_internal_error(ws): for message in AsyncIteratorTests.MESSAGES: await ws.send(message) - await ws.close(1011) + await ws.close(CloseCode.INTERNAL_ERROR) - @with_server(handler=echo_handler_1011) + @with_server(handler=echo_handler_internal_error) def test_iterate_on_messages_internal_error_exit_not_ok(self): messages = [] diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 035f0f03c..e1e4c891b 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -6,7 +6,7 @@ import warnings from websockets.exceptions import PayloadTooBig, ProtocolError -from websockets.frames import OP_BINARY, OP_CLOSE, OP_PING, OP_PONG, OP_TEXT +from websockets.frames import OP_BINARY, OP_CLOSE, OP_PING, OP_PONG, OP_TEXT, CloseCode from websockets.legacy.framing import * from .utils import AsyncioTestCase @@ -187,11 +187,11 @@ def assertCloseData(self, code, reason, data): self.assertEqual(parsed, (code, reason)) def test_parse_close_and_serialize_close(self): - self.assertCloseData(1000, "", b"\x03\xe8") - self.assertCloseData(1000, "OK", b"\x03\xe8OK") + self.assertCloseData(CloseCode.NORMAL_CLOSURE, "", b"\x03\xe8") + self.assertCloseData(CloseCode.NORMAL_CLOSURE, "OK", b"\x03\xe8OK") def test_parse_close_empty(self): - self.assertEqual(parse_close(b""), (1005, "")) + self.assertEqual(parse_close(b""), (CloseCode.NO_STATUS_RCVD, "")) def test_parse_close_errors(self): with self.assertRaises(ProtocolError): diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 514d91ef8..f2eb0fea0 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -15,6 +15,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, ) from websockets.legacy.framing import Frame from websockets.legacy.protocol import WebSocketCommonProtocol, broadcast @@ -121,9 +122,21 @@ async def delayed_drain(): self.protocol._drain = delayed_drain - close_frame = Frame(True, OP_CLOSE, Close(1000, "close").serialize()) - local_close = Frame(True, OP_CLOSE, Close(1000, "local").serialize()) - remote_close = Frame(True, OP_CLOSE, Close(1000, "remote").serialize()) + close_frame = Frame( + True, + OP_CLOSE, + Close(CloseCode.NORMAL_CLOSURE, "close").serialize(), + ) + local_close = Frame( + True, + OP_CLOSE, + Close(CloseCode.NORMAL_CLOSURE, "local").serialize(), + ) + remote_close = Frame( + True, + OP_CLOSE, + Close(CloseCode.NORMAL_CLOSURE, "remote").serialize(), + ) def receive_frame(self, frame): """ @@ -157,7 +170,7 @@ def receive_eof_if_client(self): if self.protocol.is_client: self.receive_eof() - def close_connection(self, code=1000, reason="close"): + def close_connection(self, code=CloseCode.NORMAL_CLOSURE, reason="close"): """ Execute a closing handshake. @@ -175,7 +188,11 @@ def close_connection(self, code=1000, reason="close"): assert self.protocol.state is State.CLOSED - def half_close_connection_local(self, code=1000, reason="close"): + def half_close_connection_local( + self, + code=CloseCode.NORMAL_CLOSURE, + reason="close", + ): """ Start a closing handshake but do not complete it. @@ -205,7 +222,11 @@ def half_close_connection_local(self, code=1000, reason="close"): # This task must be awaited or canceled by the caller. return close_task - def half_close_connection_remote(self, code=1000, reason="close"): + def half_close_connection_remote( + self, + code=CloseCode.NORMAL_CLOSURE, + reason="close", + ): """ Receive a closing handshake but do not complete it. @@ -299,10 +320,10 @@ def assertConnectionFailed(self, code, message): # The following line guarantees that connection_lost was called. self.assertEqual(self.protocol.state, State.CLOSED) # No close frame was received. - self.assertEqual(self.protocol.close_code, 1006) + self.assertEqual(self.protocol.close_code, CloseCode.ABNORMAL_CLOSURE) self.assertEqual(self.protocol.close_reason, "") # A close frame was sent -- unless the connection was already lost. - if code == 1006: + if code == CloseCode.ABNORMAL_CLOSURE: self.assertNoFrameSent() else: self.assertOneFrameSent(True, OP_CLOSE, Close(code, message).serialize()) @@ -391,11 +412,11 @@ def test_wait_closed(self): self.assertTrue(wait_closed.done()) def test_close_code(self): - self.close_connection(1001, "Bye!") - self.assertEqual(self.protocol.close_code, 1001) + self.close_connection(CloseCode.GOING_AWAY, "Bye!") + self.assertEqual(self.protocol.close_code, CloseCode.GOING_AWAY) def test_close_reason(self): - self.close_connection(1001, "Bye!") + self.close_connection(CloseCode.GOING_AWAY, "Bye!") self.assertEqual(self.protocol.close_reason, "Bye!") def test_close_code_not_set(self): @@ -439,24 +460,24 @@ def test_recv_on_closed_connection(self): def test_recv_protocol_error(self): self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8"))) self.process_invalid_frames() - self.assertConnectionFailed(1002, "") + self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_recv_unicode_error(self): self.receive_frame(Frame(True, OP_TEXT, "café".encode("latin-1"))) self.process_invalid_frames() - self.assertConnectionFailed(1007, "") + self.assertConnectionFailed(CloseCode.INVALID_DATA, "") def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) self.process_invalid_frames() - self.assertConnectionFailed(1009, "") + self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_recv_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(True, OP_BINARY, b"tea" * 342)) self.process_invalid_frames() - self.assertConnectionFailed(1009, "") + self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage @@ -531,7 +552,7 @@ async def read_message(): self.protocol.read_message = read_message self.process_invalid_frames() - self.assertConnectionFailed(1011, "") + self.assertConnectionFailed(CloseCode.INTERNAL_ERROR, "") def test_recv_canceled(self): recv = self.loop.create_task(self.protocol.recv()) @@ -667,7 +688,7 @@ def test_send_iterable_mixed_type_error(self): self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, Close(1011, "").serialize()), + (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) def test_send_iterable_prevents_concurrent_send(self): @@ -741,7 +762,7 @@ def test_send_async_iterable_mixed_type_error(self): ) self.assertFramesSent( (False, OP_TEXT, "café".encode("utf-8")), - (True, OP_CLOSE, Close(1011, "").serialize()), + (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) def test_send_async_iterable_prevents_concurrent_send(self): @@ -1068,14 +1089,14 @@ def test_fragmented_text_payload_too_big(self): self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) self.process_invalid_frames() - self.assertConnectionFailed(1009, "") + self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_fragmented_binary_payload_too_big(self): self.protocol.max_size = 1024 self.receive_frame(Frame(False, OP_BINARY, b"tea" * 171)) self.receive_frame(Frame(True, OP_CONT, b"tea" * 171)) self.process_invalid_frames() - self.assertConnectionFailed(1009, "") + self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage @@ -1104,7 +1125,7 @@ def test_unterminated_fragmented_text(self): # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.process_invalid_frames() - self.assertConnectionFailed(1002, "") + self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_handshake_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) @@ -1114,12 +1135,12 @@ def test_close_handshake_in_fragmented_text(self): # can be interjected in the middle of a fragmented message and that a # close frame must be echoed. Even though there's an unterminated # message, technically, the closing handshake was successful. - self.assertConnectionClosed(1005, "") + self.assertConnectionClosed(CloseCode.NO_STATUS_RCVD, "") def test_connection_close_in_fragmented_text(self): self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) self.process_invalid_frames() - self.assertConnectionFailed(1006, "") + self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") # Test miscellaneous code paths to ensure full coverage. @@ -1127,7 +1148,7 @@ def test_connection_lost(self): # Test calling connection_lost without going through close_connection. self.protocol.connection_lost(None) - self.assertConnectionFailed(1006, "") + self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") def test_ensure_open_before_opening_handshake(self): # Simulate a bug by forcibly reverting the protocol state. @@ -1168,7 +1189,7 @@ def test_connection_closed_attributes(self): self.loop.run_until_complete(self.protocol.recv()) connection_closed_exc = context.exception - self.assertEqual(connection_closed_exc.code, 1000) + self.assertEqual(connection_closed_exc.code, CloseCode.NORMAL_CLOSURE) self.assertEqual(connection_closed_exc.reason, "close") # Test the protocol logic for sending keepalive pings. @@ -1228,7 +1249,9 @@ def test_keepalive_ping_not_acknowledged_closes_connection(self): # Connection is closed at 6ms. self.loop.run_until_complete(asyncio.sleep(4 * MS)) self.assertOneFrameSent( - True, OP_CLOSE, Close(1011, "keepalive ping timeout").serialize() + True, + OP_CLOSE, + Close(CloseCode.INTERNAL_ERROR, "keepalive ping timeout").serialize(), ) # The keepalive ping task is complete. @@ -1328,13 +1351,13 @@ def test_local_close(self): # Run the closing handshake. self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertNoFrameSent() def test_remote_close(self): @@ -1347,13 +1370,13 @@ def test_remote_close(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(self.protocol.recv()) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) # Closing the connection again is a no-op. self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertNoFrameSent() def test_remote_close_and_connection_lost(self): @@ -1367,7 +1390,7 @@ def test_remote_close_and_connection_lost(self): with self.assertNoLogs(): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) def test_simultaneous_close(self): @@ -1380,7 +1403,7 @@ def test_simultaneous_close(self): self.loop.run_until_complete(self.protocol.close(reason="local")) - self.assertConnectionClosed(1000, "remote") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "remote") # The current implementation sends a close frame in response to the # close frame received from the remote end. It skips the close frame # that should be sent as a result of calling close(). @@ -1394,7 +1417,7 @@ def test_close_preserves_incoming_frames(self): self.loop.call_later(MS, self.receive_eof_if_client) self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") self.assertOneFrameSent(*self.close_frame) next_message = self.loop.run_until_complete(self.protocol.recv()) @@ -1407,14 +1430,14 @@ def test_close_protocol_error(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionFailed(1002, "") + self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_connection_lost(self): self.receive_eof() self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionFailed(1006, "") + self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_during_recv(self): recv = self.loop.create_task(self.protocol.recv()) @@ -1427,7 +1450,7 @@ def test_local_close_during_recv(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(recv) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") # There is no test_remote_close_during_recv because it would be identical # to test_remote_close. @@ -1442,7 +1465,7 @@ def test_remote_close_during_send(self): with self.assertRaises(ConnectionClosed): self.loop.run_until_complete(send) - self.assertConnectionClosed(1000, "close") + self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") # There is no test_local_close_during_send because this cannot really # happen, considering that writes are serialized. @@ -1557,7 +1580,7 @@ def test_local_close_send_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") + self.assertConnectionClosed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1565,7 +1588,7 @@ def test_local_close_receive_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(9 * MS, 19 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - self.assertConnectionClosed(1006, "") + self.assertConnectionClosed(CloseCode.ABNORMAL_CLOSURE, "") def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS @@ -1579,7 +1602,10 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1000, "close") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.NORMAL_CLOSURE, + "close", + ) def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1597,7 +1623,10 @@ def test_local_close_connection_lost_timeout_after_close(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1000, "close") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.NORMAL_CLOSURE, + "close", + ) class ClientTests(CommonTests, AsyncioTestCase): @@ -1616,7 +1645,10 @@ def test_local_close_send_close_frame_timeout(self): with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1006, "") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.ABNORMAL_CLOSURE, + "", + ) def test_local_close_receive_close_frame_timeout(self): self.protocol.close_timeout = 10 * MS @@ -1627,7 +1659,10 @@ def test_local_close_receive_close_frame_timeout(self): with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1006, "") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.ABNORMAL_CLOSURE, + "", + ) def test_local_close_connection_lost_timeout_after_write_eof(self): self.protocol.close_timeout = 10 * MS @@ -1643,7 +1678,10 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1000, "close") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.NORMAL_CLOSURE, + "close", + ) def test_local_close_connection_lost_timeout_after_close(self): self.protocol.close_timeout = 10 * MS @@ -1664,4 +1702,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed(1000, "close") # pragma: no cover + self.assertConnectionClosed( # pragma: no cover + CloseCode.NORMAL_CLOSURE, + "close", + ) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 0e7cff948..63544d4ad 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -10,7 +10,7 @@ from unittest.mock import patch from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK -from websockets.frames import Frame, Opcode +from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol from websockets.sync.connection import * @@ -134,7 +134,7 @@ def test_iter_connection_closed_ok(self): def test_iter_connection_closed_error(self): """__iter__ raises ConnnectionClosedError after an error.""" iterator = iter(self.connection) - self.remote_connection.close(code=1011) + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): next(iterator) @@ -168,7 +168,7 @@ def test_recv_connection_closed_ok(self): def test_recv_connection_closed_error(self): """recv raises ConnectionClosedError after an error.""" - self.remote_connection.close(code=1011) + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): self.connection.recv() @@ -248,7 +248,7 @@ def test_recv_streaming_connection_closed_ok(self): def test_recv_streaming_connection_closed_error(self): """recv_streaming raises ConnectionClosedError after an error.""" - self.remote_connection.close(code=1011) + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): list(self.connection.recv_streaming()) @@ -322,7 +322,7 @@ def test_send_connection_closed_ok(self): def test_send_connection_closed_error(self): """send raises ConnectionClosedError after an error.""" - self.remote_connection.close(code=1011) + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): self.connection.send("😀") @@ -400,7 +400,7 @@ def test_close(self): def test_close_explicit_code_reason(self): """close sends a close frame with a given code and reason.""" - self.connection.close(1001, "bye!") + self.connection.close(CloseCode.GOING_AWAY, "bye!") self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) def test_close_waits_for_close_frame(self): @@ -567,7 +567,7 @@ def fragments(): exc = raised.exception self.assertEqual( str(exc), - "sent 1011 (unexpected error) close during fragmented message; " + "sent 1011 (internal error) close during fragmented message; " "no close frame received", ) self.assertIsNone(exc.__cause__) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 60e70c0a5..f9db84246 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -64,8 +64,8 @@ def test_connection_handler_raises_exception(self): with run_client(server) as client: with self.assertRaisesRegex( ConnectionClosedError, - r"received 1011 \(unexpected error\); " - r"then sent 1011 \(unexpected error\)", + r"received 1011 \(internal error\); " + r"then sent 1011 \(internal error\)", ): client.recv() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index a0f9dfcd2..1e6f58fad 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -2,49 +2,80 @@ from websockets.datastructures import Headers from websockets.exceptions import * -from websockets.frames import Close +from websockets.frames import Close, CloseCode from websockets.http11 import Response class ExceptionsTests(unittest.TestCase): def test_str(self): for exception, exception_str in [ - # fmt: off ( WebSocketException("something went wrong"), "something went wrong", ), ( - ConnectionClosed(Close(1000, ""), Close(1000, ""), True), + ConnectionClosed( + Close(CloseCode.NORMAL_CLOSURE, ""), + Close(CloseCode.NORMAL_CLOSURE, ""), + True, + ), "received 1000 (OK); then sent 1000 (OK)", ), ( - ConnectionClosed(Close(1001, "Bye!"), Close(1001, "Bye!"), False), + ConnectionClosed( + Close(CloseCode.GOING_AWAY, "Bye!"), + Close(CloseCode.GOING_AWAY, "Bye!"), + False, + ), "sent 1001 (going away) Bye!; then received 1001 (going away) Bye!", ), ( - ConnectionClosed(Close(1000, "race"), Close(1000, "cond"), True), + ConnectionClosed( + Close(CloseCode.NORMAL_CLOSURE, "race"), + Close(CloseCode.NORMAL_CLOSURE, "cond"), + True, + ), "received 1000 (OK) race; then sent 1000 (OK) cond", ), ( - ConnectionClosed(Close(1000, "cond"), Close(1000, "race"), False), + ConnectionClosed( + Close(CloseCode.NORMAL_CLOSURE, "cond"), + Close(CloseCode.NORMAL_CLOSURE, "race"), + False, + ), "sent 1000 (OK) race; then received 1000 (OK) cond", ), ( - ConnectionClosed(None, Close(1009, ""), None), + ConnectionClosed( + None, + Close(CloseCode.MESSAGE_TOO_BIG, ""), + None, + ), "sent 1009 (message too big); no close frame received", ), ( - ConnectionClosed(Close(1002, ""), None, None), + ConnectionClosed( + Close(CloseCode.PROTOCOL_ERROR, ""), + None, + None, + ), "received 1002 (protocol error); no close frame sent", ), ( - ConnectionClosedOK(Close(1000, ""), Close(1000, ""), True), + ConnectionClosedOK( + Close(CloseCode.NORMAL_CLOSURE, ""), + Close(CloseCode.NORMAL_CLOSURE, ""), + True, + ), "received 1000 (OK); then sent 1000 (OK)", ), ( - ConnectionClosedError(None, None, None), - "no close frame received or sent" + ConnectionClosedError( + None, + None, + None, + ), + "no close frame received or sent", ), ( InvalidHandshake("invalid request"), @@ -75,11 +106,8 @@ def test_str(self): "invalid Name header: Value", ), ( - InvalidHeaderFormat( - "Sec-WebSocket-Protocol", "expected token", "a=|", 3 - ), - "invalid Sec-WebSocket-Protocol header: " - "expected token at 3 in a=|", + InvalidHeaderFormat("Sec-WebSocket-Protocol", "exp. token", "a=|", 3), + "invalid Sec-WebSocket-Protocol header: exp. token at 3 in a=|", ), ( InvalidHeaderValue("Sec-WebSocket-Version", "42"), @@ -153,17 +181,16 @@ def test_str(self): ProtocolError("invalid opcode: 7"), "invalid opcode: 7", ), - # fmt: on ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) def test_connection_closed_attributes_backwards_compatibility(self): - exception = ConnectionClosed(Close(1000, "OK"), None, None) - self.assertEqual(exception.code, 1000) + exception = ConnectionClosed(Close(CloseCode.NORMAL_CLOSURE, "OK"), None, None) + self.assertEqual(exception.code, CloseCode.NORMAL_CLOSURE) self.assertEqual(exception.reason, "OK") def test_connection_closed_attributes_backwards_compatibility_defaults(self): exception = ConnectionClosed(None, None, None) - self.assertEqual(exception.code, 1006) + self.assertEqual(exception.code, CloseCode.ABNORMAL_CLOSURE) self.assertEqual(exception.reason, "") diff --git a/tests/test_frames.py b/tests/test_frames.py index e7c48b930..e323b3b57 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -5,6 +5,7 @@ from websockets.exceptions import PayloadTooBig, ProtocolError from websockets.frames import * +from websockets.frames import CloseCode from websockets.streams import StreamReader from .utils import GeneratorTestCase @@ -444,18 +445,42 @@ def assertCloseData(self, close, data): self.assertEqual(parsed, close) def test_str(self): - self.assertEqual(str(Close(1000, "")), "1000 (OK)") - self.assertEqual(str(Close(1001, "Bye!")), "1001 (going away) Bye!") - self.assertEqual(str(Close(3000, "")), "3000 (registered)") - self.assertEqual(str(Close(4000, "")), "4000 (private use)") - self.assertEqual(str(Close(5000, "")), "5000 (unknown)") + self.assertEqual( + str(Close(CloseCode.NORMAL_CLOSURE, "")), + "1000 (OK)", + ) + self.assertEqual( + str(Close(CloseCode.GOING_AWAY, "Bye!")), + "1001 (going away) Bye!", + ) + self.assertEqual( + str(Close(3000, "")), + "3000 (registered)", + ) + self.assertEqual( + str(Close(4000, "")), + "4000 (private use)", + ) + self.assertEqual( + str(Close(5000, "")), + "5000 (unknown)", + ) def test_parse_and_serialize(self): - self.assertCloseData(Close(1001, ""), b"\x03\xe9") - self.assertCloseData(Close(1000, "OK"), b"\x03\xe8OK") + self.assertCloseData( + Close(CloseCode.NORMAL_CLOSURE, "OK"), + b"\x03\xe8OK", + ) + self.assertCloseData( + Close(CloseCode.GOING_AWAY, ""), + b"\x03\xe9", + ) def test_parse_empty(self): - self.assertEqual(Close.parse(b""), Close(1005, "")) + self.assertEqual( + Close.parse(b""), + Close(CloseCode.NO_STATUS_RCVD, ""), + ) def test_parse_errors(self): with self.assertRaises(ProtocolError): diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7321d2594..a64172b53 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -16,6 +16,7 @@ OP_PONG, OP_TEXT, Close, + CloseCode, Frame, ) from websockets.protocol import * @@ -133,14 +134,18 @@ def test_client_receives_masked_frame(self): client.receive_data(self.masked_text_frame_data) self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incorrect masking") - self.assertConnectionFailing(client, 1002, "incorrect masking") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "incorrect masking" + ) def test_server_receives_unmasked_frame(self): server = Protocol(SERVER) server.receive_data(self.unmasked_text_frame_date) self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incorrect masking") - self.assertConnectionFailing(server, 1002, "incorrect masking") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "incorrect masking" + ) class ContinuationTests(ProtocolTestCase): @@ -166,14 +171,18 @@ def test_client_receives_unexpected_continuation(self): client.receive_data(b"\x00\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "unexpected continuation frame") - self.assertConnectionFailing(client, 1002, "unexpected continuation frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "unexpected continuation frame" + ) def test_server_receives_unexpected_continuation(self): server = Protocol(SERVER) server.receive_data(b"\x00\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "unexpected continuation frame") - self.assertConnectionFailing(server, 1002, "unexpected continuation frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "unexpected continuation frame" + ) def test_client_sends_continuation_after_sending_close(self): client = Protocol(CLIENT) @@ -181,7 +190,7 @@ def test_client_sends_continuation_after_sending_close(self): # message (see test_client_send_close_in_fragmented_message), in fact, # this is the same test as test_client_sends_unexpected_continuation. with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(ProtocolError) as raised: client.send_continuation(b"", fin=False) @@ -192,7 +201,7 @@ def test_server_sends_continuation_after_sending_close(self): # message (see test_server_send_close_in_fragmented_message), in fact, # this is the same test as test_server_sends_unexpected_continuation. server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(ProtocolError) as raised: server.send_continuation(b"", fin=False) @@ -201,7 +210,7 @@ def test_server_sends_continuation_after_sending_close(self): def test_client_receives_continuation_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x00\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -209,7 +218,7 @@ def test_client_receives_continuation_after_receiving_close(self): def test_server_receives_continuation_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x00\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -255,14 +264,18 @@ def test_client_receives_text_over_size_limit(self): client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + self.assertConnectionFailing( + client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + ) def test_server_receives_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + self.assertConnectionFailing( + server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + ) def test_client_receives_text_without_size_limit(self): client = Protocol(CLIENT, max_size=None) @@ -349,7 +362,9 @@ def test_client_receives_fragmented_text_over_size_limit(self): client.receive_data(b"\x80\x02\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + self.assertConnectionFailing( + client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + ) def test_server_receives_fragmented_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) @@ -361,7 +376,9 @@ def test_server_receives_fragmented_text_over_size_limit(self): server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + self.assertConnectionFailing( + server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + ) def test_client_receives_fragmented_text_without_size_limit(self): client = Protocol(CLIENT, max_size=None) @@ -423,7 +440,9 @@ def test_client_receives_unexpected_text(self): client.receive_data(b"\x01\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(client, 1002, "expected a continuation frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" + ) def test_server_receives_unexpected_text(self): server = Protocol(SERVER) @@ -435,19 +454,21 @@ def test_server_receives_unexpected_text(self): server.receive_data(b"\x01\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(server, 1002, "expected a continuation frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" + ) def test_client_sends_text_after_sending_close(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState): client.send_text(b"") def test_server_sends_text_after_sending_close(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_text(b"") @@ -455,7 +476,7 @@ def test_server_sends_text_after_sending_close(self): def test_client_receives_text_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x81\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -463,7 +484,7 @@ def test_client_receives_text_after_receiving_close(self): def test_server_receives_text_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x81\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -509,14 +530,18 @@ def test_client_receives_binary_over_size_limit(self): client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (4 > 3 bytes)") + self.assertConnectionFailing( + client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + ) def test_server_receives_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (4 > 3 bytes)") + self.assertConnectionFailing( + server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + ) def test_client_sends_fragmented_binary(self): client = Protocol(CLIENT) @@ -587,7 +612,9 @@ def test_client_receives_fragmented_binary_over_size_limit(self): client.receive_data(b"\x80\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(client, 1009, "over size limit (2 > 1 bytes)") + self.assertConnectionFailing( + client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + ) def test_server_receives_fragmented_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) @@ -599,7 +626,9 @@ def test_server_receives_fragmented_binary_over_size_limit(self): server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") - self.assertConnectionFailing(server, 1009, "over size limit (2 > 1 bytes)") + self.assertConnectionFailing( + server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + ) def test_client_sends_unexpected_binary(self): client = Protocol(CLIENT) @@ -625,7 +654,9 @@ def test_client_receives_unexpected_binary(self): client.receive_data(b"\x02\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(client, 1002, "expected a continuation frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" + ) def test_server_receives_unexpected_binary(self): server = Protocol(SERVER) @@ -637,19 +668,21 @@ def test_server_receives_unexpected_binary(self): server.receive_data(b"\x02\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "expected a continuation frame") - self.assertConnectionFailing(server, 1002, "expected a continuation frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "expected a continuation frame" + ) def test_client_sends_binary_after_sending_close(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState): client.send_binary(b"") def test_server_sends_binary_after_sending_close(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) with self.assertRaises(InvalidState): server.send_binary(b"") @@ -657,7 +690,7 @@ def test_server_sends_binary_after_sending_close(self): def test_client_receives_binary_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x82\x00") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -665,7 +698,7 @@ def test_client_receives_binary_after_receiving_close(self): def test_server_receives_binary_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x82\x80\x00\xff\x00\xff") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -687,7 +720,7 @@ def test_close_code(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") client.receive_eof() - self.assertEqual(client.close_code, 1000) + self.assertEqual(client.close_code, CloseCode.NORMAL_CLOSURE) def test_close_reason(self): server = Protocol(SERVER) @@ -699,7 +732,7 @@ def test_close_code_not_provided(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x00\x00\x00\x00") server.receive_eof() - self.assertEqual(server.close_code, 1005) + self.assertEqual(server.close_code, CloseCode.NO_STATUS_RCVD) def test_close_reason_not_provided(self): client = Protocol(CLIENT) @@ -710,7 +743,7 @@ def test_close_reason_not_provided(self): def test_close_code_not_available(self): client = Protocol(CLIENT) client.receive_eof() - self.assertEqual(client.close_code, 1006) + self.assertEqual(client.close_code, CloseCode.ABNORMAL_CLOSURE) def test_close_reason_not_available(self): server = Protocol(SERVER) @@ -812,32 +845,32 @@ def test_server_receives_close_then_sends_close(self): def test_client_sends_close_with_code(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) self.assertIs(client.state, CLOSING) def test_server_sends_close_with_code(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000, "") + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE, "") self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001, "") + self.assertConnectionClosing(server, CloseCode.GOING_AWAY, "") self.assertIs(server.state, CLOSING) def test_client_sends_close_with_code_and_reason(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001, "going away") + client.send_close(CloseCode.GOING_AWAY, "going away") self.assertEqual( client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] ) @@ -845,20 +878,20 @@ def test_client_sends_close_with_code_and_reason(self): def test_server_sends_close_with_code_and_reason(self): server = Protocol(SERVER) - server.send_close(1000, "OK") + server.send_close(CloseCode.NORMAL_CLOSURE, "OK") self.assertEqual(server.data_to_send(), [b"\x88\x04\x03\xe8OK"]) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_code_and_reason(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x04\x03\xe8OK") - self.assertConnectionClosing(client, 1000, "OK") + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE, "OK") self.assertIs(client.state, CLOSING) def test_server_receives_close_with_code_and_reason(self): server = Protocol(SERVER) server.receive_data(b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away") - self.assertConnectionClosing(server, 1001, "going away") + self.assertConnectionClosing(server, CloseCode.GOING_AWAY, "going away") self.assertIs(server.state, CLOSING) def test_client_sends_close_with_reason_only(self): @@ -878,7 +911,9 @@ def test_client_receives_close_with_truncated_code(self): client.receive_data(b"\x88\x01\x03") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "close frame too short") - self.assertConnectionFailing(client, 1002, "close frame too short") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "close frame too short" + ) self.assertIs(client.state, CLOSING) def test_server_receives_close_with_truncated_code(self): @@ -886,7 +921,9 @@ def test_server_receives_close_with_truncated_code(self): server.receive_data(b"\x88\x81\x00\x00\x00\x00\x03") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "close frame too short") - self.assertConnectionFailing(server, 1002, "close frame too short") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "close frame too short" + ) self.assertIs(server.state, CLOSING) def test_client_receives_close_with_non_utf8_reason(self): @@ -898,7 +935,9 @@ def test_client_receives_close_with_non_utf8_reason(self): str(client.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) - self.assertConnectionFailing(client, 1007, "invalid start byte at position 0") + self.assertConnectionFailing( + client, CloseCode.INVALID_DATA, "invalid start byte at position 0" + ) self.assertIs(client.state, CLOSING) def test_server_receives_close_with_non_utf8_reason(self): @@ -910,7 +949,9 @@ def test_server_receives_close_with_non_utf8_reason(self): str(server.parser_exc), "'utf-8' codec can't decode byte 0xff in position 0: invalid start byte", ) - self.assertConnectionFailing(server, 1007, "invalid start byte at position 0") + self.assertConnectionFailing( + server, CloseCode.INVALID_DATA, "invalid start byte at position 0" + ) self.assertIs(server.state, CLOSING) @@ -1011,19 +1052,23 @@ def test_client_receives_fragmented_ping_frame(self): client.receive_data(b"\x09\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") - self.assertConnectionFailing(client, 1002, "fragmented control frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "fragmented control frame" + ) def test_server_receives_fragmented_ping_frame(self): server = Protocol(SERVER) server.receive_data(b"\x09\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") - self.assertConnectionFailing(server, 1002, "fragmented control frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "fragmented control frame" + ) def test_client_sends_ping_after_sending_close(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) # before the connection is closed" but websockets doesn't support @@ -1037,7 +1082,7 @@ def test_client_sends_ping_after_sending_close(self): def test_server_sends_ping_after_sending_close(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # The spec says: "An endpoint MAY send a Ping frame any time (...) # before the connection is closed" but websockets doesn't support @@ -1052,7 +1097,7 @@ def test_server_sends_ping_after_sending_close(self): def test_client_receives_ping_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1060,7 +1105,7 @@ def test_client_receives_ping_after_receiving_close(self): def test_server_receives_ping_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -1147,19 +1192,23 @@ def test_client_receives_fragmented_pong_frame(self): client.receive_data(b"\x0a\x00") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "fragmented control frame") - self.assertConnectionFailing(client, 1002, "fragmented control frame") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "fragmented control frame" + ) def test_server_receives_fragmented_pong_frame(self): server = Protocol(SERVER) server.receive_data(b"\x0a\x80\x3c\x3c\x3c\x3c") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "fragmented control frame") - self.assertConnectionFailing(server, 1002, "fragmented control frame") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "fragmented control frame" + ) def test_client_sends_pong_after_sending_close(self): client = Protocol(CLIENT) with self.enforce_mask(b"\x00\x00\x00\x00"): - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) # websockets doesn't support sending a Pong frame after a Close frame. with self.assertRaises(InvalidState): @@ -1167,7 +1216,7 @@ def test_client_sends_pong_after_sending_close(self): def test_server_sends_pong_after_sending_close(self): server = Protocol(SERVER) - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) # websockets doesn't support sending a Pong frame after a Close frame. with self.assertRaises(InvalidState): @@ -1176,7 +1225,7 @@ def test_server_sends_pong_after_sending_close(self): def test_client_receives_pong_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") - self.assertConnectionClosing(client, 1000) + self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1184,7 +1233,7 @@ def test_client_receives_pong_after_receiving_close(self): def test_server_receives_pong_after_receiving_close(self): server = Protocol(SERVER) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") - self.assertConnectionClosing(server, 1001) + self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") self.assertFrameReceived(server, None) self.assertFrameSent(server, None) @@ -1200,15 +1249,15 @@ class FailTests(ProtocolTestCase): def test_client_stops_processing_frames_after_fail(self): client = Protocol(CLIENT) - client.fail(1002) - self.assertConnectionFailing(client, 1002) + client.fail(CloseCode.PROTOCOL_ERROR) + self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR) client.receive_data(b"\x88\x02\x03\xea") self.assertFrameReceived(client, None) def test_server_stops_processing_frames_after_fail(self): server = Protocol(SERVER) - server.fail(1002) - self.assertConnectionFailing(server, 1002) + server.fail(CloseCode.PROTOCOL_ERROR) + self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR) server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xea") self.assertFrameReceived(server, None) @@ -1320,7 +1369,7 @@ def test_client_send_close_in_fragmented_message(self): # endpoint must not send a data frame after a close frame, a close # frame can't be "in the middle" of a fragmented message. with self.assertRaises(ProtocolError) as raised: - client.send_close(1001) + client.send_close(CloseCode.GOING_AWAY) self.assertEqual(str(raised.exception), "expected a continuation frame") client.send_continuation(b"Eggs", fin=True) @@ -1333,7 +1382,7 @@ def test_server_send_close_in_fragmented_message(self): # endpoint must not send a data frame after a close frame, a close # frame can't be "in the middle" of a fragmented message. with self.assertRaises(ProtocolError) as raised: - server.send_close(1000) + server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(str(raised.exception), "expected a continuation frame") def test_client_receive_close_in_fragmented_message(self): @@ -1350,7 +1399,9 @@ def test_client_receive_close_in_fragmented_message(self): client.receive_data(b"\x88\x02\x03\xe8") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incomplete fragmented message") - self.assertConnectionFailing(client, 1002, "incomplete fragmented message") + self.assertConnectionFailing( + client, CloseCode.PROTOCOL_ERROR, "incomplete fragmented message" + ) def test_server_receive_close_in_fragmented_message(self): server = Protocol(SERVER) @@ -1366,7 +1417,9 @@ def test_server_receive_close_in_fragmented_message(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incomplete fragmented message") - self.assertConnectionFailing(server, 1002, "incomplete fragmented message") + self.assertConnectionFailing( + server, CloseCode.PROTOCOL_ERROR, "incomplete fragmented message" + ) class EOFTests(ProtocolTestCase): @@ -1428,35 +1481,35 @@ def test_server_receives_eof_inside_frame(self): def test_client_receives_data_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") + self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_data(b"\x00\x00") self.assertFrameSent(client, None) def test_server_receives_data_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") + self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_data(b"\x00\x00") self.assertFrameSent(server, None) def test_client_receives_eof_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") + self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_eof() self.assertFrameSent(client, None, eof=True) def test_server_receives_eof_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") + self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_eof() self.assertFrameSent(server, None) def test_client_receives_data_and_eof_after_exception(self): client = Protocol(CLIENT) client.receive_data(b"\xff\xff") - self.assertConnectionFailing(client, 1002, "invalid opcode") + self.assertConnectionFailing(client, CloseCode.PROTOCOL_ERROR, "invalid opcode") client.receive_data(b"\x00\x00") client.receive_eof() self.assertFrameSent(client, None, eof=True) @@ -1464,7 +1517,7 @@ def test_client_receives_data_and_eof_after_exception(self): def test_server_receives_data_and_eof_after_exception(self): server = Protocol(SERVER) server.receive_data(b"\xff\xff") - self.assertConnectionFailing(server, 1002, "invalid opcode") + self.assertConnectionFailing(server, CloseCode.PROTOCOL_ERROR, "invalid opcode") server.receive_data(b"\x00\x00") server.receive_eof() self.assertFrameSent(server, None) @@ -1554,12 +1607,12 @@ def test_server_receives_close(self): def test_client_fails_connection(self): client = Protocol(CLIENT) - client.fail(1002) + client.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(client.close_expected()) def test_server_fails_connection(self): server = Protocol(SERVER) - server.fail(1002) + server.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(server.close_expected()) @@ -1572,25 +1625,25 @@ class ConnectionClosedTests(ProtocolTestCase): def test_client_sends_close_then_receives_close(self): # Client-initiated close handshake on the client side complete. client = Protocol(CLIENT) - client.send_close(1000, "") + client.send_close(CloseCode.NORMAL_CLOSURE, "") client.receive_data(b"\x88\x02\x03\xe8") client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertFalse(exc.rcvd_then_sent) def test_server_sends_close_then_receives_close(self): # Server-initiated close handshake on the server side complete. server = Protocol(SERVER) - server.send_close(1000, "") + server.send_close(CloseCode.NORMAL_CLOSURE, "") server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe8") server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertFalse(exc.rcvd_then_sent) def test_client_receives_close_then_sends_close(self): @@ -1600,8 +1653,8 @@ def test_client_receives_close_then_sends_close(self): client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertTrue(exc.rcvd_then_sent) def test_server_receives_close_then_sends_close(self): @@ -1611,30 +1664,30 @@ def test_server_receives_close_then_sends_close(self): server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedOK) - self.assertEqual(exc.rcvd, Close(1000, "")) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.rcvd, Close(CloseCode.NORMAL_CLOSURE, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertTrue(exc.rcvd_then_sent) def test_client_sends_close_then_receives_eof(self): # Client-initiated close handshake on the client side times out. client = Protocol(CLIENT) - client.send_close(1000, "") + client.send_close(CloseCode.NORMAL_CLOSURE, "") client.receive_eof() exc = client.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertIsNone(exc.rcvd_then_sent) def test_server_sends_close_then_receives_eof(self): # Server-initiated close handshake on the server side times out. server = Protocol(SERVER) - server.send_close(1000, "") + server.send_close(CloseCode.NORMAL_CLOSURE, "") server.receive_eof() exc = server.close_exc self.assertIsInstance(exc, ConnectionClosedError) self.assertIsNone(exc.rcvd) - self.assertEqual(exc.sent, Close(1000, "")) + self.assertEqual(exc.sent, Close(CloseCode.NORMAL_CLOSURE, "")) self.assertIsNone(exc.rcvd_then_sent) def test_client_receives_eof(self): @@ -1671,7 +1724,7 @@ def test_client_hits_internal_error_reading_frame(self): client.receive_data(b"\x81\x00") self.assertIsInstance(client.parser_exc, RuntimeError) self.assertEqual(str(client.parser_exc), "BOOM") - self.assertConnectionFailing(client, 1011, "") + self.assertConnectionFailing(client, CloseCode.INTERNAL_ERROR, "") def test_server_hits_internal_error_reading_frame(self): server = Protocol(SERVER) @@ -1680,7 +1733,7 @@ def test_server_hits_internal_error_reading_frame(self): server.receive_data(b"\x81\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, RuntimeError) self.assertEqual(str(server.parser_exc), "BOOM") - self.assertConnectionFailing(server, 1011, "") + self.assertConnectionFailing(server, CloseCode.INTERNAL_ERROR, "") class ExtensionsTests(ProtocolTestCase): From e3abb88ad6c8fc3e57f42aed4a00b9644c4e65df Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 18:59:55 +0200 Subject: [PATCH 446/762] Document that the Host header isn't validated. --- docs/reference/features.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 3cc52ec10..98b3c0dda 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -174,6 +174,11 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 +The server doesn't check the Host header and respond with a HTTP 400 Bad Request +if it is missing or invalid (`#1246`). + +.. _#1246: https://github.com/python-websockets/websockets/issues/1246 + The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is `mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right From 1bf9d1d766c80da4887240737266926f173fbcef Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 18 May 2023 19:16:31 +0200 Subject: [PATCH 447/762] Accept int status in process_request and reject. Fix #1309. --- src/websockets/exceptions.py | 3 ++- src/websockets/server.py | 2 ++ tests/legacy/test_client_server.py | 14 ++++++++++++++ tests/test_server.py | 6 ++++++ 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 9d8476648..0f0686872 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -334,7 +334,8 @@ def __init__( headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: - self.status = status + # If a user passes an int instead of a HTTPStatus, fix it automatically. + self.status = http.HTTPStatus(status) self.headers = datastructures.Headers(headers) self.body = body diff --git a/src/websockets/server.py b/src/websockets/server.py index ecb0f74a6..b9646ea81 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -499,6 +499,8 @@ def reject( Response: WebSocket handshake response event to send to the client. """ + # If a user passes an int instead of a HTTPStatus, fix it automatically. + status = http.HTTPStatus(status) body = text.encode() headers = Headers( [ diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index cba24ad32..02dbe9e3f 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -190,6 +190,12 @@ async def process_request(self, path, request_headers): return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" +class ProcessRequestReturningIntProtocol(WebSocketServerProtocol): + async def process_request(self, path, request_headers): + if path == "/__health__/": + return 200, [], b"OK\n" + + class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): await asyncio.sleep(10 * MS) @@ -757,6 +763,14 @@ def test_http_request_custom_server_header(self): with contextlib.closing(response): self.assertEqual(response.headers["Server"], "websockets") + @with_server(create_protocol=ProcessRequestReturningIntProtocol) + def test_process_request_returns_int_status(self): + response = self.loop.run_until_complete(self.make_http_request("/__health__/")) + + with contextlib.closing(response): + self.assertEqual(response.code, 200) + self.assertEqual(response.read(), b"OK\n") + def assert_client_raises_code(self, status_code): with self.assertRaises(InvalidStatusCode) as raised: self.start_client() diff --git a/tests/test_server.py b/tests/test_server.py index ecf3d4cbe..b6f5e3568 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -188,6 +188,12 @@ def test_reject_response(self): ) self.assertEqual(response.body, b"Sorry folks.\n") + def test_reject_response_supports_int_status(self): + server = ServerProtocol() + response = server.reject(404, "Sorry folks.\n") + self.assertEqual(response.status_code, 404) + self.assertEqual(response.reason_phrase, "Not Found") + def test_basic(self): server = ServerProtocol() request = self.make_request() From 7b9a53b74c09b27cf0899a28f38f6835c5d141ed Mon Sep 17 00:00:00 2001 From: Benjamin Loison <12752145+Benjamin-Loison@users.noreply.github.com> Date: Tue, 27 Jun 2023 18:17:43 +0200 Subject: [PATCH 448/762] Remove the `asyncio` import in `README.rst` for the `threading` example --- README.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/README.rst b/README.rst index f53d3d0fc..870b208ba 100644 --- a/README.rst +++ b/README.rst @@ -65,7 +65,6 @@ Here's how a client sends and receives messages with the ``threading`` API: #!/usr/bin/env python - import asyncio from websockets.sync.client import connect def hello(): From b870c46b2778444ae2d8960a0ae0bf4be53dc81c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Jul 2023 10:14:45 +0200 Subject: [PATCH 449/762] Fix inconsistency in 2b627b26. --- docs/reference/datastructures.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/datastructures.rst b/docs/reference/datastructures.rst index 7c037da4c..ec02d4210 100644 --- a/docs/reference/datastructures.rst +++ b/docs/reference/datastructures.rst @@ -21,12 +21,12 @@ WebSocket events .. autoclass:: CloseCode - .. autoattribute:: OK + .. autoattribute:: NORMAL_CLOSURE .. autoattribute:: GOING_AWAY .. autoattribute:: PROTOCOL_ERROR .. autoattribute:: UNSUPPORTED_DATA .. autoattribute:: NO_STATUS_RCVD - .. autoattribute:: CONNECTION_CLOSED_ABNORMALLY + .. autoattribute:: ABNORMAL_CLOSURE .. autoattribute:: INVALID_DATA .. autoattribute:: POLICY_VIOLATION .. autoattribute:: MESSAGE_TOO_BIG @@ -35,7 +35,7 @@ WebSocket events .. autoattribute:: SERVICE_RESTART .. autoattribute:: TRY_AGAIN_LATER .. autoattribute:: BAD_GATEWAY - .. autoattribute:: TLS_FAILURE + .. autoattribute:: TLS_HANDSHAKE HTTP events ----------- From 8ed5424a5bdcbe99565d05239e7ea6e5a15a3cdd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Jul 2023 10:24:40 +0200 Subject: [PATCH 450/762] Clarify that the TLS example is simplistic. Fix #1381. --- docs/howto/quickstart.rst | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/howto/quickstart.rst b/docs/howto/quickstart.rst index da3c3999e..ab870952c 100644 --- a/docs/howto/quickstart.rst +++ b/docs/howto/quickstart.rst @@ -57,9 +57,6 @@ Here's how to adapt the server to encrypt connections. You must download :download:`localhost.pem <../../example/quickstart/localhost.pem>` and save it in the same directory as ``server_secure.py``. -See the documentation of the :mod:`ssl` module for details on configuring the -TLS context securely. - .. literalinclude:: ../../example/quickstart/server_secure.py :caption: server_secure.py :language: python @@ -79,6 +76,15 @@ When connecting to a secure WebSocket server with a valid certificate — any certificate signed by a CA that your Python installation trusts — you can simply pass ``ssl=True`` to :func:`~client.connect`. +.. admonition:: Configure the TLS context securely + :class: attention + + This example demonstrates the ``ssl`` argument with a TLS certificate shared + between the client and the server. This is a simplistic setup. + + Please review the advice and security considerations in the documentation of + the :mod:`ssl` module to configure the TLS context securely. + Connect from a browser ---------------------- From adfb8d69a7a1f6f4c8381c9e7182619d202c3cf2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Jul 2023 10:47:27 +0200 Subject: [PATCH 451/762] Fix test that fails on IPv6. --- tests/legacy/test_client_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 02dbe9e3f..056ff193f 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -397,9 +397,9 @@ def test_explicit_host_port(self): wsuri = parse_uri(uri) # Change host and port to invalid values. - changed_uri = uri.replace(wsuri.host, "example.com").replace( - str(wsuri.port), str(65535 - wsuri.port) - ) + scheme = "wss" if wsuri.secure else "ws" + port = 65535 - wsuri.port + changed_uri = f"{scheme}://example.com:{port}/" with self.temp_client(uri=changed_uri, host=wsuri.host, port=wsuri.port): self.loop.run_until_complete(self.client.send("Hello!")) From c41ce814bb173e2b174c60ce935b8598da626863 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 15:55:46 +0200 Subject: [PATCH 452/762] Rejecting a connection isn't always a failure. Fix #1402. --- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- tests/legacy/test_client_server.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 77e0fdab7..c9b32c417 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -227,7 +227,7 @@ async def handler(self) -> None: self.write_http_response(status, headers, body) self.logger.info( - "connection failed (%d %s)", status.value, status.phrase + "connection rejected (%d %s)", status.value, status.phrase ) await self.close_transport() return diff --git a/src/websockets/server.py b/src/websockets/server.py index b9646ea81..34069faf0 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -516,7 +516,7 @@ def reject( # "handshake_exc is None if and only if opening handshake succeeded." if self.handshake_exc is None: self.handshake_exc = InvalidStatus(response) - self.logger.info("connection failed (%d %s)", status.value, status.phrase) + self.logger.info("connection rejected (%d %s)", status.value, status.phrase) return response def send_response(self, response: Response) -> None: diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 056ff193f..f7e02d101 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1597,12 +1597,12 @@ async def run_client(): self.assertEqual( [record.getMessage() for record in logs.records][4:-1], [ - "connection failed (503 Service Unavailable)", + "connection rejected (503 Service Unavailable)", "connection closed", "! connect failed; reconnecting in 0.0 seconds", ] + [ - "connection failed (503 Service Unavailable)", + "connection rejected (503 Service Unavailable)", "connection closed", "! connect failed again; retrying in 0 seconds", ] From 1b10ca16dbb4e6c84c8d27cd42e8e35cfb85f5e8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 16:39:03 +0200 Subject: [PATCH 453/762] Add compatibility imports from legacy package in __all__. Supersedes #1400. --- src/websockets/auth.py | 2 ++ src/websockets/client.py | 4 +++- src/websockets/server.py | 4 +++- tests/test_exports.py | 10 +++------- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 5292e4f7f..b792e02f5 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,4 +1,6 @@ from __future__ import annotations # See #940 for why lazy_import isn't used here for backwards compatibility. +# See #1400 for why listing compatibility imports in __all__ helps PyCharm. from .legacy.auth import * +from .legacy.auth import __all__ # noqa: F401 diff --git a/src/websockets/client.py b/src/websockets/client.py index bf8427c37..b2f622042 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -38,10 +38,12 @@ # See #940 for why lazy_import isn't used here for backwards compatibility. +# See #1400 for why listing compatibility imports in __all__ helps PyCharm. from .legacy.client import * # isort:skip # noqa: I001 +from .legacy.client import __all__ as legacy__all__ -__all__ = ["ClientProtocol"] +__all__ = ["ClientProtocol"] + legacy__all__ class ClientProtocol(Protocol): diff --git a/src/websockets/server.py b/src/websockets/server.py index 34069faf0..872e4e5af 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -39,10 +39,12 @@ # See #940 for why lazy_import isn't used here for backwards compatibility. +# See #1400 for why listing compatibility imports in __all__ helps PyCharm. from .legacy.server import * # isort:skip # noqa: I001 +from .legacy.server import __all__ as legacy__all__ -__all__ = ["ServerProtocol"] +__all__ = ["ServerProtocol"] + legacy__all__ class ServerProtocol(Protocol): diff --git a/tests/test_exports.py b/tests/test_exports.py index 978b1d0e7..d63cb590c 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -1,24 +1,20 @@ import unittest import websockets +import websockets.auth import websockets.client import websockets.exceptions -import websockets.legacy.auth -import websockets.legacy.client import websockets.legacy.protocol -import websockets.legacy.server import websockets.server import websockets.typing import websockets.uri combined_exports = ( - websockets.legacy.auth.__all__ - + websockets.legacy.client.__all__ - + websockets.legacy.protocol.__all__ - + websockets.legacy.server.__all__ + websockets.auth.__all__ + websockets.client.__all__ + websockets.exceptions.__all__ + + websockets.legacy.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ + websockets.uri.__all__ From c9eefae63d2e456296e32da177749b392dd37b35 Mon Sep 17 00:00:00 2001 From: kxxt Date: Wed, 14 Jun 2023 15:26:56 +0800 Subject: [PATCH 454/762] test: allow WEBSOCKETS_TESTS_TIMEOUT_FACTOR to be float --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 4e9ac9f0e..2937a2f15 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,7 +23,7 @@ # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) +MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) # PyPy has a performance penalty for this test suite. if platform.python_implementation() == "PyPy": # pragma: no cover From ca5926ed1532fcd1dd264ef4f5a7a357fef66cfc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 16:59:33 +0200 Subject: [PATCH 455/762] Go back to 100% coverage. Broken in 1bf9d1d7. --- tests/legacy/test_client_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index f7e02d101..c49d91b70 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -192,8 +192,8 @@ async def process_request(self, path, request_headers): class ProcessRequestReturningIntProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): - if path == "/__health__/": - return 200, [], b"OK\n" + assert path == "/__health__/" + return 200, [], b"OK\n" class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): From 439dafa656029f27c26956181d27a781ec04b105 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 17:15:37 +0200 Subject: [PATCH 456/762] Import eagerly when running under a type checker. Fix #1292 (and many others). Also fix some inconsistencies in the lists. The new rule is: - when re-exporting a module, re-export it entirely; - don't re-export deprecated modules. --- docs/faq/misc.rst | 23 ---- docs/project/changelog.rst | 5 + docs/reference/index.rst | 13 --- src/websockets/__init__.py | 229 ++++++++++++++++++++++++------------- src/websockets/http.py | 29 +++-- tests/test_exports.py | 3 +- 6 files changed, 175 insertions(+), 127 deletions(-) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 4fc271322..ee5ad2372 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -12,29 +12,6 @@ instead of the websockets library. .. _real-import-paths: -Why does my IDE fail to show documentation for websockets APIs? -............................................................... - -You are probably using the convenience imports e.g.:: - - import websockets - - websockets.connect(...) - websockets.serve(...) - -This is incompatible with static code analysis. It may break auto-completion and -contextual documentation in IDEs, type checking with mypy_, etc. - -.. _mypy: https://github.com/python/mypy - -Instead, use the real import paths e.g.:: - - import websockets.client - import websockets.server - - websockets.client.connect(...) - websockets.server.serve(...) - Why is the default implementation located in ``websockets.legacy``? ................................................................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 94fc5ebd9..ad9ab5908 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,11 @@ Backwards-incompatible changes Improvements ............ +* Made convenience imports from ``websockets`` compatible with static code + analysis tools such as auto-completion in an IDE or type checking with mypy_. + + .. _mypy: https://github.com/python/mypy + * Added :class:`~frames.CloseCode`. 11.0.3 diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2a9556dd9..0b80f087a 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -88,16 +88,3 @@ Convenience imports For convenience, many public APIs can be imported directly from the ``websockets`` package. - - -.. admonition:: Convenience imports are incompatible with some development tools. - :class: caution - - Specifically, static code analysis tools don't understand them. This breaks - auto-completion and contextual documentation in IDEs, type checking with - mypy_, etc. - - .. _mypy: https://github.com/python/mypy - - If you're using such tools, stick to the full import paths, as explained in - this FAQ: :ref:`real-import-paths` diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index dcf3d8150..2b9c3bc54 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,23 +1,24 @@ from __future__ import annotations +import typing + from .imports import lazy_import from .version import version as __version__ # noqa: F401 __all__ = [ - "AbortHandshake", - "basic_auth_protocol_factory", - "BasicAuthWebSocketServerProtocol", - "broadcast", + # .client "ClientProtocol", - "connect", + # .datastructures + "Headers", + "HeadersLike", + "MultipleValuesError", + # .exceptions + "AbortHandshake", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", - "Data", "DuplicateParameter", - "ExtensionName", - "ExtensionParameter", "InvalidHandshake", "InvalidHeader", "InvalidHeaderFormat", @@ -31,84 +32,156 @@ "InvalidStatusCode", "InvalidUpgrade", "InvalidURI", - "LoggerLike", "NegotiationError", - "Origin", - "parse_uri", "PayloadTooBig", "ProtocolError", "RedirectHandshake", "SecurityError", - "serve", - "ServerProtocol", - "Subprotocol", - "unix_connect", - "unix_serve", - "WebSocketClientProtocol", - "WebSocketCommonProtocol", "WebSocketException", "WebSocketProtocolError", + # .legacy.auth + "BasicAuthWebSocketServerProtocol", + "basic_auth_protocol_factory", + # .legacy.client + "WebSocketClientProtocol", + "connect", + "unix_connect", + # .legacy.protocol + "WebSocketCommonProtocol", + "broadcast", + # .legacy.server "WebSocketServer", "WebSocketServerProtocol", - "WebSocketURI", + "serve", + "unix_serve", + # .server + "ServerProtocol", + # .typing + "Data", + "ExtensionName", + "ExtensionParameter", + "LoggerLike", + "Origin", + "Subprotocol", ] -lazy_import( - globals(), - aliases={ - "auth": ".legacy", - "basic_auth_protocol_factory": ".legacy.auth", - "BasicAuthWebSocketServerProtocol": ".legacy.auth", - "broadcast": ".legacy.protocol", - "ClientProtocol": ".client", - "connect": ".legacy.client", - "unix_connect": ".legacy.client", - "WebSocketClientProtocol": ".legacy.client", - "Headers": ".datastructures", - "MultipleValuesError": ".datastructures", - "WebSocketException": ".exceptions", - "ConnectionClosed": ".exceptions", - "ConnectionClosedError": ".exceptions", - "ConnectionClosedOK": ".exceptions", - "InvalidHandshake": ".exceptions", - "SecurityError": ".exceptions", - "InvalidMessage": ".exceptions", - "InvalidHeader": ".exceptions", - "InvalidHeaderFormat": ".exceptions", - "InvalidHeaderValue": ".exceptions", - "InvalidOrigin": ".exceptions", - "InvalidUpgrade": ".exceptions", - "InvalidStatus": ".exceptions", - "InvalidStatusCode": ".exceptions", - "NegotiationError": ".exceptions", - "DuplicateParameter": ".exceptions", - "InvalidParameterName": ".exceptions", - "InvalidParameterValue": ".exceptions", - "AbortHandshake": ".exceptions", - "RedirectHandshake": ".exceptions", - "InvalidState": ".exceptions", - "InvalidURI": ".exceptions", - "PayloadTooBig": ".exceptions", - "ProtocolError": ".exceptions", - "WebSocketProtocolError": ".exceptions", - "protocol": ".legacy", - "WebSocketCommonProtocol": ".legacy.protocol", - "ServerProtocol": ".server", - "serve": ".legacy.server", - "unix_serve": ".legacy.server", - "WebSocketServerProtocol": ".legacy.server", - "WebSocketServer": ".legacy.server", - "Data": ".typing", - "LoggerLike": ".typing", - "Origin": ".typing", - "ExtensionHeader": ".typing", - "ExtensionParameter": ".typing", - "Subprotocol": ".typing", - }, - deprecated_aliases={ - "framing": ".legacy", - "handshake": ".legacy", - "parse_uri": ".uri", - "WebSocketURI": ".uri", - }, -) +# When type checking, import non-deprecated aliases eagerly. Else, import on demand. +if typing.TYPE_CHECKING: + from .client import ClientProtocol + from .datastructures import Headers, HeadersLike, MultipleValuesError + from .exceptions import ( + AbortHandshake, + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + DuplicateParameter, + InvalidHandshake, + InvalidHeader, + InvalidHeaderFormat, + InvalidHeaderValue, + InvalidMessage, + InvalidOrigin, + InvalidParameterName, + InvalidParameterValue, + InvalidState, + InvalidStatus, + InvalidStatusCode, + InvalidUpgrade, + InvalidURI, + NegotiationError, + PayloadTooBig, + ProtocolError, + RedirectHandshake, + SecurityError, + WebSocketException, + WebSocketProtocolError, + ) + from .legacy.auth import ( + BasicAuthWebSocketServerProtocol, + basic_auth_protocol_factory, + ) + from .legacy.client import WebSocketClientProtocol, connect, unix_connect + from .legacy.protocol import WebSocketCommonProtocol, broadcast + from .legacy.server import ( + WebSocketServer, + WebSocketServerProtocol, + serve, + unix_serve, + ) + from .server import ServerProtocol + from .typing import ( + Data, + ExtensionName, + ExtensionParameter, + LoggerLike, + Origin, + Subprotocol, + ) +else: + lazy_import( + globals(), + aliases={ + # .client + "ClientProtocol": ".client", + # .datastructures + "Headers": ".datastructures", + "HeadersLike": ".datastructures", + "MultipleValuesError": ".datastructures", + # .exceptions + "AbortHandshake": ".exceptions", + "ConnectionClosed": ".exceptions", + "ConnectionClosedError": ".exceptions", + "ConnectionClosedOK": ".exceptions", + "DuplicateParameter": ".exceptions", + "InvalidHandshake": ".exceptions", + "InvalidHeader": ".exceptions", + "InvalidHeaderFormat": ".exceptions", + "InvalidHeaderValue": ".exceptions", + "InvalidMessage": ".exceptions", + "InvalidOrigin": ".exceptions", + "InvalidParameterName": ".exceptions", + "InvalidParameterValue": ".exceptions", + "InvalidState": ".exceptions", + "InvalidStatus": ".exceptions", + "InvalidStatusCode": ".exceptions", + "InvalidUpgrade": ".exceptions", + "InvalidURI": ".exceptions", + "NegotiationError": ".exceptions", + "PayloadTooBig": ".exceptions", + "ProtocolError": ".exceptions", + "RedirectHandshake": ".exceptions", + "SecurityError": ".exceptions", + "WebSocketException": ".exceptions", + "WebSocketProtocolError": ".exceptions", + # .legacy.auth + "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "basic_auth_protocol_factory": ".legacy.auth", + # .legacy.client + "WebSocketClientProtocol": ".legacy.client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + # .legacy.protocol + "WebSocketCommonProtocol": ".legacy.protocol", + "broadcast": ".legacy.protocol", + # .legacy.server + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + # .server + "ServerProtocol": ".server", + # .typing + "Data": ".typing", + "ExtensionName": ".typing", + "ExtensionParameter": ".typing", + "LoggerLike": ".typing", + "Origin": ".typing", + "Subprotocol": ".typing", + }, + deprecated_aliases={ + "framing": ".legacy", + "handshake": ".legacy", + "parse_uri": ".uri", + "WebSocketURI": ".uri", + }, + ) diff --git a/src/websockets/http.py b/src/websockets/http.py index b14fa94bd..9f86f6a1f 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import typing from .imports import lazy_import from .version import version as websockets_version @@ -9,18 +10,22 @@ # For backwards compatibility: -lazy_import( - globals(), - # Headers and MultipleValuesError used to be defined in this module. - aliases={ - "Headers": ".datastructures", - "MultipleValuesError": ".datastructures", - }, - deprecated_aliases={ - "read_request": ".legacy.http", - "read_response": ".legacy.http", - }, -) +# When type checking, import non-deprecated aliases eagerly. Else, import on demand. +if typing.TYPE_CHECKING: + from .datastructures import Headers, MultipleValuesError # noqa: F401 +else: + lazy_import( + globals(), + # Headers and MultipleValuesError used to be defined in this module. + aliases={ + "Headers": ".datastructures", + "MultipleValuesError": ".datastructures", + }, + deprecated_aliases={ + "read_request": ".legacy.http", + "read_response": ".legacy.http", + }, + ) __all__ = ["USER_AGENT"] diff --git a/tests/test_exports.py b/tests/test_exports.py index d63cb590c..67a1a6f99 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -3,6 +3,7 @@ import websockets import websockets.auth import websockets.client +import websockets.datastructures import websockets.exceptions import websockets.legacy.protocol import websockets.server @@ -13,11 +14,11 @@ combined_exports = ( websockets.auth.__all__ + websockets.client.__all__ + + websockets.datastructures.__all__ + websockets.exceptions.__all__ + websockets.legacy.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ - + websockets.uri.__all__ ) From 678b4380a258abd2a074475f3a817823241a4362 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 1 Oct 2023 17:43:00 +0200 Subject: [PATCH 457/762] Fix coverage. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 530052ddd..f24616dd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ exclude_lines = [ "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", + "if typing.TYPE_CHECKING:", "pragma: no cover", "raise AssertionError", "raise NotImplementedError", From be551424ecf17b92ce5cc24f5eb73f2796eb1daa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 7 Oct 2023 05:42:16 +0000 Subject: [PATCH 458/762] Bump actions/checkout from 3 to 4 Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/tests.yml | 6 +++--- .github/workflows/wheels.yml | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 603426412..470f5bc96 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -34,7 +34,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -72,7 +72,7 @@ jobs: is_main: false steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python ${{ matrix.python }} uses: actions/setup-python@v4 with: diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 8aa5c0b7b..e846c54d9 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v4 with: @@ -45,7 +45,7 @@ jobs: - macOS-latest steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Python 3.x uses: actions/setup-python@v4 with: From f62d44ac1a0a606fe5753cf0017ce9e1dbf3640e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 7 Oct 2023 05:42:20 +0000 Subject: [PATCH 459/762] Bump docker/setup-qemu-action from 2 to 3 Bumps [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) from 2 to 3. - [Release notes](https://github.com/docker/setup-qemu-action/releases) - [Commits](https://github.com/docker/setup-qemu-action/compare/v2...v3) --- updated-dependencies: - dependency-name: docker/setup-qemu-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index e846c54d9..0e182cf78 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -52,7 +52,7 @@ jobs: python-version: 3.x - name: Set up QEMU if: runner.os == 'Linux' - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 with: platforms: all - name: Build wheels From ed6cb1b8d0b3a6121ecef06fcec8e3f4de417a5b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 15:26:22 +0200 Subject: [PATCH 460/762] Make it explicit that status codes can be int. The code and tests already support it but the types didn't reflect it. Refs #1406. --- docs/reference/types.rst | 2 ++ src/websockets/__init__.py | 3 +++ src/websockets/exceptions.py | 3 ++- src/websockets/legacy/server.py | 8 ++++---- src/websockets/server.py | 3 ++- src/websockets/typing.py | 7 +++++++ 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 88550d08d..9d3aa8310 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -7,6 +7,8 @@ Types .. autodata:: LoggerLike + .. autodata:: StatusLike + .. autodata:: Origin .. autodata:: Subprotocol diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 2b9c3bc54..fdb028f4c 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -61,6 +61,7 @@ "ExtensionName", "ExtensionParameter", "LoggerLike", + "StatusLike", "Origin", "Subprotocol", ] @@ -115,6 +116,7 @@ ExtensionParameter, LoggerLike, Origin, + StatusLike, Subprotocol, ) else: @@ -176,6 +178,7 @@ "ExtensionParameter": ".typing", "LoggerLike": ".typing", "Origin": ".typing", + "StatusLike": "typing", "Subprotocol": ".typing", }, deprecated_aliases={ diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 0f0686872..f7169e3b1 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -34,6 +34,7 @@ from typing import Optional from . import datastructures, frames, http11 +from .typing import StatusLike __all__ = [ @@ -330,7 +331,7 @@ class AbortHandshake(InvalidHandshake): def __init__( self, - status: http.HTTPStatus, + status: StatusLike, headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c9b32c417..7c24dd74a 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -45,7 +45,7 @@ ) from ..http import USER_AGENT from ..protocol import State -from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol +from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout from .handshake import build_response, check_request from .http import read_request @@ -57,7 +57,7 @@ HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] -HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes] +HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] class WebSocketServerProtocol(WebSocketCommonProtocol): @@ -349,7 +349,7 @@ async def process_request( request_headers: request headers. Returns: - Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]: :obj:`None` + Optional[Tuple[StatusLike, HeadersLike, bytes]]: :obj:`None` to continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, @@ -943,7 +943,7 @@ class Serve: It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. process_request (Optional[Callable[[str, Headers], \ - Awaitable[Optional[Tuple[http.HTTPStatus, HeadersLike, bytes]]]]]): + Awaitable[Optional[Tuple[StatusLike, HeadersLike, bytes]]]]]): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. diff --git a/src/websockets/server.py b/src/websockets/server.py index 872e4e5af..191660553 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -32,6 +32,7 @@ ExtensionHeader, LoggerLike, Origin, + StatusLike, Subprotocol, UpgradeProtocol, ) @@ -480,7 +481,7 @@ def select_subprotocol(protocol, subprotocols): def reject( self, - status: http.HTTPStatus, + status: StatusLike, text: str, ) -> Response: """ diff --git a/src/websockets/typing.py b/src/websockets/typing.py index e672ba006..cc3e3ec0d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import http import logging from typing import List, NewType, Optional, Tuple, Union @@ -7,6 +8,7 @@ __all__ = [ "Data", "LoggerLike", + "StatusLike", "Origin", "Subprotocol", "ExtensionName", @@ -30,6 +32,11 @@ """Types accepted where a :class:`~logging.Logger` is expected.""" +StatusLike = Union[http.HTTPStatus, int] +""" +Types accepted where an :class:`~http.HTTPStatus` is expected.""" + + Origin = NewType("Origin", str) """Value of a ``Origin`` header.""" From 1077920fec5ebc274df16efc86bbb0a8056f3601 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 15:39:16 +0200 Subject: [PATCH 461/762] Add changelog for previous commit. --- docs/project/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index ad9ab5908..d87f4545c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -46,6 +46,8 @@ Improvements .. _mypy: https://github.com/python/mypy +* Accepted a plain :class:`int` where an :class:`~http.HTTPStatus` is expected. + * Added :class:`~frames.CloseCode`. 11.0.3 From 20413649b09aa98f136b1354a1f3ee801b663f36 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 21 Oct 2023 13:50:17 +0000 Subject: [PATCH 462/762] Bump pypa/cibuildwheel from 2.12.1 to 2.16.2 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.12.1 to 2.16.2. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.12.1...v2.16.2) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 0e182cf78..707ef2c60 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -56,7 +56,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.12.1 + uses: pypa/cibuildwheel@v2.16.2 env: BUILD_EXTENSION: yes - name: Save wheels From 01195322d2620a44039b716cb93c108c2ca9b6b9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 15:41:38 +0200 Subject: [PATCH 463/762] Release version 12.0 --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index d87f4545c..264e6e42d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -28,7 +28,7 @@ notice. 12.0 ---- -*In development* +*October 21, 2023* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 3f171b391..d1c99458e 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "12.0" From 310c29512955b37fffee685120108795a8436b6c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:30:14 +0200 Subject: [PATCH 464/762] Rename workflow for making a release. --- .github/workflows/{wheels.yml => release.yml} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename .github/workflows/{wheels.yml => release.yml} (97%) diff --git a/.github/workflows/wheels.yml b/.github/workflows/release.yml similarity index 97% rename from .github/workflows/wheels.yml rename to .github/workflows/release.yml index 707ef2c60..90f24b6f1 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/release.yml @@ -1,4 +1,4 @@ -name: Build wheels +name: Make release on: push: @@ -64,8 +64,8 @@ jobs: with: path: wheelhouse/*.whl - release: - name: Release + upload: + name: Upload needs: - sdist - wheels From 88e702ddaf214b46fcf6b3ceca25961f79ca9d00 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:39:20 +0200 Subject: [PATCH 465/762] Upgrade to Trusted Publishing. --- .github/workflows/release.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 90f24b6f1..8fad13529 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -72,6 +72,8 @@ jobs: runs-on: ubuntu-latest # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + permissions: + id-token: write steps: - name: Download artifacts uses: actions/download-artifact@v3 @@ -80,8 +82,6 @@ jobs: path: dist - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.PYPI_API_TOKEN }} - name: Create GitHub release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} From 5121bd15f988cc446db95b15a0bcac8dc64b68ab Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Oct 2023 16:54:27 +0200 Subject: [PATCH 466/762] Blind fix for automatic release creation. --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8fad13529..6e895e64e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -85,4 +85,4 @@ jobs: - name: Create GitHub release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: gh release create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." + run: gh release -R python-websockets/websockets create ${{ github.ref_name }} --notes "See https://websockets.readthedocs.io/en/stable/project/changelog.html for details." From 2431e09eebc75578e310627f0eab38cd81df2f6b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Nov 2023 07:55:23 +0100 Subject: [PATCH 467/762] Fix import style (likely autogenerated). --- src/websockets/sync/server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 14767968c..d12da0c65 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -11,10 +11,9 @@ from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type -from websockets.frames import CloseCode - from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..frames import CloseCode from ..headers import validate_subprotocols from ..http import USER_AGENT from ..http11 import Request, Response From ec3bd2ab06278602c1d6018b476699e090036373 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Nov 2023 08:22:33 +0100 Subject: [PATCH 468/762] Make sync reassembler more readable. No logic changes. --- src/websockets/sync/messages.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 67a22313c..d98ff855b 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -47,13 +47,13 @@ def __init__(self) -> None: # queue for transferring frames from the writing thread (library code) # to the reading thread (user code). We're buffering when chunks_queue # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the stream, superseding message_complete. + # value marking the end of the message, superseding message_complete. # Stream data from frames belonging to the same message. # Remove quotes around type when dropping Python < 3.9. self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None - # This flag marks the end of the stream. + # This flag marks the end of the connection. self.closed = False def get(self, timeout: Optional[float] = None) -> Data: @@ -108,12 +108,12 @@ def get(self, timeout: Optional[float] = None) -> Data: # mypy cannot figure out that chunks have the proper type. message: Data = joiner.join(self.chunks) # type: ignore - assert not self.message_fetched.is_set() - self.message_fetched.set() - self.chunks = [] assert self.chunks_queue is None + assert not self.message_fetched.is_set() + self.message_fetched.set() + return message def get_iter(self) -> Iterator[Data]: @@ -169,26 +169,26 @@ def get_iter(self) -> Iterator[Data]: with self.mutex: self.get_in_progress = False - assert self.message_complete.is_set() - self.message_complete.clear() - # get_iter() was unblocked by close() rather than put(). if self.closed: raise EOFError("stream of frames ended") - assert not self.message_fetched.is_set() - self.message_fetched.set() + assert self.message_complete.is_set() + self.message_complete.clear() assert self.chunks == [] self.chunks_queue = None + assert not self.message_fetched.is_set() + self.message_fetched.set() + def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, either by calling :meth:`get` or by fully - consuming the return value of :meth:`get_iter`. + the message is fetched, which can be achieved by calling :meth:`get` or + by fully consuming the return value of :meth:`get_iter`. :meth:`put` assumes that the stream of frames respects the protocol. If it doesn't, the behavior is undefined. @@ -247,13 +247,13 @@ def put(self, frame: Frame) -> None: with self.mutex: self.put_in_progress = False - assert self.message_fetched.is_set() - self.message_fetched.clear() - # put() was unblocked by close() rather than get() or get_iter(). if self.closed: raise EOFError("stream of frames ended") + assert self.message_fetched.is_set() + self.message_fetched.clear() + self.decoder = None def close(self) -> None: From 5737b474ad7d4a3a5e04d68299f4e5ec34bd62ac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Nov 2023 14:46:44 +0100 Subject: [PATCH 469/762] Start version 12.1. This commit should have been made right after releasing 12.0. --- docs/project/changelog.rst | 5 +++++ src/websockets/version.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 264e6e42d..200ca7ef3 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,11 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +12.1 +---- + +*In development* + 12.0 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index d1c99458e..f1de3cbf4 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "12.0" +tag = version = commit = "12.1" if not released: # pragma: no cover From 94dd203f63bb52b1a30faa228e63ada2f0f2e874 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 9 Dec 2023 06:25:11 +0000 Subject: [PATCH 470/762] Bump actions/setup-python from 4 to 5 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 4 ++-- .github/workflows/tests.yml | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6e895e64e..7d56b9aa5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.x - name: Build sdist @@ -47,7 +47,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.x - name: Set up QEMU diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 470f5bc96..b128defb5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install tox @@ -36,7 +36,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python 3.x - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.x" - name: Install tox @@ -74,7 +74,7 @@ jobs: - name: Check out repository uses: actions/checkout@v4 - name: Install Python ${{ matrix.python }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: Install tox From fe1833fb103f4d63baee525c5b62dedd24b9884e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Nov 2023 14:48:59 +0100 Subject: [PATCH 471/762] Confirm support for Python 3.12. Fix #1417. --- .github/workflows/tests.yml | 1 + docs/project/changelog.rst | 5 +++++ pyproject.toml | 1 + tox.ini | 1 + 4 files changed, 8 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b128defb5..8161f1cbb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,6 +61,7 @@ jobs: - "3.9" - "3.10" - "3.11" + - "3.12" - "pypy-3.8" - "pypy-3.9" is_main: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 200ca7ef3..963353d0e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,6 +30,11 @@ notice. *In development* +New features +............ + +* Validated compatibility with Python 3.12. + 12.0 ---- diff --git a/pyproject.toml b/pyproject.toml index f24616dd7..a7b4a6a9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] dynamic = ["version", "readme"] diff --git a/tox.ini b/tox.ini index 939d8c0cd..538b638d9 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ envlist = py39 py310 py311 + py312 coverage black ruff From beeb9387dedb574c8d1a6c2a6e7312c17788c858 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 16 Dec 2023 06:24:38 +0000 Subject: [PATCH 472/762] Bump actions/download-artifact from 3 to 4 Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 3 to 4. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7d56b9aa5..c1b750c80 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -76,7 +76,7 @@ jobs: id-token: write steps: - name: Download artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: artifact path: dist From 33b20e11e86f8490770185c78ed39adab8db4560 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 16 Dec 2023 06:24:43 +0000 Subject: [PATCH 473/762] Bump actions/upload-artifact from 3 to 4 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 3 to 4. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c1b750c80..4a00bf8fc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ jobs: - name: Build sdist run: python setup.py sdist - name: Save sdist - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: dist/*.tar.gz - name: Install wheel @@ -30,7 +30,7 @@ jobs: BUILD_EXTENSION: no run: python setup.py bdist_wheel - name: Save wheel - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: dist/*.whl @@ -60,7 +60,7 @@ jobs: env: BUILD_EXTENSION: yes - name: Save wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: wheelhouse/*.whl From b3c51958849c80209b4d68fca081ef3fffc5e2bd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:02:51 +0100 Subject: [PATCH 474/762] Make test_local/remote_address more robust. Fix #1427. --- tests/sync/test_connection.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 63544d4ad..e128425d8 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -656,13 +656,17 @@ def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - def test_local_address(self): - """Connection has a local_address attribute.""" - self.assertIsNotNone(self.connection.local_address) - - def test_remote_address(self): - """Connection has a remote_address attribute.""" - self.assertIsNotNone(self.connection.remote_address) + @unittest.mock.patch("socket.socket.getsockname", return_value=("sock", 1234)) + def test_local_address(self, getsockname): + """Connection provides a local_address attribute.""" + self.assertEqual(self.connection.local_address, ("sock", 1234)) + getsockname.assert_called_with() + + @unittest.mock.patch("socket.socket.getpeername", return_value=("peer", 1234)) + def test_remote_address(self, getpeername): + """Connection provides a remote_address attribute.""" + self.assertEqual(self.connection.remote_address, ("peer", 1234)) + getpeername.assert_called_with() def test_request(self): """Connection has a request attribute.""" From 9038a62e7261af21109977407907038a1a0efc65 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:38:53 +0100 Subject: [PATCH 475/762] Make mypy 1.8.0 happy. --- src/websockets/legacy/auth.py | 2 +- src/websockets/typing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index d3425836e..e8d6b75d5 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -159,7 +159,7 @@ def basic_auth_protocol_factory( if is_credentials(credentials): credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(credentials) + credentials_list = list(cast(Iterable[Credentials], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/typing.py b/src/websockets/typing.py index cc3e3ec0d..e073e650d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,7 +2,7 @@ import http import logging -from typing import List, NewType, Optional, Tuple, Union +from typing import Any, List, NewType, Optional, Tuple, Union __all__ = [ @@ -28,7 +28,7 @@ """ -LoggerLike = Union[logging.Logger, logging.LoggerAdapter] +LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" From 230d5052a33c0d940d926a1fc88909d39f57efd8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 15:41:54 +0100 Subject: [PATCH 476/762] Add tests for abstract classes. This prevents Python 3.12 to complain that no test cases were run and to exit with code 5 (which breaks maxi_cov). --- pyproject.toml | 1 - tests/extensions/test_base.py | 28 +++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a7b4a6a9e..c4c5412c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,6 @@ exclude_lines = [ "if typing.TYPE_CHECKING:", "pragma: no cover", "raise AssertionError", - "raise NotImplementedError", "self.fail\\(\".*\"\\)", "@unittest.skip", ] diff --git a/tests/extensions/test_base.py b/tests/extensions/test_base.py index b18ffb6fb..62250b07f 100644 --- a/tests/extensions/test_base.py +++ b/tests/extensions/test_base.py @@ -1,4 +1,30 @@ +import unittest + from websockets.extensions.base import * +from websockets.frames import Frame, Opcode + + +class ExtensionTests(unittest.TestCase): + def test_encode(self): + with self.assertRaises(NotImplementedError): + Extension().encode(Frame(Opcode.TEXT, b"")) + + def test_decode(self): + with self.assertRaises(NotImplementedError): + Extension().decode(Frame(Opcode.TEXT, b"")) + + +class ClientExtensionFactoryTests(unittest.TestCase): + def test_get_request_params(self): + with self.assertRaises(NotImplementedError): + ClientExtensionFactory().get_request_params() + + def test_process_response_params(self): + with self.assertRaises(NotImplementedError): + ClientExtensionFactory().process_response_params([], []) -# Abstract classes don't provide any behavior to test. +class ServerExtensionFactoryTests(unittest.TestCase): + def test_process_request_params(self): + with self.assertRaises(NotImplementedError): + ServerExtensionFactory().process_request_params([], []) From 5209b2a1cba00b28b8f62502157d5dbb98625a49 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 1 Jan 2024 16:11:16 +0100 Subject: [PATCH 477/762] Remove empty test modules. This prevents Python 3.12 to complain that no test cases were run and to exit with code 5 (which breaks maxi_cov). --- tests/maxi_cov.py | 33 ++++++++++++++++++++++----------- tests/test_auth.py | 1 - tests/test_http.py | 7 +++++++ tests/test_typing.py | 1 - 4 files changed, 29 insertions(+), 13 deletions(-) delete mode 100644 tests/test_auth.py delete mode 100644 tests/test_typing.py diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 2568dcf18..bc4a44e8c 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -8,8 +8,15 @@ import sys -UNMAPPED_SRC_FILES = ["websockets/version.py"] -UNMAPPED_TEST_FILES = ["tests/test_exports.py"] +UNMAPPED_SRC_FILES = [ + "websockets/auth.py", + "websockets/typing.py", + "websockets/version.py", +] + +UNMAPPED_TEST_FILES = [ + "tests/test_exports.py", +] def check_environment(): @@ -60,7 +67,7 @@ def get_mapping(src_dir="src"): # Map source files to test files. mapping = {} - unmapped_test_files = [] + unmapped_test_files = set() for test_file in test_files: dir_name, file_name = os.path.split(test_file) @@ -73,26 +80,30 @@ def get_mapping(src_dir="src"): if src_file in src_files: mapping[src_file] = test_file else: - unmapped_test_files.append(test_file) + unmapped_test_files.add(test_file) - unmapped_src_files = list(set(src_files) - set(mapping)) + unmapped_src_files = set(src_files) - set(mapping) # Ensure that all files are mapped. - assert unmapped_src_files == UNMAPPED_SRC_FILES - assert unmapped_test_files == UNMAPPED_TEST_FILES + assert unmapped_src_files == set(UNMAPPED_SRC_FILES) + assert unmapped_test_files == set(UNMAPPED_TEST_FILES) return mapping def get_ignored_files(src_dir="src"): """Return the list of files to exclude from coverage measurement.""" - + # */websockets matches src/websockets and .tox/**/site-packages/websockets. return [ - # */websockets matches src/websockets and .tox/**/site-packages/websockets. - # There are no tests for the __main__ module and for compatibility modules. + # There are no tests for the __main__ module. "*/websockets/__main__.py", + # There is nothing to test on type declarations. + "*/websockets/typing.py", + # We don't test compatibility modules with previous versions of Python + # or websockets (import locations). "*/websockets/*/compatibility.py", + "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", @@ -125,7 +136,7 @@ def run_coverage(mapping, src_dir="src"): "-m", "unittest", ] - + UNMAPPED_TEST_FILES, + + list(UNMAPPED_TEST_FILES), check=True, ) # Append coverage of each source module by the corresponding test module. diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 28db93155..000000000 --- a/tests/test_auth.py +++ /dev/null @@ -1 +0,0 @@ -from websockets.auth import * diff --git a/tests/test_http.py b/tests/test_http.py index 036bc1410..baaa7d416 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1 +1,8 @@ +import unittest + from websockets.http import * + + +class HTTPTests(unittest.TestCase): + def test_user_agent(self): + USER_AGENT # exists diff --git a/tests/test_typing.py b/tests/test_typing.py deleted file mode 100644 index 202de840f..000000000 --- a/tests/test_typing.py +++ /dev/null @@ -1 +0,0 @@ -from websockets.typing import * From 3c6b1aab96adde1a4b0d3e8f1a93b7f2c7310af0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 13 Jan 2024 20:46:54 +0100 Subject: [PATCH 478/762] Restore compatibility with Python < 3.11. Broken in 9038a62e. --- src/websockets/typing.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/websockets/typing.py b/src/websockets/typing.py index e073e650d..5dfecf66f 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,6 +2,7 @@ import http import logging +import typing from typing import Any, List, NewType, Optional, Tuple, Union @@ -28,8 +29,12 @@ """ -LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] -"""Types accepted where a :class:`~logging.Logger` is expected.""" +if typing.TYPE_CHECKING: + LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] + """Types accepted where a :class:`~logging.Logger` is expected.""" +else: # remove this branch when dropping support for Python < 3.11 + LoggerLike = Union[logging.Logger, logging.LoggerAdapter] + """Types accepted where a :class:`~logging.Logger` is expected.""" StatusLike = Union[http.HTTPStatus, int] From 7b522ec0df8f4e26abe09046a0ae7861714f5a2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 13 Jan 2024 21:54:08 +0100 Subject: [PATCH 479/762] Simplify code. It had to be written in that way with asyncio.wait_for but that isn't necessary anymore with asyncio.timeout. --- src/websockets/legacy/client.py | 55 ++++++++++++++++----------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 48622523e..b85d22867 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -640,38 +640,35 @@ async def __aexit__( def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: # Create a suitable iterator by calling __await__ on a coroutine. - return self.__await_impl_timeout__().__await__() - - async def __await_impl_timeout__(self) -> WebSocketClientProtocol: - async with asyncio_timeout(self.open_timeout): - return await self.__await_impl__() + return self.__await_impl__().__await__() async def __await_impl__(self) -> WebSocketClientProtocol: - for redirects in range(self.MAX_REDIRECTS_ALLOWED): - _transport, _protocol = await self._create_connection() - protocol = cast(WebSocketClientProtocol, _protocol) - try: - await protocol.handshake( - self._wsuri, - origin=protocol.origin, - available_extensions=protocol.available_extensions, - available_subprotocols=protocol.available_subprotocols, - extra_headers=protocol.extra_headers, - ) - except RedirectHandshake as exc: - protocol.fail_connection() - await protocol.wait_closed() - self.handle_redirect(exc.uri) - # Avoid leaking a connected socket when the handshake fails. - except (Exception, asyncio.CancelledError): - protocol.fail_connection() - await protocol.wait_closed() - raise + async with asyncio_timeout(self.open_timeout): + for _redirects in range(self.MAX_REDIRECTS_ALLOWED): + _transport, _protocol = await self._create_connection() + protocol = cast(WebSocketClientProtocol, _protocol) + try: + await protocol.handshake( + self._wsuri, + origin=protocol.origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + except RedirectHandshake as exc: + protocol.fail_connection() + await protocol.wait_closed() + self.handle_redirect(exc.uri) + # Avoid leaking a connected socket when the handshake fails. + except (Exception, asyncio.CancelledError): + protocol.fail_connection() + await protocol.wait_closed() + raise + else: + self.protocol = protocol + return protocol else: - self.protocol = protocol - return protocol - else: - raise SecurityError("too many redirects") + raise SecurityError("too many redirects") # ... = yield from connect(...) - remove when dropping Python < 3.10 From 35bc7dd8288445289134c335aae8af859862ccd1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 14 Jan 2024 21:20:06 +0100 Subject: [PATCH 480/762] Create futures with create_future. This is the preferred way to create Futures in asyncio. --- README.rst | 2 +- compliance/test_server.py | 2 +- docs/intro/tutorial1.rst | 2 +- docs/topics/broadcast.rst | 5 +++-- example/django/authentication.py | 2 +- example/echo.py | 2 +- example/faq/health_check_server.py | 2 +- example/legacy/basic_auth_server.py | 2 +- example/legacy/unix_server.py | 2 +- example/quickstart/counter.py | 2 +- example/quickstart/server.py | 2 +- example/quickstart/server_secure.py | 2 +- example/quickstart/show_time.py | 2 +- example/tutorial/step1/app.py | 2 +- example/tutorial/step2/app.py | 2 +- experiments/broadcast/server.py | 5 +++-- src/websockets/legacy/protocol.py | 4 ++-- src/websockets/legacy/server.py | 6 ++++-- 18 files changed, 26 insertions(+), 22 deletions(-) diff --git a/README.rst b/README.rst index 870b208ba..94cd79ab9 100644 --- a/README.rst +++ b/README.rst @@ -55,7 +55,7 @@ Here's an echo server with the ``asyncio`` API: async def main(): async with serve(echo, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/compliance/test_server.py b/compliance/test_server.py index 92f895d92..5701e4485 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -21,7 +21,7 @@ async def echo(ws): async def main(): with websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): try: - await asyncio.Future() + await asyncio.get_running_loop().create_future() # run forever except KeyboardInterrupt: pass diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index ff85003b5..6b32d47f6 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -195,7 +195,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 1acb372d4..b6ddda734 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -273,10 +273,11 @@ Here's a message stream that supports multiple consumers:: class PubSub: def __init__(self): - self.waiter = asyncio.Future() + self.waiter = asyncio.get_running_loop().create_future() def publish(self, value): - waiter, self.waiter = self.waiter, asyncio.Future() + waiter = self.waiter + self.waiter = asyncio.get_running_loop().create_future() waiter.set_result((value, self.waiter)) async def subscribe(self): diff --git a/example/django/authentication.py b/example/django/authentication.py index f6dad0f55..83e128f07 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -23,7 +23,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "localhost", 8888): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/echo.py b/example/echo.py index 2e47e52d9..d11b33527 100755 --- a/example/echo.py +++ b/example/echo.py @@ -9,6 +9,6 @@ async def echo(websocket): async def main(): async with serve(echo, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index 7b8bded77..6c7681e8a 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -17,6 +17,6 @@ async def main(): echo, "localhost", 8765, process_request=health_check, ): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/legacy/basic_auth_server.py b/example/legacy/basic_auth_server.py index d2efeb7e5..6f6020253 100755 --- a/example/legacy/basic_auth_server.py +++ b/example/legacy/basic_auth_server.py @@ -16,6 +16,6 @@ async def main(): realm="example", credentials=("mary", "p@ssw0rd") ), ): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/legacy/unix_server.py b/example/legacy/unix_server.py index 335039c35..5bfb66072 100755 --- a/example/legacy/unix_server.py +++ b/example/legacy/unix_server.py @@ -18,6 +18,6 @@ async def hello(websocket): async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") async with websockets.unix_serve(hello, socket_path): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index 566e12965..414919e04 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -43,7 +43,7 @@ async def counter(websocket): async def main(): async with websockets.serve(counter, "localhost", 6789): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server.py b/example/quickstart/server.py index 31b182972..64d7adeb6 100755 --- a/example/quickstart/server.py +++ b/example/quickstart/server.py @@ -14,7 +14,7 @@ async def hello(websocket): async def main(): async with websockets.serve(hello, "localhost", 8765): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server_secure.py b/example/quickstart/server_secure.py index de41d30dc..11db5fb3a 100755 --- a/example/quickstart/server_secure.py +++ b/example/quickstart/server_secure.py @@ -20,7 +20,7 @@ async def hello(websocket): async def main(): async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index a83078e8a..add226869 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -13,7 +13,7 @@ async def show_time(websocket): async def main(): async with websockets.serve(show_time, "localhost", 5678): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": asyncio.run(main()) diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index 3b0fbd786..6ec1c60b8 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -58,7 +58,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index 2693d4304..db3e36374 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -183,7 +183,7 @@ async def handler(websocket): async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() # run forever + await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index 9c9907b7f..b0407ba34 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -27,10 +27,11 @@ async def relay(queue, websocket): class PubSub: def __init__(self): - self.waiter = asyncio.Future() + self.waiter = asyncio.get_running_loop().create_future() def publish(self, value): - waiter, self.waiter = self.waiter, asyncio.Future() + waiter = self.waiter + self.waiter = asyncio.get_running_loop().create_future() waiter.set_result((value, self.waiter)) async def subscribe(self): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 19cee0e65..47f948b7a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -664,7 +664,7 @@ async def send( return opcode, data = prepare_data(fragment) - self._fragmented_message_waiter = asyncio.Future() + self._fragmented_message_waiter = self.loop.create_future() try: # First fragment. await self.write_frame(False, opcode, data) @@ -709,7 +709,7 @@ async def send( return opcode, data = prepare_data(fragment) - self._fragmented_message_waiter = asyncio.Future() + self._fragmented_message_waiter = self.loop.create_future() try: # First fragment. await self.write_frame(False, opcode, data) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 7c24dd74a..d95bec4f6 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -897,7 +897,8 @@ class Serve: Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object provides a :meth:`~WebSocketServer.close` method to shut down the server:: - stop = asyncio.Future() # set this future to exit the server + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() server = await serve(...) await stop @@ -906,7 +907,8 @@ class Serve: :func:`serve` can be used as an asynchronous context manager. Then, the server is shut down automatically when exiting the context:: - stop = asyncio.Future() # set this future to exit the server + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() async with serve(...): await stop From cba4c242614734a722891992e8bc005bc848c0c1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 18:56:07 +0100 Subject: [PATCH 481/762] Avoid duplicating signature in API docs. --- docs/reference/sansio/client.rst | 2 +- docs/reference/sansio/common.rst | 2 +- docs/reference/sansio/server.rst | 2 +- docs/reference/sync/client.rst | 4 ++-- docs/reference/sync/server.rst | 4 ++-- src/websockets/sync/client.py | 18 ++++++++++-------- src/websockets/sync/server.py | 14 +++++++++----- 7 files changed, 26 insertions(+), 20 deletions(-) diff --git a/docs/reference/sansio/client.rst b/docs/reference/sansio/client.rst index 09bafc745..12f88b8ed 100644 --- a/docs/reference/sansio/client.rst +++ b/docs/reference/sansio/client.rst @@ -5,7 +5,7 @@ Client (`Sans-I/O`_) .. currentmodule:: websockets.client -.. autoclass:: ClientProtocol(wsuri, origin=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ClientProtocol .. automethod:: receive_data diff --git a/docs/reference/sansio/common.rst b/docs/reference/sansio/common.rst index cd1ef3c63..7d5447ac9 100644 --- a/docs/reference/sansio/common.rst +++ b/docs/reference/sansio/common.rst @@ -7,7 +7,7 @@ Both sides (`Sans-I/O`_) .. automodule:: websockets.protocol -.. autoclass:: Protocol(side, state=State.OPEN, max_size=2 ** 20, logger=None) +.. autoclass:: Protocol .. automethod:: receive_data diff --git a/docs/reference/sansio/server.rst b/docs/reference/sansio/server.rst index d70df6277..3152f174e 100644 --- a/docs/reference/sansio/server.rst +++ b/docs/reference/sansio/server.rst @@ -5,7 +5,7 @@ Server (`Sans-I/O`_) .. currentmodule:: websockets.server -.. autoclass:: ServerProtocol(origins=None, extensions=None, subprotocols=None, state=State.CONNECTING, max_size=2 ** 20, logger=None) +.. autoclass:: ServerProtocol .. automethod:: receive_data diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index 6cccd6ec4..af1132412 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -6,9 +6,9 @@ Client (:mod:`threading`) Opening a connection -------------------- -.. autofunction:: connect(uri, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: connect -.. autofunction:: unix_connect(path, uri=None, *, sock=None, ssl_context=None, server_hostname=None, origin=None, extensions=None, subprotocols=None, additional_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: unix_connect Using a connection ------------------ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 35c112046..7ed744df2 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -6,9 +6,9 @@ Server (:mod:`threading`) Creating a server ----------------- -.. autofunction:: serve(handler, host=None, port=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: serve -.. autofunction:: unix_serve(handler, path=None, *, sock=None, ssl_context=None, origins=None, extensions=None, subprotocols=None, select_subprotocol=None, process_request=None, process_response=None, server_header="Python/x.y.z websockets/X.Y", compression="deflate", open_timeout=10, close_timeout=10, max_size=2 ** 20, logger=None, create_connection=None) +.. autofunction:: unix_serve Running a server ---------------- diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 087ff5f56..78a9a3c86 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -126,12 +126,10 @@ def recv_events(self) -> None: def connect( uri: str, *, - # TCP/TLS — unix and path are only for unix_connect() + # TCP/TLS sock: Optional[socket.socket] = None, ssl_context: Optional[ssl.SSLContext] = None, server_hostname: Optional[str] = None, - unix: bool = False, - path: Optional[str] = None, # WebSocket origin: Optional[Origin] = None, extensions: Optional[Sequence[ClientExtensionFactory]] = None, @@ -148,6 +146,7 @@ def connect( logger: Optional[LoggerLike] = None, # Escape hatch for advanced customization create_connection: Optional[Type[ClientConnection]] = None, + **kwargs: Any, ) -> ClientConnection: """ Connect to the WebSocket server at ``uri``. @@ -210,13 +209,15 @@ def connect( if not wsuri.secure and ssl_context is not None: raise TypeError("ssl_context argument is incompatible with a ws:// URI") + # Private APIs for unix_connect() + unix: bool = kwargs.pop("unix", False) + path: Optional[str] = kwargs.pop("path", None) + if unix: if path is None and sock is None: raise TypeError("missing path argument") elif path is not None and sock is not None: raise TypeError("path and sock arguments are incompatible") - else: - assert path is None # private argument, only set by unix_connect() if subprotocols is not None: validate_subprotocols(subprotocols) @@ -241,7 +242,7 @@ def connect( if unix: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(deadline.timeout()) - assert path is not None # validated above -- this is for mpypy + assert path is not None # mypy cannot figure this out sock.connect(path) else: sock = socket.create_connection( @@ -308,8 +309,9 @@ def unix_connect( """ Connect to a WebSocket server listening on a Unix socket. - This function is identical to :func:`connect`, except for the additional - ``path`` argument. It's only available on Unix. + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. It's mainly useful for debugging servers listening on Unix sockets. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index d12da0c65..7faab0a3d 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -266,11 +266,9 @@ def serve( host: Optional[str] = None, port: Optional[int] = None, *, - # TCP/TLS — unix and path are only for unix_serve() + # TCP/TLS sock: Optional[socket.socket] = None, ssl_context: Optional[ssl.SSLContext] = None, - unix: bool = False, - path: Optional[str] = None, # WebSocket origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -304,6 +302,7 @@ def serve( logger: Optional[LoggerLike] = None, # Escape hatch for advanced customization create_connection: Optional[Type[ServerConnection]] = None, + **kwargs: Any, ) -> WebSocketServer: """ Create a WebSocket server listening on ``host`` and ``port``. @@ -397,6 +396,10 @@ def handler(websocket): # Bind socket and listen + # Private APIs for unix_connect() + unix: bool = kwargs.pop("unix", False) + path: Optional[str] = kwargs.pop("path", None) + if sock is None: if unix: if path is None: @@ -515,8 +518,9 @@ def unix_serve( """ Create a WebSocket server listening on a Unix socket. - This function is identical to :func:`serve`, except the ``host`` and - ``port`` arguments are replaced by ``path``. It's only available on Unix. + This function accepts the same keyword arguments as :func:`serve`. + + It's only available on Unix. It's useful for deploying a server behind a reverse proxy such as nginx. From cd4bc7960658db6d51f60f528b3b53c718426591 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 19:08:49 +0100 Subject: [PATCH 482/762] Pass arguments to create_server/connection. --- src/websockets/sync/client.py | 8 ++++---- src/websockets/sync/server.py | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 78a9a3c86..79af0132f 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -195,6 +195,8 @@ def connect( the connection. Set it to a wrapper or a subclass to customize connection handling. + Any other keyword arguments are passed to :func:`~socket.create_connection`. + Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. OSError: If the TCP connection fails. @@ -245,10 +247,8 @@ def connect( assert path is not None # mypy cannot figure this out sock.connect(path) else: - sock = socket.create_connection( - (wsuri.host, wsuri.port), - deadline.timeout(), - ) + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection((wsuri.host, wsuri.port), **kwargs) sock.settimeout(None) # Disable Nagle algorithm diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 7faab0a3d..c19992849 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -379,6 +379,9 @@ def handler(websocket): create_connection: Factory for the :class:`ServerConnection` managing the connection. Set it to a wrapper or a subclass to customize connection handling. + + Any other keyword arguments are passed to :func:`~socket.create_server`. + """ # Process parameters @@ -404,9 +407,10 @@ def handler(websocket): if unix: if path is None: raise TypeError("missing path argument") - sock = socket.create_server(path, family=socket.AF_UNIX) + kwargs.setdefault("family", socket.AF_UNIX) + sock = socket.create_server(path, **kwargs) else: - sock = socket.create_server((host, port)) + sock = socket.create_server((host, port), **kwargs) else: if path is not None: raise TypeError("path and sock arguments are incompatible") From 03ecfa5611f0c87ea9cfa7497f78e0c85408060e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 20 Jan 2024 19:14:28 +0100 Subject: [PATCH 483/762] Standardize on .encode(). We had a mix of .encode() and .encode("utf-8") -- which is the default. --- experiments/compression/benchmark.py | 2 +- src/websockets/frames.py | 6 +- src/websockets/sync/connection.py | 8 +- tests/extensions/test_permessage_deflate.py | 26 +++--- tests/legacy/test_framing.py | 6 +- tests/legacy/test_protocol.py | 90 ++++++++++----------- tests/test_frames.py | 4 +- 7 files changed, 69 insertions(+), 73 deletions(-) diff --git a/experiments/compression/benchmark.py b/experiments/compression/benchmark.py index c5b13c8fa..4fbdf6220 100644 --- a/experiments/compression/benchmark.py +++ b/experiments/compression/benchmark.py @@ -66,7 +66,7 @@ def _run(data): for _ in range(REPEAT): for item in data: if isinstance(item, str): - item = item.encode("utf-8") + item = item.encode() # Taken from PerMessageDeflate.encode item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) if item.endswith(b"\x00\x00\xff\xff"): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 6b1befb2e..63c35ed4d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -364,7 +364,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: """ if isinstance(data, str): - return OP_TEXT, data.encode("utf-8") + return OP_TEXT, data.encode() elif isinstance(data, BytesLike): return OP_BINARY, data else: @@ -387,7 +387,7 @@ def prepare_ctrl(data: Data) -> bytes: """ if isinstance(data, str): - return data.encode("utf-8") + return data.encode() elif isinstance(data, BytesLike): return bytes(data) else: @@ -456,7 +456,7 @@ def serialize(self) -> bytes: """ self.check() - return struct.pack("!H", self.code) + self.reason.encode("utf-8") + return struct.pack("!H", self.code) + self.reason.encode() def check(self) -> None: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 4a8879e37..62aa17ffd 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -287,7 +287,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_text(message.encode("utf-8")) + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): @@ -324,7 +324,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: ) self.send_in_progress = True self.protocol.send_text( - chunk.encode("utf-8"), + chunk.encode(), fin=False, ) elif isinstance(chunk, BytesLike): @@ -349,7 +349,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: with self.send_context(): assert self.send_in_progress self.protocol.send_continuation( - chunk.encode("utf-8"), + chunk.encode(), fin=False, ) elif isinstance(chunk, BytesLike) and not text: @@ -630,7 +630,7 @@ def send_context( socket:: with self.send_context(): - self.protocol.send_text(message.encode("utf-8")) + self.protocol.send_text(message.encode()) When the connection isn't open on entry, when the connection is expected to close on exit, or when an unexpected error happens, terminating the diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index 0e698566f..ee09813c4 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -84,7 +84,7 @@ def test_no_encode_decode_close_frame(self): # Data frames are encoded and decoded. def test_encode_decode_text_frame(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame = self.extension.encode(frame) @@ -112,9 +112,9 @@ def test_encode_decode_binary_frame(self): self.assertEqual(dec_frame, frame) def test_encode_decode_fragmented_text_frame(self): - frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) - frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) - frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode(), fin=False) + frame2 = Frame(OP_CONT, " & ".encode(), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode()) enc_frame1 = self.extension.encode(frame1) enc_frame2 = self.extension.encode(frame2) @@ -168,7 +168,7 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_no_decode_text_frame(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) # Try decoding a frame that wasn't encoded. self.assertEqual(self.extension.decode(frame), frame) @@ -180,9 +180,9 @@ def test_no_decode_binary_frame(self): self.assertEqual(self.extension.decode(frame), frame) def test_no_decode_fragmented_text_frame(self): - frame1 = Frame(OP_TEXT, "café".encode("utf-8"), fin=False) - frame2 = Frame(OP_CONT, " & ".encode("utf-8"), fin=False) - frame3 = Frame(OP_CONT, "croissants".encode("utf-8")) + frame1 = Frame(OP_TEXT, "café".encode(), fin=False) + frame2 = Frame(OP_CONT, " & ".encode(), fin=False) + frame3 = Frame(OP_CONT, "croissants".encode()) dec_frame1 = self.extension.decode(frame1) dec_frame2 = self.extension.decode(frame2) @@ -203,7 +203,7 @@ def test_no_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame2, frame2) def test_context_takeover(self): - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -215,7 +215,7 @@ def test_remote_no_context_takeover(self): # No context takeover when decoding messages. self.extension = PerMessageDeflate(True, False, 15, 15) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -233,7 +233,7 @@ def test_local_no_context_takeover(self): # No context takeover when encoding and decoding messages. self.extension = PerMessageDeflate(True, True, 15, 15) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame1 = self.extension.encode(frame) enc_frame2 = self.extension.encode(frame) @@ -253,7 +253,7 @@ def test_compress_settings(self): # Configure an extension so that no compression actually occurs. extension = PerMessageDeflate(False, False, 15, 15, {"level": 0}) - frame = Frame(OP_TEXT, "café".encode("utf-8")) + frame = Frame(OP_TEXT, "café".encode()) enc_frame = extension.encode(frame) @@ -269,7 +269,7 @@ def test_compress_settings(self): # Frames aren't decoded beyond max_size. def test_decompress_max_size(self): - frame = Frame(OP_TEXT, ("a" * 20).encode("utf-8")) + frame = Frame(OP_TEXT, ("a" * 20).encode()) enc_frame = self.extension.encode(frame) diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index e1e4c891b..6f811bd5e 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -76,14 +76,12 @@ def test_binary_masked(self): ) def test_non_ascii_text(self): - self.round_trip( - b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode("utf-8")) - ) + self.round_trip(b"\x81\x05caf\xc3\xa9", Frame(True, OP_TEXT, "café".encode())) def test_non_ascii_text_masked(self): self.round_trip( b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", - Frame(True, OP_TEXT, "café".encode("utf-8")), + Frame(True, OP_TEXT, "café".encode()), mask=True, ) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index f2eb0fea0..f3dcd9ac7 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -428,7 +428,7 @@ def test_close_reason_not_set(self): # Test the recv coroutine. def test_recv_text(self): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -458,7 +458,7 @@ def test_recv_on_closed_connection(self): self.loop.run_until_complete(self.protocol.recv()) def test_recv_protocol_error(self): - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CONT, "café".encode())) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") @@ -469,7 +469,7 @@ def test_recv_unicode_error(self): def test_recv_text_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) + self.receive_frame(Frame(True, OP_TEXT, "café".encode() * 205)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") @@ -481,7 +481,7 @@ def test_recv_binary_payload_too_big(self): def test_recv_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8") * 205)) + self.receive_frame(Frame(True, OP_TEXT, "café".encode() * 205)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café" * 205) @@ -498,7 +498,7 @@ def test_recv_queue_empty(self): asyncio.wait_for(asyncio.shield(recv), timeout=MS) ) - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(recv) self.assertEqual(data, "café") @@ -507,7 +507,7 @@ def test_recv_queue_full(self): # Test internals because it's hard to verify buffers from the outside. self.assertEqual(list(self.protocol.messages), []) - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.run_loop_once() self.assertEqual(list(self.protocol.messages), ["café"]) @@ -535,7 +535,7 @@ def test_recv_queue_no_limit(self): self.protocol.max_queue = None for _ in range(100): - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.run_loop_once() # Incoming message queue can contain at least 100 messages. @@ -562,7 +562,7 @@ def test_recv_canceled(self): self.loop.run_until_complete(recv) # The next frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -570,15 +570,13 @@ def test_recv_canceled_race_condition(self): recv = self.loop.create_task( asyncio.wait_for(self.protocol.recv(), timeout=0.000_001) ) - self.loop.call_soon( - self.receive_frame, Frame(True, OP_TEXT, "café".encode("utf-8")) - ) + self.loop.call_soon(self.receive_frame, Frame(True, OP_TEXT, "café".encode())) with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(recv) # The previous frame doesn't disappear in a vacuum (it used to). - self.receive_frame(Frame(True, OP_TEXT, "tea".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "tea".encode())) data = self.loop.run_until_complete(self.protocol.recv()) # If we're getting "tea" there, it means "café" was swallowed (ha, ha). self.assertEqual(data, "café") @@ -586,7 +584,7 @@ def test_recv_canceled_race_condition(self): def test_recv_when_transfer_data_cancelled(self): # Clog incoming queue. self.protocol.max_queue = 1 - self.receive_frame(Frame(True, OP_TEXT, "café".encode("utf-8"))) + self.receive_frame(Frame(True, OP_TEXT, "café".encode())) self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.run_loop_once() @@ -620,7 +618,7 @@ def test_recv_prevents_concurrent_calls(self): def test_send_text(self): self.loop.run_until_complete(self.protocol.send("café")) - self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) def test_send_binary(self): self.loop.run_until_complete(self.protocol.send(b"tea")) @@ -647,9 +645,9 @@ def test_send_type_error(self): def test_send_iterable_text(self): self.loop.run_until_complete(self.protocol.send(["ca", "fé"])) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), ) def test_send_iterable_binary(self): @@ -687,7 +685,7 @@ def test_send_iterable_mixed_type_error(self): with self.assertRaises(TypeError): self.loop.run_until_complete(self.protocol.send(["café", b"tea"])) self.assertFramesSent( - (False, OP_TEXT, "café".encode("utf-8")), + (False, OP_TEXT, "café".encode()), (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) @@ -710,18 +708,18 @@ async def run_concurrently(): self.loop.run_until_complete(run_concurrently()) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), (True, OP_BINARY, b"tea"), ) def test_send_async_iterable_text(self): self.loop.run_until_complete(self.protocol.send(async_iterable(["ca", "fé"]))) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), ) def test_send_async_iterable_binary(self): @@ -761,7 +759,7 @@ def test_send_async_iterable_mixed_type_error(self): self.protocol.send(async_iterable(["café", b"tea"])) ) self.assertFramesSent( - (False, OP_TEXT, "café".encode("utf-8")), + (False, OP_TEXT, "café".encode()), (True, OP_CLOSE, Close(CloseCode.INTERNAL_ERROR, "").serialize()), ) @@ -784,9 +782,9 @@ async def run_concurrently(): self.loop.run_until_complete(run_concurrently()) self.assertFramesSent( - (False, OP_TEXT, "ca".encode("utf-8")), - (False, OP_CONT, "fé".encode("utf-8")), - (True, OP_CONT, "".encode("utf-8")), + (False, OP_TEXT, "ca".encode()), + (False, OP_CONT, "fé".encode()), + (True, OP_CONT, "".encode()), (True, OP_BINARY, b"tea"), ) @@ -829,7 +827,7 @@ def test_ping_default(self): def test_ping_text(self): self.loop.run_until_complete(self.protocol.ping("café")) - self.assertOneFrameSent(True, OP_PING, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_PING, "café".encode()) def test_ping_binary(self): self.loop.run_until_complete(self.protocol.ping(b"tea")) @@ -882,7 +880,7 @@ def test_pong_default(self): def test_pong_text(self): self.loop.run_until_complete(self.protocol.pong("café")) - self.assertOneFrameSent(True, OP_PONG, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_PONG, "café".encode()) def test_pong_binary(self): self.loop.run_until_complete(self.protocol.pong(b"tea")) @@ -1072,8 +1070,8 @@ def test_return_latency_on_pong(self): # Test the protocol's logic for rebuilding fragmented messages. def test_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) - self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) + self.receive_frame(Frame(True, OP_CONT, "fé".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") @@ -1086,8 +1084,8 @@ def test_fragmented_binary(self): def test_fragmented_text_payload_too_big(self): self.protocol.max_size = 1024 - self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) + self.receive_frame(Frame(False, OP_TEXT, "café".encode() * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode() * 105)) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.MESSAGE_TOO_BIG, "") @@ -1100,8 +1098,8 @@ def test_fragmented_binary_payload_too_big(self): def test_fragmented_text_no_max_size(self): self.protocol.max_size = None # for test coverage - self.receive_frame(Frame(False, OP_TEXT, "café".encode("utf-8") * 100)) - self.receive_frame(Frame(True, OP_CONT, "café".encode("utf-8") * 105)) + self.receive_frame(Frame(False, OP_TEXT, "café".encode() * 100)) + self.receive_frame(Frame(True, OP_CONT, "café".encode() * 105)) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café" * 205) @@ -1113,22 +1111,22 @@ def test_fragmented_binary_no_max_size(self): self.assertEqual(data, b"tea" * 342) def test_control_frame_within_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_PING, b"")) - self.receive_frame(Frame(True, OP_CONT, "fé".encode("utf-8"))) + self.receive_frame(Frame(True, OP_CONT, "fé".encode())) data = self.loop.run_until_complete(self.protocol.recv()) self.assertEqual(data, "café") self.assertOneFrameSent(True, OP_PONG, b"") def test_unterminated_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) # Missing the second part of the fragmented frame. self.receive_frame(Frame(True, OP_BINARY, b"tea")) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.PROTOCOL_ERROR, "") def test_close_handshake_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.receive_frame(Frame(True, OP_CLOSE, b"")) self.process_invalid_frames() # The RFC may have overlooked this case: it says that control frames @@ -1138,7 +1136,7 @@ def test_close_handshake_in_fragmented_text(self): self.assertConnectionClosed(CloseCode.NO_STATUS_RCVD, "") def test_connection_close_in_fragmented_text(self): - self.receive_frame(Frame(False, OP_TEXT, "ca".encode("utf-8"))) + self.receive_frame(Frame(False, OP_TEXT, "ca".encode())) self.process_invalid_frames() self.assertConnectionFailed(CloseCode.ABNORMAL_CLOSURE, "") @@ -1472,7 +1470,7 @@ def test_remote_close_during_send(self): def test_broadcast_text(self): broadcast([self.protocol], "café") - self.assertOneFrameSent(True, OP_TEXT, "café".encode("utf-8")) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) def test_broadcast_binary(self): broadcast([self.protocol], b"tea") @@ -1489,8 +1487,8 @@ def test_broadcast_no_clients(self): def test_broadcast_two_clients(self): broadcast([self.protocol, self.protocol], "café") self.assertFramesSent( - (True, OP_TEXT, "café".encode("utf-8")), - (True, OP_TEXT, "café".encode("utf-8")), + (True, OP_TEXT, "café".encode()), + (True, OP_TEXT, "café".encode()), ) def test_broadcast_skips_closed_connection(self): @@ -1513,7 +1511,7 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() - self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + self.assertOneFrameSent(False, OP_TEXT, "ca".encode()) with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") @@ -1530,7 +1528,7 @@ def test_broadcast_reports_connection_sending_fragmented_text(self): self.make_drain_slow() self.loop.create_task(self.protocol.send(["ca", "fé"])) self.run_loop_once() - self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8")) + self.assertOneFrameSent(False, OP_TEXT, "ca".encode()) with self.assertRaises(ExceptionGroup) as raised: broadcast([self.protocol], "café", raise_exceptions=True) diff --git a/tests/test_frames.py b/tests/test_frames.py index e323b3b57..3e9f5d6f8 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -77,14 +77,14 @@ def test_binary_masked(self): def test_non_ascii_text_unmasked(self): self.assertFrameData( - Frame(OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode()), b"\x81\x05caf\xc3\xa9", mask=False, ) def test_non_ascii_text_masked(self): self.assertFrameData( - Frame(OP_TEXT, "café".encode("utf-8")), + Frame(OP_TEXT, "café".encode()), b"\x81\x85\x64\xbe\xee\x7e\x07\xdf\x88\xbd\xcd", mask=True, ) From ebc9890d4f2a7b1675d50d4fea167b9107082e9a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 09:27:40 +0100 Subject: [PATCH 484/762] Remove redundant return types from docs. sphinx picks them from function signatures. This changes slightly the output e.g. ExtensionParameter becomes Tuple[str, str | None]. While this can be a bit less readable, it looks like an improvement because the information is available without needing to navigate to the definition of ExtensionParameter. --- src/websockets/client.py | 6 +++--- src/websockets/extensions/base.py | 13 ++++++------- src/websockets/legacy/auth.py | 2 +- src/websockets/legacy/handshake.py | 4 ++-- src/websockets/legacy/protocol.py | 11 +++++------ src/websockets/legacy/server.py | 11 +++++------ src/websockets/protocol.py | 6 +++--- src/websockets/server.py | 16 +++++++--------- src/websockets/uri.py | 2 +- 9 files changed, 33 insertions(+), 38 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index b2f622042..85bc81b47 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -101,7 +101,7 @@ def connect(self) -> Request: You can modify it before sending it, for example to add HTTP headers. Returns: - Request: WebSocket handshake request event to send to the server. + WebSocket handshake request event to send to the server. """ headers = Headers() @@ -213,7 +213,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: headers: WebSocket handshake response headers. Returns: - List[Extension]: List of accepted extensions. + List of accepted extensions. Raises: InvalidHandshake: to abort the handshake. @@ -271,7 +271,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: headers: WebSocket handshake response headers. Returns: - Optional[Subprotocol]: Subprotocol, if one was selected. + Subprotocol, if one was selected. """ subprotocol: Optional[Subprotocol] = None diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 6c481a46c..9eba6c9e7 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -32,7 +32,7 @@ def decode( max_size: maximum payload size in bytes. Returns: - Frame: Decoded frame. + Decoded frame. Raises: PayloadTooBig: if decoding the payload exceeds ``max_size``. @@ -48,7 +48,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: frame (Frame): outgoing frame. Returns: - Frame: Encoded frame. + Encoded frame. """ raise NotImplementedError @@ -68,7 +68,7 @@ def get_request_params(self) -> List[ExtensionParameter]: Build parameters to send to the server for this extension. Returns: - List[ExtensionParameter]: Parameters to send to the server. + Parameters to send to the server. """ raise NotImplementedError @@ -88,7 +88,7 @@ def process_response_params( accepted extensions. Returns: - Extension: An extension instance. + An extension instance. Raises: NegotiationError: if parameters aren't acceptable. @@ -121,9 +121,8 @@ def process_request_params( accepted extensions. Returns: - Tuple[List[ExtensionParameter], Extension]: To accept the offer, - parameters to send to the client for this extension and an - extension instance. + To accept the offer, parameters to send to the client for this + extension and an extension instance. Raises: NegotiationError: to reject the offer, if parameters received from diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e8d6b75d5..8217afedd 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -66,7 +66,7 @@ async def check_credentials(self, username: str, password: str) -> bool: password: HTTP Basic Auth password. Returns: - bool: :obj:`True` if the handshake should continue; + :obj:`True` if the handshake should continue; :obj:`False` if it should fail with an HTTP 401 error. """ diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index ad8faf040..5853c31db 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -24,7 +24,7 @@ def build_request(headers: Headers) -> str: headers: Handshake request headers. Returns: - str: ``key`` that must be passed to :func:`check_response`. + ``key`` that must be passed to :func:`check_response`. """ key = generate_key() @@ -48,7 +48,7 @@ def check_request(headers: Headers) -> str: headers: Handshake request headers. Returns: - str: ``key`` that must be passed to :func:`build_response`. + ``key`` that must be passed to :func:`build_response`. Raises: InvalidHandshake: If the handshake request is invalid. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 47f948b7a..a9fbd5a7a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -518,7 +518,7 @@ async def recv(self) -> Data: :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. Returns: - Data: A string (:class:`str`) for a Text_ frame. A bytestring + A string (:class:`str`) for a Text_ frame. A bytestring (:class:`bytes`) for a Binary_ frame. .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 @@ -805,7 +805,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: """ Send a Ping_. @@ -827,10 +827,9 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: containing four random bytes. Returns: - ~asyncio.Future[float]: A future that will be completed when the - corresponding pong is received. You can ignore it if you don't - intend to wait. The result of the future is the latency of the - connection in seconds. + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. :: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index d95bec4f6..297613591 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -349,8 +349,8 @@ async def process_request( request_headers: request headers. Returns: - Optional[Tuple[StatusLike, HeadersLike, bytes]]: :obj:`None` - to continue the WebSocket handshake normally. + Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to + continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, headers, and body, to abort the WebSocket handshake and return @@ -534,8 +534,7 @@ def select_subprotocol( server_subprotocols: list of subprotocols available on the server. Returns: - Optional[Subprotocol]: Selected subprotocol, if a common subprotocol - was found. + Selected subprotocol, if a common subprotocol was found. :obj:`None` to continue without a subprotocol. @@ -572,7 +571,7 @@ async def handshake( the handshake succeeds. Returns: - str: path of the URI of the request. + path of the URI of the request. Raises: InvalidHandshake: if the handshake fails. @@ -968,7 +967,7 @@ class Serve: outside of websockets. Returns: - WebSocketServer: WebSocket server. + WebSocket server. """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 765e6b9bb..342aba413 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -452,7 +452,7 @@ def events_received(self) -> List[Event]: Process resulting events, likely by passing them to the application. Returns: - List[Event]: Events read from the connection. + Events read from the connection. """ events, self.events = self.events, [] return events @@ -473,7 +473,7 @@ def data_to_send(self) -> List[bytes]: connection. Returns: - List[bytes]: Data to write to the connection. + Data to write to the connection. """ writes, self.writes = self.writes, [] @@ -490,7 +490,7 @@ def close_expected(self) -> bool: short timeout if the other side hasn't already closed it. Returns: - bool: Whether the TCP connection is expected to close soon. + Whether the TCP connection is expected to close soon. """ # We expect a TCP close if and only if we sent a close frame: diff --git a/src/websockets/server.py b/src/websockets/server.py index 191660553..58391d3cf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -213,7 +213,6 @@ def process_request( request: WebSocket handshake request received from the client. Returns: - Tuple[str, Optional[str], Optional[str]]: ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and ``Sec-WebSocket-Protocol`` headers for the handshake response. @@ -294,7 +293,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: headers: WebSocket handshake request headers. Returns: - Optional[Origin]: origin, if it is acceptable. + origin, if it is acceptable. Raises: InvalidHandshake: if the Origin header is invalid. @@ -344,8 +343,8 @@ def process_extensions( headers: WebSocket handshake request headers. Returns: - Tuple[Optional[str], List[Extension]]: ``Sec-WebSocket-Extensions`` - HTTP response header and list of accepted extensions. + ``Sec-WebSocket-Extensions`` HTTP response header and list of + accepted extensions. Raises: InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid. @@ -401,8 +400,8 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: headers: WebSocket handshake request headers. Returns: - Optional[Subprotocol]: Subprotocol, if one was selected; this is - also the value of the ``Sec-WebSocket-Protocol`` response header. + Subprotocol, if one was selected; this is also the value of the + ``Sec-WebSocket-Protocol`` response header. Raises: InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid. @@ -449,8 +448,7 @@ def select_subprotocol(protocol, subprotocols): subprotocols: list of subprotocols offered by the client. Returns: - Optional[Subprotocol]: Selected subprotocol, if a common subprotocol - was found. + Selected subprotocol, if a common subprotocol was found. :obj:`None` to continue without a subprotocol. @@ -499,7 +497,7 @@ def reject( text: HTTP response body; will be encoded to UTF-8. Returns: - Response: WebSocket handshake response event to send to the client. + WebSocket handshake response event to send to the client. """ # If a user passes an int instead of a HTTPStatus, fix it automatically. diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 385090f66..970020e26 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -66,7 +66,7 @@ def parse_uri(uri: str) -> WebSocketURI: uri: WebSocket URI. Returns: - WebSocketURI: Parsed WebSocket URI. + Parsed WebSocket URI. Raises: InvalidURI: if ``uri`` isn't a valid WebSocket URI. From c53fc3b7eed17c12c1b4db5d456b7921c2cde98f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 10:01:53 +0100 Subject: [PATCH 485/762] Start argument descriptions with uppercase letter. This change was automated with regex replaces: ^( Args:\n(?: .*\n(?: .*\n)*)*?(?: \w+(?: \(.*\))?): )([a-z]) $1\U$2 ^( Args:\n(?: .*\n(?: .*\n)*)*?(?: \w+(?: \(.*\))?): )([a-z]) $1\U$2 Also remove redundant type annotations. --- src/websockets/client.py | 12 ++++---- src/websockets/datastructures.py | 2 +- src/websockets/extensions/base.py | 18 +++++------- .../extensions/permessage_deflate.py | 22 +++++++-------- src/websockets/frames.py | 14 +++++----- src/websockets/headers.py | 6 ++-- src/websockets/http11.py | 12 ++++---- src/websockets/legacy/protocol.py | 11 +++----- src/websockets/legacy/server.py | 28 +++++++++---------- src/websockets/protocol.py | 6 ++-- src/websockets/server.py | 14 +++++----- src/websockets/streams.py | 8 +++--- src/websockets/sync/utils.py | 2 +- src/websockets/utils.py | 4 +-- 14 files changed, 76 insertions(+), 83 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 85bc81b47..028e7ce47 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -53,17 +53,17 @@ class ClientProtocol(Protocol): Args: wsuri: URI of the WebSocket server, parsed with :func:`~websockets.uri.parse_uri`. - origin: value of the ``Origin`` header. This is useful when connecting + origin: Value of the ``Origin`` header. This is useful when connecting to a server that validates the ``Origin`` header to defend against Cross-Site WebSocket Hijacking attacks. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + subprotocols: List of supported subprotocols, in order of decreasing preference. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; + logger: Logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index a0a648463..c2a5acfee 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -152,7 +152,7 @@ def get_all(self, key: str) -> List[str]: Return the (possibly empty) list of all values for a header. Args: - key: header name. + key: Header name. """ return self._dict.get(key.lower(), []) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 9eba6c9e7..cca3fe513 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -28,8 +28,8 @@ def decode( Decode an incoming frame. Args: - frame (Frame): incoming frame. - max_size: maximum payload size in bytes. + frame: Incoming frame. + max_size: Maximum payload size in bytes. Returns: Decoded frame. @@ -45,7 +45,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: Encode an outgoing frame. Args: - frame (Frame): outgoing frame. + frame: Outgoing frame. Returns: Encoded frame. @@ -82,10 +82,8 @@ def process_response_params( Process parameters received from the server. Args: - params (Sequence[ExtensionParameter]): parameters received from - the server for this extension. - accepted_extensions (Sequence[Extension]): list of previously - accepted extensions. + params: Parameters received from the server for this extension. + accepted_extensions: List of previously accepted extensions. Returns: An extension instance. @@ -115,10 +113,8 @@ def process_request_params( Process parameters received from the client. Args: - params (Sequence[ExtensionParameter]): parameters received from - the client for this extension. - accepted_extensions (Sequence[Extension]): list of previously - accepted extensions. + params: Parameters received from the client for this extension. + accepted_extensions: List of previously accepted extensions. Returns: To accept the offer, parameters to send to the client for this diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index b391837c6..edccac3ca 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -268,14 +268,14 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): value or to an integer value to include them with this value. Args: - server_no_context_takeover: prevent server from using context takeover. - client_no_context_takeover: prevent client from using context takeover. - server_max_window_bits: maximum size of the server's LZ77 sliding window + server_no_context_takeover: Prevent server from using context takeover. + client_no_context_takeover: Prevent client from using context takeover. + server_max_window_bits: Maximum size of the server's LZ77 sliding window in bits, between 8 and 15. - client_max_window_bits: maximum size of the client's LZ77 sliding window + client_max_window_bits: Maximum size of the client's LZ77 sliding window in bits, between 8 and 15, or :obj:`True` to indicate support without setting a limit. - compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. """ @@ -468,15 +468,15 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): value or to an integer value to include them with this value. Args: - server_no_context_takeover: prevent server from using context takeover. - client_no_context_takeover: prevent client from using context takeover. - server_max_window_bits: maximum size of the server's LZ77 sliding window + server_no_context_takeover: Prevent server from using context takeover. + client_no_context_takeover: Prevent client from using context takeover. + server_max_window_bits: Maximum size of the server's LZ77 sliding window in bits, between 8 and 15. - client_max_window_bits: maximum size of the client's LZ77 sliding window + client_max_window_bits: Maximum size of the client's LZ77 sliding window in bits, between 8 and 15. - compress_settings: additional keyword arguments for :func:`zlib.compressobj`, + compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, excluding ``wbits``. - require_client_max_window_bits: do not enable compression at all if + require_client_max_window_bits: Do not enable compression at all if client doesn't advertise support for ``client_max_window_bits``; the default behavior is to enable compression without enforcing ``client_max_window_bits``. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 63c35ed4d..e5e2af8b4 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -208,12 +208,12 @@ def parse( This is a generator-based coroutine. Args: - read_exact: generator-based coroutine that reads the requested + read_exact: Generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. - mask: whether the frame should be masked i.e. whether the read + mask: Whether the frame should be masked i.e. whether the read happens on the server side. - max_size: maximum payload size in bytes. - extensions: list of extensions, applied in reverse order. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. Raises: EOFError: if the connection is closed without a full WebSocket frame. @@ -280,9 +280,9 @@ def serialize( Serialize a WebSocket frame. Args: - mask: whether the frame should be masked i.e. whether the write + mask: Whether the frame should be masked i.e. whether the write happens on the client side. - extensions: list of extensions, applied in order. + extensions: List of extensions, applied in order. Raises: ProtocolError: if the frame contains incorrect values. @@ -432,7 +432,7 @@ def parse(cls, data: bytes) -> Close: Parse the payload of a close frame. Args: - data: payload of the close frame. + data: Payload of the close frame. Raises: ProtocolError: if data is ill-formed. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 9ae3035a5..8391ad26c 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -289,7 +289,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: Return a list of HTTP protocols. Args: - header: value of the ``Upgrade`` header. + header: Value of the ``Upgrade`` header. Raises: InvalidHeaderFormat: on invalid inputs. @@ -486,7 +486,7 @@ def build_www_authenticate_basic(realm: str) -> str: Build a ``WWW-Authenticate`` header for HTTP Basic Auth. Args: - realm: identifier of the protection space. + realm: Identifier of the protection space. """ # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 @@ -532,7 +532,7 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: Return a ``(username, password)`` tuple. Args: - header: value of the ``Authorization`` header. + header: Value of the ``Authorization`` header. Raises: InvalidHeaderFormat: on invalid inputs. diff --git a/src/websockets/http11.py b/src/websockets/http11.py index ec4e3b8b7..c0a96f878 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -93,7 +93,7 @@ def parse( body, it may be read from the data stream after :meth:`parse` returns. Args: - read_line: generator-based coroutine that reads a LF-terminated + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data Raises: @@ -193,11 +193,11 @@ def parse( characters. Other characters are represented with surrogate escapes. Args: - read_line: generator-based coroutine that reads a LF-terminated + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. - read_exact: generator-based coroutine that reads the requested + read_exact: Generator-based coroutine that reads the requested bytes or raises an exception if there isn't enough data. - read_to_eof: generator-based coroutine that reads until the end + read_to_eof: Generator-based coroutine that reads until the end of the stream. Raises: @@ -295,7 +295,7 @@ def parse_headers( Non-ASCII characters are represented with surrogate escapes. Args: - read_line: generator-based coroutine that reads a LF-terminated line + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. Raises: @@ -346,7 +346,7 @@ def parse_line( CRLF is stripped from the return value. Args: - read_line: generator-based coroutine that reads a LF-terminated line + read_line: Generator-based coroutine that reads a LF-terminated line or raises an exception if there isn't enough data. Raises: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index a9fbd5a7a..26d50a2cc 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -624,8 +624,7 @@ async def send( error or a network failure. Args: - message (Union[Data, Iterable[Data], AsyncIterable[Data]): message - to send. + message: Message to send. Raises: ConnectionClosed: When the connection is closed. @@ -822,9 +821,8 @@ async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: effect. Args: - data (Optional[Data]): payload of the ping; a string will be - encoded to UTF-8; or :obj:`None` to generate a payload - containing four random bytes. + data: Payload of the ping. A string will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. Returns: A future that will be completed when the corresponding pong is @@ -878,8 +876,7 @@ async def pong(self, data: Data = b"") -> None: wait, you should close the connection. Args: - data (Data): Payload of the pong. A string will be encoded to - UTF-8. + data: Payload of the pong. A string will be encoded to UTF-8. Raises: ConnectionClosed: When the connection is closed. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 297613591..4af7ed109 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -345,8 +345,8 @@ async def process_request( from shutting down. Args: - path: request path, including optional query string. - request_headers: request headers. + path: Request path, including optional query string. + request_headers: Request headers. Returns: Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to @@ -377,8 +377,8 @@ def process_origin( Handle the Origin HTTP request header. Args: - headers: request headers. - origins: optional list of acceptable origins. + headers: Request headers. + origins: Optional list of acceptable origins. Raises: InvalidOrigin: if the origin isn't acceptable. @@ -428,8 +428,8 @@ def process_extensions( order of extensions, may be implemented by overriding this method. Args: - headers: request headers. - extensions: optional list of supported extensions. + headers: Request headers. + extensions: Optional list of supported extensions. Raises: InvalidHandshake: to abort the handshake with an HTTP 400 error. @@ -488,8 +488,8 @@ def process_subprotocol( as the selected subprotocol. Args: - headers: request headers. - available_subprotocols: optional list of supported subprotocols. + headers: Request headers. + available_subprotocols: Optional list of supported subprotocols. Raises: InvalidHandshake: to abort the handshake with an HTTP 400 error. @@ -530,8 +530,8 @@ def select_subprotocol( subprotocol. Args: - client_subprotocols: list of subprotocols offered by the client. - server_subprotocols: list of subprotocols available on the server. + client_subprotocols: List of subprotocols offered by the client. + server_subprotocols: List of subprotocols available on the server. Returns: Selected subprotocol, if a common subprotocol was found. @@ -561,13 +561,13 @@ async def handshake( Perform the server side of the opening handshake. Args: - origins: list of acceptable values of the Origin HTTP header; + origins: List of acceptable values of the Origin HTTP header; include :obj:`None` if the lack of an origin is acceptable. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of + subprotocols: List of supported subprotocols, in order of decreasing preference. - extra_headers: arbitrary HTTP headers to add to the response when + extra_headers: Arbitrary HTTP headers to add to the response when the handshake succeeds. Returns: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 342aba413..99c9ee1a8 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -74,10 +74,10 @@ class Protocol: Args: side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; depending on ``side``, + logger: Logger for this connection; depending on ``side``, defaults to ``logging.getLogger("websockets.client")`` or ``logging.getLogger("websockets.server")``; see the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/server.py b/src/websockets/server.py index 58391d3cf..6711a0bba 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -53,22 +53,22 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: acceptable values of the ``Origin`` header; include + origins: Acceptable values of the ``Origin`` header; include :obj:`None` in the list if the lack of an origin is acceptable. This is useful for defending against Cross-Site WebSocket Hijacking attacks. - extensions: list of supported extensions, in order in which they + extensions: List of supported extensions, in order in which they should be tried. - subprotocols: list of supported subprotocols, in order of decreasing + subprotocols: List of supported subprotocols, in order of decreasing preference. select_subprotocol: Callback for selecting a subprotocol among those supported by the client and the server. It has the same signature as the :meth:`select_subprotocol` method, including a :class:`ServerProtocol` instance as first argument. - state: initial state of the WebSocket connection. - max_size: maximum size of incoming messages in bytes; + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. - logger: logger for this connection; + logger: Logger for this connection; defaults to ``logging.getLogger("websockets.client")``; see the :doc:`logging guide <../../topics/logging>` for details. @@ -445,7 +445,7 @@ def select_subprotocol(protocol, subprotocols): return "chat" Args: - subprotocols: list of subprotocols offered by the client. + subprotocols: List of subprotocols offered by the client. Returns: Selected subprotocol, if a common subprotocol was found. diff --git a/src/websockets/streams.py b/src/websockets/streams.py index f861d4bd2..d288cf0cc 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -26,7 +26,7 @@ def read_line(self, m: int) -> Generator[None, None, bytes]: The return value includes the LF character. Args: - m: maximum number bytes to read; this is a security limit. + m: Maximum number bytes to read; this is a security limit. Raises: EOFError: if the stream ends without a LF. @@ -58,7 +58,7 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: This is a generator-based coroutine. Args: - n: how many bytes to read. + n: How many bytes to read. Raises: EOFError: if the stream ends in less than ``n`` bytes. @@ -81,7 +81,7 @@ def read_to_eof(self, m: int) -> Generator[None, None, bytes]: This is a generator-based coroutine. Args: - m: maximum number bytes to read; this is a security limit. + m: Maximum number bytes to read; this is a security limit. Raises: RuntimeError: if the stream ends in more than ``m`` bytes. @@ -119,7 +119,7 @@ def feed_data(self, data: bytes) -> None: :meth:`feed_data` cannot be called after :meth:`feed_eof`. Args: - data: data to write. + data: Data to write. Raises: EOFError: if the stream has ended. diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 471f32e19..3364bdc2d 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -28,7 +28,7 @@ def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: Calculate a timeout from a deadline. Args: - raise_if_elapsed (bool): Whether to raise :exc:`TimeoutError` + raise_if_elapsed: Whether to raise :exc:`TimeoutError` if the deadline lapsed. Raises: diff --git a/src/websockets/utils.py b/src/websockets/utils.py index c40404906..62d2dc177 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -26,7 +26,7 @@ def accept_key(key: str) -> str: Compute the value of the Sec-WebSocket-Accept header. Args: - key: value of the Sec-WebSocket-Key header. + key: Value of the Sec-WebSocket-Key header. """ sha1 = hashlib.sha1((key + GUID).encode()).digest() @@ -38,7 +38,7 @@ def apply_mask(data: bytes, mask: bytes) -> bytes: Apply masking to the data of a WebSocket message. Args: - data: data to mask. + data: Data to mask. mask: 4-bytes mask. """ From 2865bdcc8b93f78d019aa0c605c86535dd66d026 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 10:14:25 +0100 Subject: [PATCH 486/762] Start exception descriptions with uppercase letter. This change was automated with regex replaces: ^( Raises:\n(?: .*\n(?: .*\n)*)*?(?: \w+): )([a-z]) $1\U$2 ^( Raises:\n(?: .*\n(?: .*\n)*)*?(?: \w+): )([a-z]) $1\U$2 --- src/websockets/client.py | 4 ++-- src/websockets/extensions/base.py | 6 +++--- src/websockets/frames.py | 22 +++++++++++----------- src/websockets/headers.py | 30 +++++++++++++++--------------- src/websockets/http11.py | 24 ++++++++++++------------ src/websockets/legacy/server.py | 10 +++++----- src/websockets/protocol.py | 16 ++++++++-------- src/websockets/server.py | 12 ++++++------ src/websockets/streams.py | 12 ++++++------ src/websockets/uri.py | 2 +- 10 files changed, 69 insertions(+), 69 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 028e7ce47..633b1960b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -144,7 +144,7 @@ def process_response(self, response: Response) -> None: request: WebSocket handshake response received from the server. Raises: - InvalidHandshake: if the handshake response is invalid. + InvalidHandshake: If the handshake response is invalid. """ @@ -216,7 +216,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: List of accepted extensions. Raises: - InvalidHandshake: to abort the handshake. + InvalidHandshake: To abort the handshake. """ accepted_extensions: List[Extension] = [] diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index cca3fe513..7446c990c 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -35,7 +35,7 @@ def decode( Decoded frame. Raises: - PayloadTooBig: if decoding the payload exceeds ``max_size``. + PayloadTooBig: If decoding the payload exceeds ``max_size``. """ raise NotImplementedError @@ -89,7 +89,7 @@ def process_response_params( An extension instance. Raises: - NegotiationError: if parameters aren't acceptable. + NegotiationError: If parameters aren't acceptable. """ raise NotImplementedError @@ -121,7 +121,7 @@ def process_request_params( extension and an extension instance. Raises: - NegotiationError: to reject the offer, if parameters received from + NegotiationError: To reject the offer, if parameters received from the client aren't acceptable. """ diff --git a/src/websockets/frames.py b/src/websockets/frames.py index e5e2af8b4..201bc5068 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -216,10 +216,10 @@ def parse( extensions: List of extensions, applied in reverse order. Raises: - EOFError: if the connection is closed without a full WebSocket frame. - UnicodeDecodeError: if the frame contains invalid UTF-8. - PayloadTooBig: if the frame's payload size exceeds ``max_size``. - ProtocolError: if the frame contains incorrect values. + EOFError: If the connection is closed without a full WebSocket frame. + UnicodeDecodeError: If the frame contains invalid UTF-8. + PayloadTooBig: If the frame's payload size exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. """ # Read the header. @@ -285,7 +285,7 @@ def serialize( extensions: List of extensions, applied in order. Raises: - ProtocolError: if the frame contains incorrect values. + ProtocolError: If the frame contains incorrect values. """ self.check() @@ -334,7 +334,7 @@ def check(self) -> None: Check that reserved bits and opcode have acceptable values. Raises: - ProtocolError: if a reserved bit or the opcode is invalid. + ProtocolError: If a reserved bit or the opcode is invalid. """ if self.rsv1 or self.rsv2 or self.rsv3: @@ -360,7 +360,7 @@ def prepare_data(data: Data) -> Tuple[int, bytes]: object. Raises: - TypeError: if ``data`` doesn't have a supported type. + TypeError: If ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -383,7 +383,7 @@ def prepare_ctrl(data: Data) -> bytes: If ``data`` is a bytes-like object, return a :class:`bytes` object. Raises: - TypeError: if ``data`` doesn't have a supported type. + TypeError: If ``data`` doesn't have a supported type. """ if isinstance(data, str): @@ -435,8 +435,8 @@ def parse(cls, data: bytes) -> Close: data: Payload of the close frame. Raises: - ProtocolError: if data is ill-formed. - UnicodeDecodeError: if the reason isn't valid UTF-8. + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. """ if len(data) >= 2: @@ -463,7 +463,7 @@ def check(self) -> None: Check that the close code has a valid value for a close frame. Raises: - ProtocolError: if the close code is invalid. + ProtocolError: If the close code is invalid. """ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 8391ad26c..463df3061 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -103,7 +103,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _token_re.match(header, pos) @@ -127,7 +127,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, i Return the unquoted value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _quoted_string_re.match(header, pos) @@ -180,7 +180,7 @@ def parse_list( Return a list of items. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient @@ -234,7 +234,7 @@ def parse_connection_option( Return the protocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -251,7 +251,7 @@ def parse_connection(header: str) -> List[ConnectionOption]: header: value of the ``Connection`` header. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_connection_option, header, 0, "Connection") @@ -271,7 +271,7 @@ def parse_upgrade_protocol( Return the protocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _protocol_re.match(header, pos) @@ -292,7 +292,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: header: Value of the ``Upgrade`` header. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") @@ -307,7 +307,7 @@ def parse_extension_item_param( Return a ``(name, value)`` pair and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Extract parameter name. @@ -344,7 +344,7 @@ def parse_extension_item( list of ``(name, value)`` pairs, and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ # Extract extension name. @@ -379,7 +379,7 @@ def parse_extension(header: str) -> List[ExtensionHeader]: Parameter values are :obj:`None` when no value is provided. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") @@ -431,7 +431,7 @@ def parse_subprotocol_item( Return the subprotocol value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ item, pos = parse_token(header, pos, header_name) @@ -445,7 +445,7 @@ def parse_subprotocol(header: str) -> List[Subprotocol]: Return a list of WebSocket subprotocols. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") @@ -505,7 +505,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: Return the token value and the new position. Raises: - InvalidHeaderFormat: on invalid inputs. + InvalidHeaderFormat: On invalid inputs. """ match = _token68_re.match(header, pos) @@ -535,8 +535,8 @@ def parse_authorization_basic(header: str) -> Tuple[str, str]: header: Value of the ``Authorization`` header. Raises: - InvalidHeaderFormat: on invalid inputs. - InvalidHeaderValue: on unsupported inputs. + InvalidHeaderFormat: On invalid inputs. + InvalidHeaderValue: On unsupported inputs. """ # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 diff --git a/src/websockets/http11.py b/src/websockets/http11.py index c0a96f878..6fe775eec 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -97,9 +97,9 @@ def parse( line or raises an exception if there isn't enough data Raises: - EOFError: if the connection is closed without a full HTTP request. - SecurityError: if the request exceeds a security limit. - ValueError: if the request isn't well formatted. + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 @@ -201,10 +201,10 @@ def parse( of the stream. Raises: - EOFError: if the connection is closed without a full HTTP response. - SecurityError: if the response exceeds a security limit. - LookupError: if the response isn't well formatted. - ValueError: if the response isn't well formatted. + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + LookupError: If the response isn't well formatted. + ValueError: If the response isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 @@ -299,9 +299,9 @@ def parse_headers( or raises an exception if there isn't enough data. Raises: - EOFError: if the connection is closed without complete headers. - SecurityError: if the request exceeds a security limit. - ValueError: if the request isn't well formatted. + EOFError: If the connection is closed without complete headers. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. """ # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 @@ -350,8 +350,8 @@ def parse_line( or raises an exception if there isn't enough data. Raises: - EOFError: if the connection is closed without a CRLF. - SecurityError: if the response exceeds a security limit. + EOFError: If the connection is closed without a CRLF. + SecurityError: If the response exceeds a security limit. """ try: diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4af7ed109..e8cf8220f 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -271,7 +271,7 @@ async def read_http_request(self) -> Tuple[str, Headers]: after this coroutine returns. Raises: - InvalidMessage: if the HTTP message is malformed or isn't an + InvalidMessage: If the HTTP message is malformed or isn't an HTTP/1.1 GET request. """ @@ -381,7 +381,7 @@ def process_origin( origins: Optional list of acceptable origins. Raises: - InvalidOrigin: if the origin isn't acceptable. + InvalidOrigin: If the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -432,7 +432,7 @@ def process_extensions( extensions: Optional list of supported extensions. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: To abort the handshake with an HTTP 400 error. """ response_header_value: Optional[str] = None @@ -492,7 +492,7 @@ def process_subprotocol( available_subprotocols: Optional list of supported subprotocols. Raises: - InvalidHandshake: to abort the handshake with an HTTP 400 error. + InvalidHandshake: To abort the handshake with an HTTP 400 error. """ subprotocol: Optional[Subprotocol] = None @@ -574,7 +574,7 @@ async def handshake( path of the URI of the request. Raises: - InvalidHandshake: if the handshake fails. + InvalidHandshake: If the handshake fails. """ path, request_headers = await self.read_http_request() diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 99c9ee1a8..6851f3b1f 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -217,7 +217,7 @@ def close_exc(self) -> ConnectionClosed: known only once the connection is closed. Raises: - AssertionError: if the connection isn't closed yet. + AssertionError: If the connection isn't closed yet. """ assert self.state is CLOSED, "connection isn't closed yet" @@ -252,7 +252,7 @@ def receive_data(self, data: bytes) -> None: - You should call :meth:`events_received` and process resulting events. Raises: - EOFError: if :meth:`receive_eof` was called earlier. + EOFError: If :meth:`receive_eof` was called earlier. """ self.reader.feed_data(data) @@ -270,7 +270,7 @@ def receive_eof(self) -> None: any new events. Raises: - EOFError: if :meth:`receive_eof` was called earlier. + EOFError: If :meth:`receive_eof` was called earlier. """ self.reader.feed_eof() @@ -292,7 +292,7 @@ def send_continuation(self, data: bytes, fin: bool) -> None: of a fragmented message and to :obj:`False` otherwise. Raises: - ProtocolError: if a fragmented message isn't in progress. + ProtocolError: If a fragmented message isn't in progress. """ if not self.expect_continuation_frame: @@ -313,7 +313,7 @@ def send_text(self, data: bytes, fin: bool = True) -> None: a fragmented message. Raises: - ProtocolError: if a fragmented message is in progress. + ProtocolError: If a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -334,7 +334,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: a fragmented message. Raises: - ProtocolError: if a fragmented message is in progress. + ProtocolError: If a fragmented message is in progress. """ if self.expect_continuation_frame: @@ -354,7 +354,7 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: reason: close reason. Raises: - ProtocolError: if a fragmented message is being sent, if the code + ProtocolError: If a fragmented message is being sent, if the code isn't valid, or if a reason is provided without a code """ @@ -412,7 +412,7 @@ def fail(self, code: int, reason: str = "") -> None: reason: close reason Raises: - ProtocolError: if the code isn't valid. + ProtocolError: If the code isn't valid. """ # 7.1.7. Fail the WebSocket Connection diff --git a/src/websockets/server.py b/src/websockets/server.py index 6711a0bba..330e54f37 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -217,7 +217,7 @@ def process_request( ``Sec-WebSocket-Protocol`` headers for the handshake response. Raises: - InvalidHandshake: if the handshake request is invalid; + InvalidHandshake: If the handshake request is invalid; then the server must return 400 Bad Request error. """ @@ -296,8 +296,8 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: origin, if it is acceptable. Raises: - InvalidHandshake: if the Origin header is invalid. - InvalidOrigin: if the origin isn't acceptable. + InvalidHandshake: If the Origin header is invalid. + InvalidOrigin: If the origin isn't acceptable. """ # "The user agent MUST NOT include more than one Origin header field" @@ -347,7 +347,7 @@ def process_extensions( accepted extensions. Raises: - InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid. + InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. """ response_header_value: Optional[str] = None @@ -404,7 +404,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: ``Sec-WebSocket-Protocol`` response header. Raises: - InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid. + InvalidHandshake: If the Sec-WebSocket-Subprotocol header is invalid. """ subprotocols: Sequence[Subprotocol] = sum( @@ -453,7 +453,7 @@ def select_subprotocol(protocol, subprotocols): :obj:`None` to continue without a subprotocol. Raises: - NegotiationError: custom implementations may raise this exception + NegotiationError: Custom implementations may raise this exception to abort the handshake with an HTTP 400 error. """ diff --git a/src/websockets/streams.py b/src/websockets/streams.py index d288cf0cc..956f139d4 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -29,8 +29,8 @@ def read_line(self, m: int) -> Generator[None, None, bytes]: m: Maximum number bytes to read; this is a security limit. Raises: - EOFError: if the stream ends without a LF. - RuntimeError: if the stream ends in more than ``m`` bytes. + EOFError: If the stream ends without a LF. + RuntimeError: If the stream ends in more than ``m`` bytes. """ n = 0 # number of bytes to read @@ -61,7 +61,7 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: n: How many bytes to read. Raises: - EOFError: if the stream ends in less than ``n`` bytes. + EOFError: If the stream ends in less than ``n`` bytes. """ assert n >= 0 @@ -84,7 +84,7 @@ def read_to_eof(self, m: int) -> Generator[None, None, bytes]: m: Maximum number bytes to read; this is a security limit. Raises: - RuntimeError: if the stream ends in more than ``m`` bytes. + RuntimeError: If the stream ends in more than ``m`` bytes. """ while not self.eof: @@ -122,7 +122,7 @@ def feed_data(self, data: bytes) -> None: data: Data to write. Raises: - EOFError: if the stream has ended. + EOFError: If the stream has ended. """ if self.eof: @@ -136,7 +136,7 @@ def feed_eof(self) -> None: :meth:`feed_eof` cannot be called more than once. Raises: - EOFError: if the stream has ended. + EOFError: If the stream has ended. """ if self.eof: diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 970020e26..8cf581743 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -69,7 +69,7 @@ def parse_uri(uri: str) -> WebSocketURI: Parsed WebSocket URI. Raises: - InvalidURI: if ``uri`` isn't a valid WebSocket URI. + InvalidURI: If ``uri`` isn't a valid WebSocket URI. """ parsed = urllib.parse.urlparse(uri) From 908c7ba23168da52d0006d67bc068e315e90daae Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 6 Jan 2024 10:02:56 +0100 Subject: [PATCH 487/762] Clean up sync message assembler. Remove support for control frames, which isn't actually used. --- src/websockets/sync/messages.py | 34 +++++----- tests/sync/test_messages.py | 113 ++++++++++++-------------------- 2 files changed, 60 insertions(+), 87 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index d98ff855b..dcba183d9 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -5,7 +5,7 @@ import threading from typing import Iterator, List, Optional, cast -from ..frames import Frame, Opcode +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -25,8 +25,11 @@ def __init__(self) -> None: # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to ensure proper interleaving of - # writing and reading messages. + # We create a latch with two events to synchronize the production of + # frames and the consumption of messages (or frames) without a buffer. + # This design requires a switch between the library thread and the user + # thread for each message; that shouldn't be a performance bottleneck. + # put() sets this event to tell get() that a message can be fetched. self.message_complete = threading.Event() # get() sets this event to let put() that the message was fetched. @@ -72,8 +75,10 @@ def get(self, timeout: Optional[float] = None) -> Data: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` + RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. + TimeoutError: If a timeout is provided and elapses before a + complete message is received. """ with self.mutex: @@ -131,7 +136,7 @@ def get_iter(self) -> Iterator[Data]: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:``get_iter` + RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. """ @@ -159,11 +164,10 @@ def get_iter(self) -> Iterator[Data]: self.get_in_progress = True # Locking with get_in_progress ensures only one thread can get here. - yield from chunks - while True: - chunk = self.chunks_queue.get() - if chunk is None: - break + chunk: Optional[Data] + for chunk in chunks: + yield chunk + while (chunk := self.chunks_queue.get()) is not None: yield chunk with self.mutex: @@ -205,15 +209,12 @@ def put(self, frame: Frame) -> None: if self.put_in_progress: raise RuntimeError("put is already running") - if frame.opcode is Opcode.TEXT: + if frame.opcode is OP_TEXT: self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is Opcode.BINARY: + elif frame.opcode is OP_BINARY: self.decoder = None - elif frame.opcode is Opcode.CONT: - pass else: - # Ignore control frames. - return + assert frame.opcode is OP_CONT data: Data if self.decoder is not None: @@ -242,6 +243,7 @@ def put(self, frame: Frame) -> None: self.put_in_progress = True # Release the lock to allow get() to run and eventually set the event. + # Locking with put_in_progress ensures only one coroutine can get here. self.message_fetched.wait() with self.mutex: diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 825eb8797..c134b8304 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,6 +1,6 @@ import time -from websockets.frames import OP_BINARY, OP_CONT, OP_PING, OP_PONG, OP_TEXT, Frame +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from websockets.sync.messages import * from ..utils import MS @@ -350,76 +350,6 @@ def test_get_with_timeout_times_out(self): with self.assertRaises(TimeoutError): self.assembler.get(MS) - # Test control frames - - def test_control_frame_before_message_is_ignored(self): - """get ignores control frames between messages.""" - - def putter(): - self.assembler.put(Frame(OP_PING, b"")) - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - - self.assertEqual(message, "café") - - def test_control_frame_in_fragmented_message_is_ignored(self): - """get ignores control frames within fragmented messages.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_PING, b"")) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_PONG, b"")) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - - self.assertEqual(message, b"tea") - - # Test concurrency - - def test_get_fails_when_get_is_running(self): - """get cannot be called concurrently with itself.""" - with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): - self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_fails_when_get_iter_is_running(self): - """get cannot be called concurrently with get_iter.""" - with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): - self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_iter_fails_when_get_is_running(self): - """get_iter cannot be called concurrently with get.""" - with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): - list(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_get_iter_fails_when_get_iter_is_running(self): - """get_iter cannot be called concurrently with itself.""" - with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): - list(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently with itself.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - with self.assertRaises(RuntimeError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread - # Test termination def test_get_fails_when_interrupted_by_close(self): @@ -477,3 +407,44 @@ def test_close_is_idempotent(self): """close can be called multiple times safely.""" self.assembler.close() self.assembler.close() + + # Test (non-)concurrency + + def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + with self.run_in_thread(self.assembler.get): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + with self.run_in_thread(lambda: list(self.assembler.get_iter())): + with self.assertRaises(RuntimeError): + list(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread + + def test_put_fails_when_put_is_running(self): + """put cannot be called concurrently with itself.""" + + def putter(): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + with self.run_in_thread(putter): + with self.assertRaises(RuntimeError): + self.assembler.put(Frame(OP_BINARY, b"tea")) + self.assembler.get() # unblock other thread From e21811e751f3f4fef18ad13b1b6f7064be004af6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 11:12:27 +0100 Subject: [PATCH 488/762] Rename ssl_context to ssl in sync implementation. --- docs/project/changelog.rst | 14 ++++++++++- src/websockets/sync/client.py | 27 ++++++++++++-------- src/websockets/sync/server.py | 21 ++++++++++------ tests/sync/client.py | 3 ++- tests/sync/test_client.py | 47 ++++++++++++++++++++--------------- tests/sync/test_server.py | 23 +++++++++++------ 6 files changed, 89 insertions(+), 46 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 963353d0e..e288831be 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,11 +25,23 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. -12.1 +13.0 ---- *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` + and :func:`~sync.server.serve` is renamed to ``ssl``. + :class: note + + This aligns the API of the :mod:`threading` implementation with the + :mod:`asyncio` implementation. + + For backwards compatibility, ``ssl_context`` is still supported. + New features ............ diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 79af0132f..6faca7789 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -1,8 +1,9 @@ from __future__ import annotations import socket -import ssl +import ssl as ssl_module import threading +import warnings from typing import Any, Optional, Sequence, Type from ..client import ClientProtocol @@ -128,7 +129,7 @@ def connect( *, # TCP/TLS sock: Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, + ssl: Optional[ssl_module.SSLContext] = None, server_hostname: Optional[str] = None, # WebSocket origin: Optional[Origin] = None, @@ -166,7 +167,7 @@ def connect( sock: Preexisting TCP socket. ``sock`` overrides the host and port from ``uri``. You may call :func:`socket.create_connection` to create a suitable TCP socket. - ssl_context: Configuration for enabling TLS on the connection. + ssl: Configuration for enabling TLS on the connection. server_hostname: Host name for the TLS handshake. ``server_hostname`` overrides the host name from ``uri``. origin: Value of the ``Origin`` header, for servers that require it. @@ -207,9 +208,14 @@ def connect( # Process parameters + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + wsuri = parse_uri(uri) - if not wsuri.secure and ssl_context is not None: - raise TypeError("ssl_context argument is incompatible with a ws:// URI") + if not wsuri.secure and ssl is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) @@ -259,12 +265,12 @@ def connect( # Initialize TLS wrapper and perform TLS handshake if wsuri.secure: - if ssl_context is None: - ssl_context = ssl.create_default_context() + if ssl is None: + ssl = ssl_module.create_default_context() if server_hostname is None: server_hostname = wsuri.host sock.settimeout(deadline.timeout()) - sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) # Initialize WebSocket connection @@ -318,12 +324,13 @@ def unix_connect( Args: path: File system path to the Unix socket. uri: URI of the WebSocket server. ``uri`` defaults to - ``ws://localhost/`` or, when a ``ssl_context`` is provided, to + ``ws://localhost/`` or, when a ``ssl`` is provided, to ``wss://localhost/``. """ if uri is None: - if kwargs.get("ssl_context") is None: + # Backwards compatibility: ssl used to be called ssl_context. + if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None: uri = "ws://localhost/" else: uri = "wss://localhost/" diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index c19992849..fa6087d54 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -5,9 +5,10 @@ import os import selectors import socket -import ssl +import ssl as ssl_module import sys import threading +import warnings from types import TracebackType from typing import Any, Callable, Optional, Sequence, Type @@ -268,7 +269,7 @@ def serve( *, # TCP/TLS sock: Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, + ssl: Optional[ssl_module.SSLContext] = None, # WebSocket origins: Optional[Sequence[Optional[Origin]]] = None, extensions: Optional[Sequence[ServerExtensionFactory]] = None, @@ -337,7 +338,7 @@ def handler(websocket): sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``. You may call :func:`socket.create_server` to create a suitable TCP socket. - ssl_context: Configuration for enabling TLS on the connection. + ssl: Configuration for enabling TLS on the connection. origins: Acceptable values of the ``Origin`` header, for defending against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` in the list if the lack of an origin is acceptable. @@ -386,6 +387,11 @@ def handler(websocket): # Process parameters + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + if subprotocols is not None: validate_subprotocols(subprotocols) @@ -417,8 +423,8 @@ def handler(websocket): # Initialize TLS wrapper - if ssl_context is not None: - sock = ssl_context.wrap_socket( + if ssl is not None: + sock = ssl.wrap_socket( sock, server_side=True, # Delay TLS handshake until after we set a timeout on the socket. @@ -441,9 +447,10 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: # Perform TLS handshake - if ssl_context is not None: + if ssl is not None: sock.settimeout(deadline.timeout()) - assert isinstance(sock, ssl.SSLSocket) # mypy cannot figure this out + # mypy cannot figure this out + assert isinstance(sock, ssl_module.SSLSocket) sock.do_handshake() sock.settimeout(None) diff --git a/tests/sync/client.py b/tests/sync/client.py index 683893e88..bb4855c7f 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -25,7 +25,8 @@ def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): else: assert isinstance(wsuri_or_server, WebSocketServer) if secure is None: - secure = "ssl_context" in kwargs + # Backwards compatibility: ssl used to be called ssl_context. + secure = "ssl" in kwargs or "ssl_context" in kwargs protocol = "wss" if secure else "ws" host, port = wsuri_or_server.socket.getsockname() wsuri = f"{protocol}://{host}:{port}{resource_name}" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index c900f3b0f..fa363debf 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,7 +7,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..utils import MS, temp_unix_socket_path +from ..utils import MS, DeprecationTestCase, temp_unix_socket_path from .client import CLIENT_CONTEXT, run_client, run_unix_client from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server @@ -137,18 +137,18 @@ def close_connection(self, request): class SecureClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + with run_server(ssl=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -156,17 +156,17 @@ def test_set_server_hostname_implicitly(self): def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, server_hostname="overridden", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaisesRegex( ssl.SSLCertVerificationError, r"certificate verify failed: self[ -]signed certificate", @@ -177,15 +177,13 @@ def test_reject_invalid_server_certificate(self): def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaisesRegex( ssl.SSLCertVerificationError, r"certificate verify failed: Hostname mismatch", ): # This hostname isn't included in the test certificate. - with run_client( - server, ssl_context=CLIENT_CONTEXT, server_hostname="invalid" - ): + with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): self.fail("did not raise") @@ -212,8 +210,8 @@ class SecureUnixClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): - with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + with run_unix_server(path, ssl=SERVER_CONTEXT): + with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") @@ -221,23 +219,23 @@ def test_set_server_hostname(self): """Client sets server_hostname to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): + with run_unix_server(path, ssl=SERVER_CONTEXT): with run_unix_client( path, - ssl_context=CLIENT_CONTEXT, + ssl=CLIENT_CONTEXT, uri="wss://overridden/", ) as client: self.assertEqual(client.socket.server_hostname, "overridden") class ClientUsageErrorsTests(unittest.TestCase): - def test_ssl_context_without_secure_uri(self): - """Client rejects ssl_context when URI isn't secure.""" + def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" with self.assertRaisesRegex( TypeError, - "ssl_context argument is incompatible with a ws:// URI", + "ssl argument is incompatible with a ws:// URI", ): - connect("ws://localhost/", ssl_context=CLIENT_CONTEXT) + connect("ws://localhost/", ssl=CLIENT_CONTEXT) def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" @@ -272,3 +270,12 @@ def test_unsupported_compression(self): "unsupported compression: False", ): connect("ws://localhost/", compression=False) + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_ssl_context_argument(self): + """Client supports the deprecated ssl_context argument.""" + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertDeprecationWarning("ssl_context was renamed to ssl"): + with run_client(server, ssl_context=CLIENT_CONTEXT): + pass diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f9db84246..5e7e79c52 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -14,7 +14,7 @@ from websockets.http11 import Request, Response from websockets.sync.server import * -from ..utils import MS, temp_unix_socket_path +from ..utils import MS, DeprecationTestCase, temp_unix_socket_path from .client import CLIENT_CONTEXT, run_client, run_unix_client from .server import ( SERVER_CONTEXT, @@ -274,20 +274,20 @@ def handler(sock, addr): class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl_context=CLIENT_CONTEXT) as client: + with run_server(ssl=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" - with run_server(ssl_context=SERVER_CONTEXT, open_timeout=MS) as server: + with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: with socket.create_connection(server.socket.getsockname()) as sock: self.assertEqual(sock.recv(4096), b"") def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" - with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_server(ssl=SERVER_CONTEXT) as server: # Patch handler to record a reference to the thread running it. server_thread = None conn_received = threading.Event() @@ -325,8 +325,8 @@ class SecureUnixServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: - with run_unix_server(path, ssl_context=SERVER_CONTEXT): - with run_unix_client(path, ssl_context=CLIENT_CONTEXT) as client: + with run_unix_server(path, ssl=SERVER_CONTEXT): + with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") @@ -386,3 +386,12 @@ def test_shutdown(self): # Check that the server socket is closed. with self.assertRaises(OSError): server.socket.accept() + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_ssl_context_argument(self): + """Client supports the deprecated ssl_context argument.""" + with self.assertDeprecationWarning("ssl_context was renamed to ssl"): + with run_server(ssl_context=SERVER_CONTEXT) as server: + with run_client(server, ssl=CLIENT_CONTEXT): + pass From 45d8de7495ea33724bf93d753d65cad932472aac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jan 2024 21:06:38 +0100 Subject: [PATCH 489/762] Standardize style for testing exceptions. --- tests/legacy/test_client_server.py | 19 +++-- tests/legacy/test_http.py | 76 +++++++++++++++----- tests/sync/test_client.py | 90 ++++++++++++----------- tests/sync/test_connection.py | 54 +++++++------- tests/sync/test_server.py | 110 ++++++++++++++++------------- tests/test_server.py | 73 +++++++++++++------ 6 files changed, 265 insertions(+), 157 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c49d91b70..4a21f7cea 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1331,20 +1331,24 @@ def test_checking_origin_succeeds(self): @with_server(origins=["http://localhost"]) def test_checking_origin_fails(self): - with self.assertRaisesRegex( - InvalidHandshake, "server rejected WebSocket connection: HTTP 403" - ): + with self.assertRaises(InvalidHandshake) as raised: self.start_client(origin="http://otherhost") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) @with_server(origins=["http://localhost"]) def test_checking_origins_fails_with_multiple_headers(self): - with self.assertRaisesRegex( - InvalidHandshake, "server rejected WebSocket connection: HTTP 400" - ): + with self.assertRaises(InvalidHandshake) as raised: self.start_client( origin="http://localhost", extra_headers=[("Origin", "http://otherhost")], ) + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) @with_server(origins=[None]) @with_client() @@ -1574,8 +1578,9 @@ async def run_client(): pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: - with self.assertRaisesRegex(Exception, "BOOM"): + with self.assertRaises(Exception) as raised: self.loop.run_until_complete(run_client()) + self.assertEqual(str(raised.exception), "BOOM") # Iteration 1 self.assertEqual( diff --git a/tests/legacy/test_http.py b/tests/legacy/test_http.py index 15d53e08d..76af61122 100644 --- a/tests/legacy/test_http.py +++ b/tests/legacy/test_http.py @@ -31,30 +31,48 @@ async def test_read_request(self): async def test_read_request_empty(self): self.stream.feed_eof() - with self.assertRaisesRegex( - EOFError, "connection closed while reading HTTP request line" - ): + with self.assertRaises(EOFError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "connection closed while reading HTTP request line", + ) async def test_read_request_invalid_request_line(self): self.stream.feed_data(b"GET /\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP request line: GET /"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP request line: GET /", + ) async def test_read_request_unsupported_method(self): self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP method: OPTIONS"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP method: OPTIONS", + ) async def test_read_request_unsupported_version(self): self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: HTTP/1.0", + ) async def test_read_request_invalid_header(self): self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + with self.assertRaises(ValueError) as raised: await read_request(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP header line: Oops", + ) async def test_read_response(self): # Example from the protocol overview in RFC 6455 @@ -73,40 +91,66 @@ async def test_read_response(self): async def test_read_response_empty(self): self.stream.feed_eof() - with self.assertRaisesRegex( - EOFError, "connection closed while reading HTTP status line" - ): + with self.assertRaises(EOFError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "connection closed while reading HTTP status line", + ) async def test_read_request_invalid_status_line(self): self.stream.feed_data(b"Hello!\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP status line: Hello!"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP status line: Hello!", + ) async def test_read_response_unsupported_version(self): self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: HTTP/1.0", + ) async def test_read_response_invalid_status(self): self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP status code: OMG"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP status code: OMG", + ) async def test_read_response_unsupported_status(self): self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") - with self.assertRaisesRegex(ValueError, "unsupported HTTP status code: 007"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "unsupported HTTP status code: 007", + ) async def test_read_response_invalid_reason(self): self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP reason phrase: \\x7f"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP reason phrase: \x7f", + ) async def test_read_response_invalid_header(self): self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n") - with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"): + with self.assertRaises(ValueError) as raised: await read_response(self.stream) + self.assertEqual( + str(raised.exception), + "invalid HTTP header line: Oops", + ) async def test_header_name(self): self.stream.feed_data(b"foo bar: baz qux\r\n\r\n") diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index fa363debf..c403b9632 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -28,12 +28,13 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. with run_server(do_nothing, process_response=remove_accept_header) as server: - with self.assertRaisesRegex( - InvalidHandshake, - "missing Sec-WebSocket-Accept header", - ): + with self.assertRaises(InvalidHandshake) as raised: with run_client(server, close_timeout=MS): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) def test_tcp_connection_fails(self): """Client fails to connect to server.""" @@ -107,15 +108,16 @@ def stall_connection(self, request): # Use a connection handler that exits immediately to avoid an exception. with run_server(do_nothing, process_request=stall_connection) as server: try: - with self.assertRaisesRegex( - TimeoutError, - "timed out during handshake", - ): + with self.assertRaises(TimeoutError) as raised: # While it shouldn't take 50ms to open a connection, this # test becomes flaky in CI when setting a smaller timeout, # even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR. with run_client(server, open_timeout=5 * MS): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) finally: gate.set() @@ -126,12 +128,13 @@ def close_connection(self, request): self.close_socket() with run_server(process_request=close_connection) as server: - with self.assertRaisesRegex( - ConnectionError, - "connection closed during handshake", - ): + with self.assertRaises(ConnectionError) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connection closed during handshake", + ) class SecureClientTests(unittest.TestCase): @@ -167,24 +170,26 @@ def test_set_server_hostname_explicitly(self): def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaisesRegex( - ssl.SSLCertVerificationError, - r"certificate verify failed: self[ -]signed certificate", - ): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. with run_client(server, secure=True): self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaisesRegex( - ssl.SSLCertVerificationError, - r"certificate verify failed: Hostname mismatch", - ): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): self.fail("did not raise") + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception), + ) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -231,45 +236,50 @@ def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.TestCase): def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaisesRegex( - TypeError, - "ssl argument is incompatible with a ws:// URI", - ): + with self.assertRaises(TypeError) as raised: connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" - with self.assertRaisesRegex( - TypeError, - "missing path argument", - ): + with self.assertRaises(TypeError) as raised: unix_connect() + self.assertEqual( + str(raised.exception), + "missing path argument", + ) def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaisesRegex( - TypeError, - "path and sock arguments are incompatible", - ): + with self.assertRaises(TypeError) as raised: unix_connect(path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock arguments are incompatible", + ) def test_invalid_subprotocol(self): """Client rejects single value of subprotocols.""" - with self.assertRaisesRegex( - TypeError, - "subprotocols must be a list", - ): + with self.assertRaises(TypeError) as raised: connect("ws://localhost/", subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) def test_unsupported_compression(self): """Client rejects incorrect value of compression.""" - with self.assertRaisesRegex( - ValueError, - "unsupported compression: False", - ): + with self.assertRaises(ValueError) as raised: connect("ws://localhost/", compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) class BackwardsCompatibilityTests(DeprecationTestCase): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index e128425d8..953c8c253 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -177,12 +177,13 @@ def test_recv_during_recv(self): recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + self.connection.recv() + self.assertEqual( + str(raised.exception), "cannot call recv while another thread " "is already running recv or recv_streaming", - ): - self.connection.recv() + ) self.remote_connection.send("") recv_thread.join() @@ -194,12 +195,13 @@ def test_recv_during_recv_streaming(self): ) recv_streaming_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + self.connection.recv() + self.assertEqual( + str(raised.exception), "cannot call recv while another thread " "is already running recv or recv_streaming", - ): - self.connection.recv() + ) self.remote_connection.send("") recv_streaming_thread.join() @@ -257,12 +259,13 @@ def test_recv_streaming_during_recv(self): recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + list(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), "cannot call recv_streaming while another thread " "is already running recv or recv_streaming", - ): - list(self.connection.recv_streaming()) + ) self.remote_connection.send("") recv_thread.join() @@ -274,12 +277,13 @@ def test_recv_streaming_during_recv_streaming(self): ) recv_streaming_thread.start() - with self.assertRaisesRegex( - RuntimeError, + with self.assertRaises(RuntimeError) as raised: + list(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), r"cannot call recv_streaming while another thread " r"is already running recv or recv_streaming", - ): - list(self.connection.recv_streaming()) + ) self.remote_connection.send("") recv_streaming_thread.join() @@ -355,11 +359,12 @@ def fragments(): [b"\x01\x02", b"\xfe\xff"], ]: with self.subTest(message=message): - with self.assertRaisesRegex( - RuntimeError, - "cannot call send while another thread is already running send", - ): + with self.assertRaises(RuntimeError) as raised: self.connection.send(message) + self.assertEqual( + str(raised.exception), + "cannot call send while another thread is already running send", + ) exit_gate.set() send_thread.join() @@ -598,11 +603,12 @@ def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.remote_connection.protocol_mutex: # block response to ping pong_waiter = self.connection.ping("idem") - with self.assertRaisesRegex( - RuntimeError, - "already waiting for a pong with the same data", - ): + with self.assertRaises(RuntimeError) as raised: self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) self.assertTrue(pong_waiter.wait(MS)) self.connection.ping("idem") # doesn't raise an exception diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 5e7e79c52..f9f30baf1 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -41,33 +41,36 @@ def remove_key_header(self, request): del request.headers["Sec-WebSocket-Key"] with run_server(process_request=remove_key_header) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 400", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) def test_connection_handler_returns(self): """Connection handler returns.""" with run_server(do_nothing) as server: with run_client(server) as client: - with self.assertRaisesRegex( - ConnectionClosedOK, - r"received 1000 \(OK\); then sent 1000 \(OK\)", - ): + with self.assertRaises(ConnectionClosedOK) as raised: client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" with run_server(crash) as server: with run_client(server) as client: - with self.assertRaisesRegex( - ConnectionClosedError, - r"received 1011 \(internal error\); " - r"then sent 1011 \(internal error\)", - ): + with self.assertRaises(ConnectionClosedError) as raised: client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); " + "then sent 1011 (internal error)", + ) def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" @@ -100,12 +103,13 @@ def select_subprotocol(ws, subprotocols): raise NegotiationError with run_server(select_subprotocol=select_subprotocol) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 400", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) def test_select_subprotocol_raises_exception(self): """Server returns an error if select_subprotocol raises an exception.""" @@ -114,12 +118,13 @@ def select_subprotocol(ws, subprotocols): raise RuntimeError with run_server(select_subprotocol=select_subprotocol) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_process_request(self): """Server runs process_request before processing the handshake.""" @@ -139,12 +144,13 @@ def process_request(ws, request): return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") with run_server(process_request=process_request) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 403", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" @@ -153,12 +159,13 @@ def process_request(ws, request): raise RuntimeError with run_server(process_request=process_request) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_process_response(self): """Server runs process_response after processing the handshake.""" @@ -193,12 +200,13 @@ def process_response(ws, request, response): raise RuntimeError with run_server(process_response=process_response) as server: - with self.assertRaisesRegex( - InvalidStatus, - "server rejected WebSocket connection: HTTP 500", - ): + with self.assertRaises(InvalidStatus) as raised: with run_client(server): self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) def test_override_server(self): """Server can override Server header with server_header.""" @@ -334,37 +342,41 @@ def test_connection(self): class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" - with self.assertRaisesRegex( - TypeError, - "missing path argument", - ): + with self.assertRaises(TypeError) as raised: unix_serve(eval_shell) + self.assertEqual( + str(raised.exception), + "missing path argument", + ) def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaisesRegex( - TypeError, - "path and sock arguments are incompatible", - ): + with self.assertRaises(TypeError) as raised: unix_serve(eval_shell, path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock arguments are incompatible", + ) def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" - with self.assertRaisesRegex( - TypeError, - "subprotocols must be a list", - ): + with self.assertRaises(TypeError) as raised: serve(eval_shell, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" - with self.assertRaisesRegex( - ValueError, - "unsupported compression: False", - ): + with self.assertRaises(ValueError) as raised: serve(eval_shell, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) class WebSocketServerTests(unittest.TestCase): diff --git a/tests/test_server.py b/tests/test_server.py index b6f5e3568..e4460dcba 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -213,7 +213,10 @@ def test_unexpected_exception(self): self.assertEqual(response.status_code, 500) with self.assertRaises(Exception) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "BOOM") + self.assertEqual( + str(raised.exception), + "BOOM", + ) def test_missing_connection(self): server = ServerProtocol() @@ -225,7 +228,10 @@ def test_missing_connection(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Connection header") + self.assertEqual( + str(raised.exception), + "missing Connection header", + ) def test_invalid_connection(self): server = ServerProtocol() @@ -238,7 +244,10 @@ def test_invalid_connection(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "invalid Connection header: close") + self.assertEqual( + str(raised.exception), + "invalid Connection header: close", + ) def test_missing_upgrade(self): server = ServerProtocol() @@ -250,7 +259,10 @@ def test_missing_upgrade(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Upgrade header") + self.assertEqual( + str(raised.exception), + "missing Upgrade header", + ) def test_invalid_upgrade(self): server = ServerProtocol() @@ -263,7 +275,10 @@ def test_invalid_upgrade(self): self.assertEqual(response.headers["Upgrade"], "websocket") with self.assertRaises(InvalidUpgrade) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") + self.assertEqual( + str(raised.exception), + "invalid Upgrade header: h2c", + ) def test_missing_key(self): server = ServerProtocol() @@ -274,7 +289,10 @@ def test_missing_key(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Key header") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Key header", + ) def test_multiple_key(self): server = ServerProtocol() @@ -302,7 +320,8 @@ def test_invalid_key(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Sec-WebSocket-Key header: not Base64 data!" + str(raised.exception), + "invalid Sec-WebSocket-Key header: not Base64 data!", ) def test_truncated_key(self): @@ -318,7 +337,8 @@ def test_truncated_key(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), f"invalid Sec-WebSocket-Key header: {KEY[:16]}" + str(raised.exception), + f"invalid Sec-WebSocket-Key header: {KEY[:16]}", ) def test_missing_version(self): @@ -330,7 +350,10 @@ def test_missing_version(self): self.assertEqual(response.status_code, 400) with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Version header") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Version header", + ) def test_multiple_version(self): server = ServerProtocol() @@ -358,7 +381,8 @@ def test_invalid_version(self): with self.assertRaises(InvalidHeader) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Sec-WebSocket-Version header: 11" + str(raised.exception), + "invalid Sec-WebSocket-Version header: 11", ) def test_no_origin(self): @@ -369,7 +393,10 @@ def test_no_origin(self): self.assertEqual(response.status_code, 403) with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc - self.assertEqual(str(raised.exception), "missing Origin header") + self.assertEqual( + str(raised.exception), + "missing Origin header", + ) def test_origin(self): server = ServerProtocol(origins=["https://example.com"]) @@ -390,7 +417,8 @@ def test_unexpected_origin(self): with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Origin header: https://other.example.com" + str(raised.exception), + "invalid Origin header: https://other.example.com", ) def test_multiple_origin(self): @@ -435,7 +463,8 @@ def test_unsupported_origin(self): with self.assertRaises(InvalidOrigin) as raised: raise server.handshake_exc self.assertEqual( - str(raised.exception), "invalid Origin header: https://original.example.com" + str(raised.exception), + "invalid Origin header: https://original.example.com", ) def test_no_origin_accepted(self): @@ -574,11 +603,12 @@ def test_no_subprotocol(self): response = server.accept(request) self.assertEqual(response.status_code, 400) - with self.assertRaisesRegex( - NegotiationError, - r"missing subprotocol", - ): + with self.assertRaises(NegotiationError) as raised: raise server.handshake_exc + self.assertEqual( + str(raised.exception), + "missing subprotocol", + ) def test_subprotocol(self): server = ServerProtocol(subprotocols=["chat"]) @@ -628,11 +658,12 @@ def test_unsupported_subprotocol(self): response = server.accept(request) self.assertEqual(response.status_code, 400) - with self.assertRaisesRegex( - NegotiationError, - r"invalid subprotocol; expected one of superchat, chat", - ): + with self.assertRaises(NegotiationError) as raised: raise server.handshake_exc + self.assertEqual( + str(raised.exception), + "invalid subprotocol; expected one of superchat, chat", + ) @staticmethod def optional_chat(protocol, subprotocols): From c06e44d214ca3650b12fbbcaa1a0266dae9432d0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 15:20:39 +0100 Subject: [PATCH 490/762] Support closing while sending a fragmented message. On one hand, it will close the connection with an unfinished fragmented message, which is less than ideal. On the other hand, RFC 6455 implies that it should be legal and it's probably best to let users close the connection if they want to close the connection (rather than force them to call fail() instead). --- src/websockets/protocol.py | 6 ++---- tests/test_protocol.py | 36 ++++++++++++------------------------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 6851f3b1f..4650cf16d 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -354,12 +354,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: reason: close reason. Raises: - ProtocolError: If a fragmented message is being sent, if the code - isn't valid, or if a reason is provided without a code + ProtocolError: If the code isn't valid or if a reason is provided + without a code. """ - if self.expect_continuation_frame: - raise ProtocolError("expected a continuation frame") if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a64172b53..a1661231f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1364,26 +1364,22 @@ def test_client_send_close_in_fragmented_message(self): client = Protocol(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - client.send_close(CloseCode.GOING_AWAY) - self.assertEqual(str(raised.exception), "expected a continuation frame") - client.send_continuation(b"Eggs", fin=True) + with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + client.send_close() + self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) + self.assertIs(client.state, CLOSING) + with self.assertRaises(InvalidState): + client.send_continuation(b"Eggs", fin=True) def test_server_send_close_in_fragmented_message(self): - server = Protocol(CLIENT) + server = Protocol(SERVER) server.send_text(b"Spam", fin=False) self.assertFrameSent(server, Frame(OP_TEXT, b"Spam", fin=False)) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. - with self.assertRaises(ProtocolError) as raised: - server.send_close(CloseCode.NORMAL_CLOSURE) - self.assertEqual(str(raised.exception), "expected a continuation frame") + server.send_close() + self.assertEqual(server.data_to_send(), [b"\x88\x00"]) + self.assertIs(server.state, CLOSING) + with self.assertRaises(InvalidState): + server.send_continuation(b"Eggs", fin=True) def test_client_receive_close_in_fragmented_message(self): client = Protocol(CLIENT) @@ -1392,10 +1388,6 @@ def test_client_receive_close_in_fragmented_message(self): client, Frame(OP_TEXT, b"Spam", fin=False), ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. client.receive_data(b"\x88\x02\x03\xe8") self.assertIsInstance(client.parser_exc, ProtocolError) self.assertEqual(str(client.parser_exc), "incomplete fragmented message") @@ -1410,10 +1402,6 @@ def test_server_receive_close_in_fragmented_message(self): server, Frame(OP_TEXT, b"Spam", fin=False), ) - # The spec says: "An endpoint MUST be capable of handling control - # frames in the middle of a fragmented message." However, since the - # endpoint must not send a data frame after a close frame, a close - # frame can't be "in the middle" of a fragmented message. server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertIsInstance(server.parser_exc, ProtocolError) self.assertEqual(str(server.parser_exc), "incomplete fragmented message") From d28b71dd297da99aad9d644a2f4721707e464707 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 15:41:42 +0100 Subject: [PATCH 491/762] Upgrade to the latest version of black. --- src/websockets/datastructures.py | 6 ++---- tests/test_protocol.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index c2a5acfee..aef11bf23 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -172,11 +172,9 @@ class SupportsKeysAndGetItem(Protocol): # pragma: no cover """ - def keys(self) -> Iterable[str]: - ... + def keys(self) -> Iterable[str]: ... - def __getitem__(self, key: str) -> str: - ... + def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a1661231f..b53c8a1ec 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -37,12 +37,14 @@ def assertFrameSent(self, connection, frame, eof=False): """ frames_sent = [ - None - if write is SEND_EOF - else self.parse( - write, - mask=connection.side is CLIENT, - extensions=connection.extensions, + ( + None + if write is SEND_EOF + else self.parse( + write, + mask=connection.side is CLIENT, + extensions=connection.extensions, + ) ) for write in connection.data_to_send() ] From 705dc85e87bb1184d926ab95a591097780c4b855 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 27 Jan 2024 14:55:44 +0100 Subject: [PATCH 492/762] Allow sending ping and pong after close. Fix #1429. --- src/websockets/protocol.py | 28 ++++++++-- tests/test_protocol.py | 111 +++++++++++++++++++++++++++---------- 2 files changed, 105 insertions(+), 34 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 4650cf16d..0b36202e5 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -297,6 +297,8 @@ def send_continuation(self, data: bytes, fin: bool) -> None: """ if not self.expect_continuation_frame: raise ProtocolError("unexpected continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_CONT, data, fin)) @@ -318,6 +320,8 @@ def send_text(self, data: bytes, fin: bool = True) -> None: """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_TEXT, data, fin)) @@ -339,6 +343,8 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: """ if self.expect_continuation_frame: raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) @@ -358,6 +364,10 @@ def send_close(self, code: Optional[int] = None, reason: str = "") -> None: without a code. """ + # While RFC 6455 doesn't rule out sending more than one close Frame, + # websockets is conservative in what it sends and doesn't allow that. + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") if code is None: if reason != "": raise ProtocolError("cannot send a reason without a code") @@ -383,6 +393,9 @@ def send_ping(self, data: bytes) -> None: data: payload containing arbitrary binary data. """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PING, data)) def send_pong(self, data: bytes) -> None: @@ -396,6 +409,9 @@ def send_pong(self, data: bytes) -> None: data: payload containing arbitrary binary data. """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PONG, data)) def fail(self, code: int, reason: str = "") -> None: @@ -675,6 +691,8 @@ def recv_frame(self, frame: Frame) -> None: # 1.4. Closing Handshake: "after receiving a control frame # indicating the connection should be closed, a peer discards # any further data received." + # RFC 6455 allows reading Ping and Pong frames after a Close frame. + # However, that doesn't seem useful; websockets doesn't support it. self.parser = self.discard() next(self.parser) # start coroutine @@ -687,15 +705,13 @@ def recv_frame(self, frame: Frame) -> None: # Private methods for sending events. def send_frame(self, frame: Frame) -> None: - if self.state is not OPEN: - raise InvalidState( - f"cannot write to a WebSocket in the {self.state.name} state" - ) - if self.debug: self.logger.debug("> %s", frame) self.writes.append( - frame.serialize(mask=self.side is CLIENT, extensions=self.extensions) + frame.serialize( + mask=self.side is CLIENT, + extensions=self.extensions, + ) ) def send_eof(self) -> None: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index b53c8a1ec..1d5dab7a0 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -465,15 +465,17 @@ def test_client_sends_text_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_text(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_text_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_text(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_text_after_receiving_close(self): client = Protocol(CLIENT) @@ -679,15 +681,17 @@ def test_client_sends_binary_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_binary(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_server_sends_binary_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_binary(b"") + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receives_binary_after_receiving_close(self): client = Protocol(CLIENT) @@ -956,6 +960,37 @@ def test_server_receives_close_with_non_utf8_reason(self): ) self.assertIs(server.state, CLOSING) + def test_client_sends_close_twice(self): + client = Protocol(CLIENT) + with self.enforce_mask(b"\x00\x00\x00\x00"): + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) + with self.assertRaises(InvalidState) as raised: + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(str(raised.exception), "connection is closing") + + def test_server_sends_close_twice(self): + server = Protocol(SERVER) + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) + with self.assertRaises(InvalidState) as raised: + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(str(raised.exception), "connection is closing") + + def test_client_sends_close_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_close(CloseCode.GOING_AWAY) + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_close_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_close(CloseCode.NORMAL_CLOSURE) + self.assertEqual(str(raised.exception), "connection is closed") + class PingTests(ProtocolTestCase): """ @@ -1072,35 +1107,23 @@ def test_client_sends_ping_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: + with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) + self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) def test_server_sends_ping_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # The spec says: "An endpoint MAY send a Ping frame any time (...) - # before the connection is closed" but websockets doesn't support - # sending a Ping frame after a Close frame. - with self.assertRaises(InvalidState) as raised: - server.send_ping(b"") - self.assertEqual( - str(raised.exception), - "cannot write to a WebSocket in the CLOSING state", - ) + server.send_ping(b"") + self.assertEqual(server.data_to_send(), [b"\x89\x00"]) def test_client_receives_ping_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x89\x04\x22\x66\xaa\xee") + # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1109,9 +1132,24 @@ def test_server_receives_ping_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) + def test_client_sends_ping_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_ping(b"") + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_ping_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_ping(b"") + self.assertEqual(str(raised.exception), "connection is closed") + class PongTests(ProtocolTestCase): """ @@ -1212,23 +1250,23 @@ def test_client_sends_pong_after_sending_close(self): with self.enforce_mask(b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): + with self.enforce_mask(b"\x00\x44\x88\xcc"): client.send_pong(b"") + self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) def test_server_sends_pong_after_sending_close(self): server = Protocol(SERVER) server.send_close(CloseCode.NORMAL_CLOSURE) self.assertEqual(server.data_to_send(), [b"\x88\x02\x03\xe8"]) - # websockets doesn't support sending a Pong frame after a Close frame. - with self.assertRaises(InvalidState): - server.send_pong(b"") + server.send_pong(b"") + self.assertEqual(server.data_to_send(), [b"\x8a\x00"]) def test_client_receives_pong_after_receiving_close(self): client = Protocol(CLIENT) client.receive_data(b"\x88\x02\x03\xe8") self.assertConnectionClosing(client, CloseCode.NORMAL_CLOSURE) client.receive_data(b"\x8a\x04\x22\x66\xaa\xee") + # websockets ignores control frames after a close frame. self.assertFrameReceived(client, None) self.assertFrameSent(client, None) @@ -1237,9 +1275,24 @@ def test_server_receives_pong_after_receiving_close(self): server.receive_data(b"\x88\x82\x00\x00\x00\x00\x03\xe9") self.assertConnectionClosing(server, CloseCode.GOING_AWAY) server.receive_data(b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22") + # websockets ignores control frames after a close frame. self.assertFrameReceived(server, None) self.assertFrameSent(server, None) + def test_client_sends_pong_after_connection_is_closed(self): + client = Protocol(CLIENT) + client.receive_eof() + with self.assertRaises(InvalidState) as raised: + client.send_pong(b"") + self.assertEqual(str(raised.exception), "connection is closed") + + def test_server_sends_pong_after_connection_is_closed(self): + server = Protocol(SERVER) + server.receive_eof() + with self.assertRaises(InvalidState) as raised: + server.send_pong(b"") + self.assertEqual(str(raised.exception), "connection is closed") + class FailTests(ProtocolTestCase): """ @@ -1370,8 +1423,9 @@ def test_client_send_close_in_fragmented_message(self): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: client.send_continuation(b"Eggs", fin=True) + self.assertEqual(str(raised.exception), "connection is closing") def test_server_send_close_in_fragmented_message(self): server = Protocol(SERVER) @@ -1380,8 +1434,9 @@ def test_server_send_close_in_fragmented_message(self): server.send_close() self.assertEqual(server.data_to_send(), [b"\x88\x00"]) self.assertIs(server.state, CLOSING) - with self.assertRaises(InvalidState): + with self.assertRaises(InvalidState) as raised: server.send_continuation(b"Eggs", fin=True) + self.assertEqual(str(raised.exception), "connection is closing") def test_client_receive_close_in_fragmented_message(self): client = Protocol(CLIENT) From 96fddaf49b5a5af1f3215076bf2a73dfb4b72ca1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Jan 2024 16:47:51 +0100 Subject: [PATCH 493/762] Wording and line wrapping fixes in changelog. --- docs/project/changelog.rst | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index e288831be..dc84a5ae2 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,7 +34,8 @@ Backwards-incompatible changes .............................. .. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` - and :func:`~sync.server.serve` is renamed to ``ssl``. + and :func:`~sync.server.serve` in the :mod:`threading` implementation is + renamed to ``ssl``. :class: note This aligns the API of the :mod:`threading` implementation with the @@ -140,7 +141,8 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. -.. admonition:: :func:`~server.serve` times out on the opening handshake after 10 seconds by default. +.. admonition:: :func:`~server.serve` times out on the opening handshake after + 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -149,7 +151,7 @@ Backwards-incompatible changes New features ............ -.. admonition:: websockets 11.0 introduces a implementation on top of :mod:`threading`. +.. admonition:: websockets 11.0 introduces a :mod:`threading` implementation. :class: important It may be more convenient if you don't need to manage many connections and @@ -211,7 +213,8 @@ Improvements Backwards-incompatible changes .............................. -.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated. +.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and + :class:`~http11.Response` is deprecated. :class: note Use the ``handshake_exc`` attribute of :class:`~server.ServerProtocol` and @@ -565,11 +568,11 @@ Backwards-incompatible changes .. admonition:: ``process_request`` is now expected to be a coroutine. :class: note - If you're passing a ``process_request`` argument to - :func:`~server.serve` or :class:`~server.WebSocketServerProtocol`, or if - you're overriding + If you're passing a ``process_request`` argument to :func:`~server.serve` + or :class:`~server.WebSocketServerProtocol`, or if you're overriding :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, - define it with ``async def`` instead of ``def``. Previously, both were supported. + define it with ``async def`` instead of ``def``. Previously, both were + supported. For backwards compatibility, functions are still accepted, but mixing functions and coroutines won't work in some inheritance scenarios. From 3b7fa7673bf6a96a5e9debd7dcfa65e04f85efbb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 28 Jan 2024 16:49:13 +0100 Subject: [PATCH 494/762] Enable deprecation for second argument of handlers. --- docs/project/changelog.rst | 26 ++++++++++++++++++++------ src/websockets/legacy/server.py | 4 +--- tests/legacy/test_client_server.py | 6 ++---- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index dc84a5ae2..fd186a5fc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,6 +43,21 @@ Backwards-incompatible changes For backwards compatibility, ``ssl_context`` is still supported. +.. admonition:: Receiving the request path in the second parameter of connection + handlers is deprecated. + :class: note + + If you implemented the connection handler of a server as:: + + async def handler(request, path): + ... + + You should switch to the recommended pattern since 10.1:: + + async def handler(request): + path = request.path # only if handler() uses the path argument + ... + New features ............ @@ -257,20 +272,19 @@ New features * Added a tutorial. -* Made the second parameter of connection handlers optional. It will be - deprecated in the next major release. The request path is available in - the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` attribute of - the first argument. +* Made the second parameter of connection handlers optional. The request path is + available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.path` + attribute of the first argument. If you implemented the connection handler of a server as:: async def handler(request, path): ... - You should replace it by:: + You should replace it with:: async def handler(request): - path = request.path # if handler() uses the path argument + path = request.path # only if handler() uses the path argument ... * Added ``python -m websockets --version``. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index e8cf8220f..4659ed9a6 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1168,9 +1168,7 @@ def remove_path_argument( pass else: # ws_handler accepts two arguments; activate backwards compatibility. - - # Enable deprecation warning and announce deprecation in 11.0. - # warnings.warn("remove second argument of ws_handler", DeprecationWarning) + warnings.warn("remove second argument of ws_handler", DeprecationWarning) async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: return await cast( diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 4a21f7cea..51a74734b 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -480,8 +480,7 @@ async def handler_with_path(ws, path): with self.temp_server( handler=handler_with_path, - # Enable deprecation warning and announce deprecation in 11.0. - # deprecation_warnings=["remove second argument of ws_handler"], + deprecation_warnings=["remove second argument of ws_handler"], ): with self.temp_client("/path"): self.assertEqual( @@ -497,8 +496,7 @@ async def handler_with_path(ws, path, extra): with self.temp_server( handler=bound_handler_with_path, - # Enable deprecation warning and announce deprecation in 11.0. - # deprecation_warnings=["remove second argument of ws_handler"], + deprecation_warnings=["remove second argument of ws_handler"], ): with self.temp_client("/path"): self.assertEqual( From aa33161cd9498bfca39d64fc36319bc1fbce68f2 Mon Sep 17 00:00:00 2001 From: MtkN1 <51289448+MtkN1@users.noreply.github.com> Date: Wed, 7 Feb 2024 13:14:22 +0900 Subject: [PATCH 495/762] Fix wrong RFC number --- src/websockets/legacy/client.py | 2 +- tests/test_protocol.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b85d22867..255696580 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -599,7 +599,7 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: yield protocol except Exception: # Add a random initial delay between 0 and 5 seconds. - # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 1d5dab7a0..e1527525b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -714,7 +714,7 @@ class CloseTests(ProtocolTestCase): """ Test close frames. - See RFC 6544: + See RFC 6455: 5.5.1. Close 7.1.6. The WebSocket Connection Close Reason @@ -994,7 +994,7 @@ def test_server_sends_close_after_connection_is_closed(self): class PingTests(ProtocolTestCase): """ - Test ping. See 5.5.2. Ping in RFC 6544. + Test ping. See 5.5.2. Ping in RFC 6455. """ @@ -1153,7 +1153,7 @@ def test_server_sends_ping_after_connection_is_closed(self): class PongTests(ProtocolTestCase): """ - Test pong frames. See 5.5.3. Pong in RFC 6544. + Test pong frames. See 5.5.3. Pong in RFC 6455. """ @@ -1298,7 +1298,7 @@ class FailTests(ProtocolTestCase): """ Test failing the connection. - See 7.1.7. Fail the WebSocket Connection in RFC 6544. + See 7.1.7. Fail the WebSocket Connection in RFC 6455. """ @@ -1321,7 +1321,7 @@ class FragmentationTests(ProtocolTestCase): """ Test message fragmentation. - See 5.4. Fragmentation in RFC 6544. + See 5.4. Fragmentation in RFC 6455. """ From 87f58c7190025521e5dc380945b0cc536169bd0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 17:25:00 +0100 Subject: [PATCH 496/762] Fix make clean. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index cf3b53393..bf8c8dc58 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,6 @@ build: python setup.py build_ext --inplace clean: - find . -name '*.pyc' -o -name '*.so' -delete + find . -name '*.pyc' -delete -o -name '*.so' -delete find . -name __pycache__ -delete rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 9b5273c68323dd63598dfcba97339f03f61d3d0f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 17:34:57 +0100 Subject: [PATCH 497/762] Move CLIENT/SERVER_CONTEXT to utils. Then we can reuse them for testing other implementations. --- tests/sync/client.py | 8 -------- tests/sync/server.py | 17 ----------------- tests/sync/test_client.py | 12 +++++++++--- tests/sync/test_server.py | 11 ++++++++--- tests/utils.py | 18 ++++++++++++++++++ 5 files changed, 35 insertions(+), 31 deletions(-) diff --git a/tests/sync/client.py b/tests/sync/client.py index bb4855c7f..72eb5b8d2 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -1,23 +1,15 @@ import contextlib -import ssl from websockets.sync.client import * from websockets.sync.server import WebSocketServer -from ..utils import CERTIFICATE - __all__ = [ - "CLIENT_CONTEXT", "run_client", "run_unix_client", ] -CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) - - @contextlib.contextmanager def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): diff --git a/tests/sync/server.py b/tests/sync/server.py index a9a77438c..10ab789c2 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,25 +1,8 @@ import contextlib -import ssl import threading from websockets.sync.server import * -from ..utils import CERTIFICATE - - -SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) -SERVER_CONTEXT.load_cert_chain(CERTIFICATE) - -# Work around https://github.com/openssl/openssl/issues/7967 - -# This bug causes connect() to hang in tests for the client. Including this -# workaround acknowledges that the issue could happen outside of the test suite. - -# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it -# happens, we can look for a library-level fix, but it won't be easy. - -SERVER_CONTEXT.num_tickets = 0 - def crash(ws): raise RuntimeError diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index c403b9632..bebf68aa5 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,9 +7,15 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..utils import MS, DeprecationTestCase, temp_unix_socket_path -from .client import CLIENT_CONTEXT, run_client, run_unix_client -from .server import SERVER_CONTEXT, do_nothing, run_server, run_unix_server +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + DeprecationTestCase, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import do_nothing, run_server, run_unix_server class ClientTests(unittest.TestCase): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f9f30baf1..490a3f63e 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -14,10 +14,15 @@ from websockets.http11 import Request, Response from websockets.sync.server import * -from ..utils import MS, DeprecationTestCase, temp_unix_socket_path -from .client import CLIENT_CONTEXT, run_client, run_unix_client -from .server import ( +from ..utils import ( + CLIENT_CONTEXT, + MS, SERVER_CONTEXT, + DeprecationTestCase, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import ( EvalShellMixin, crash, do_nothing, diff --git a/tests/utils.py b/tests/utils.py index 2937a2f15..bd3b61d7b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import os import pathlib import platform +import ssl import tempfile import time import unittest @@ -17,6 +18,23 @@ CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) +CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) + + +SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +SERVER_CONTEXT.load_cert_chain(CERTIFICATE) + +# Work around https://github.com/openssl/openssl/issues/7967 + +# This bug causes connect() to hang in tests for the client. Including this +# workaround acknowledges that the issue could happen outside of the test suite. + +# It shouldn't happen too often, or else OpenSSL 1.1.1 would be unusable. If it +# happens, we can look for a library-level fix, but it won't be easy. + +SERVER_CONTEXT.num_tickets = 0 + DATE = email.utils.formatdate(usegmt=True) From de768cf65e7e2b1a3b67854fb9e08816a5ff7050 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:08:45 +0100 Subject: [PATCH 498/762] Improve tests for sync implementation. --- tests/sync/server.py | 8 +++--- tests/sync/test_client.py | 59 ++++++++++++++++++++------------------- tests/sync/test_server.py | 36 ++++++++++++------------ 3 files changed, 53 insertions(+), 50 deletions(-) diff --git a/tests/sync/server.py b/tests/sync/server.py index 10ab789c2..d5295ccd8 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -25,8 +25,8 @@ def assertEval(self, client, expr, value): @contextlib.contextmanager -def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs): - with serve(ws_handler, host, port, **kwargs) as server: +def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): + with serve(handler, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: @@ -37,8 +37,8 @@ def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs): @contextlib.contextmanager -def run_unix_server(path, ws_handler=eval_shell, **kwargs): - with unix_serve(ws_handler, path, **kwargs) as server: +def run_unix_server(path, handler=eval_shell, **kwargs): + with unix_serve(handler, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index bebf68aa5..03f4e972f 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -3,7 +3,7 @@ import threading import unittest -from websockets.exceptions import InvalidHandshake +from websockets.exceptions import InvalidHandshake, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -25,29 +25,6 @@ def test_connection(self): with run_client(server) as client: self.assertEqual(client.protocol.state.name, "OPEN") - def test_connection_fails(self): - """Client connects to server but the handshake fails.""" - - def remove_accept_header(self, request, response): - del response.headers["Sec-WebSocket-Accept"] - - # The connection will be open for the server but failed for the client. - # Use a connection handler that exits immediately to avoid an exception. - with run_server(do_nothing, process_response=remove_accept_header) as server: - with self.assertRaises(InvalidHandshake) as raised: - with run_client(server, close_timeout=MS): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "missing Sec-WebSocket-Accept header", - ) - - def test_tcp_connection_fails(self): - """Client fails to connect to server.""" - with self.assertRaises(OSError): - with run_client("ws://localhost:54321"): # invalid port - self.fail("did not raise") - def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: @@ -103,6 +80,35 @@ def create_connection(*args, **kwargs): with run_client(server, create_connection=create_connection) as client: self.assertTrue(client.create_connection_ran) + def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + with run_client("http://localhost"): # invalid scheme + self.fail("did not raise") + + def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + with run_server(do_nothing, process_response=remove_accept_header) as server: + with self.assertRaises(InvalidHandshake) as raised: + with run_client(server, close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" gate = threading.Event() @@ -115,10 +121,7 @@ def stall_connection(self, request): with run_server(do_nothing, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - # While it shouldn't take 50ms to open a connection, this - # test becomes flaky in CI when setting a smaller timeout, - # even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR. - with run_client(server, open_timeout=5 * MS): + with run_client(server, open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 490a3f63e..9d509a5c4 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -39,21 +39,6 @@ def test_connection(self): with run_client(server) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") - def test_connection_fails(self): - """Server receives connection from client but the handshake fails.""" - - def remove_key_header(self, request): - del request.headers["Sec-WebSocket-Key"] - - with run_server(process_request=remove_key_header) as server: - with self.assertRaises(InvalidStatus) as raised: - with run_client(server): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 400", - ) - def test_connection_handler_returns(self): """Connection handler returns.""" with run_server(do_nothing) as server: @@ -81,8 +66,8 @@ def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: with run_server(sock=sock): - # Build WebSocket URI to ensure we connect to the right socket. - with run_client("ws://{}:{}/".format(*sock.getsockname())) as client: + uri = "ws://{}:{}/".format(*sock.getsockname()) + with run_client(uri) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_select_subprotocol(self): @@ -185,7 +170,7 @@ def process_response(ws, request, response): self.assertEval(client, "ws.process_response_ran", "True") def test_process_response_override_response(self): - """Server runs process_response after processing the handshake.""" + """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): headers = response.headers.copy() @@ -253,6 +238,21 @@ def create_connection(*args, **kwargs): with run_client(server) as client: self.assertEval(client, "ws.create_connection_ran", "True") + def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" with run_server(open_timeout=MS) as server: From 50b6d20d7a652d39cffc7aea9f8c0abc88fb8f37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:19:24 +0100 Subject: [PATCH 499/762] Various cleanups in sync implementation. --- src/websockets/sync/client.py | 9 +++-- src/websockets/sync/connection.py | 58 +++++++++++++++---------------- src/websockets/sync/server.py | 27 +++++++------- 3 files changed, 45 insertions(+), 49 deletions(-) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 6faca7789..0bb7a76fd 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -25,7 +25,7 @@ class ClientConnection(Connection): """ - Threaded implementation of a WebSocket client connection. + :mod:`threading` implementation of a WebSocket client connection. :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. @@ -157,7 +157,7 @@ def connect( :func:`connect` may be used as a context manager:: - async with websockets.sync.client.connect(...) as websocket: + with websockets.sync.client.connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -273,19 +273,18 @@ def connect( sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) - # Initialize WebSocket connection + # Initialize WebSocket protocol protocol = ClientProtocol( wsuri, origin=origin, extensions=extensions, subprotocols=subprotocols, - state=CONNECTING, max_size=max_size, logger=logger, ) - # Initialize WebSocket protocol + # Initialize WebSocket connection connection = create_connection( sock, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 62aa17ffd..6ac40cd7c 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -21,12 +21,10 @@ __all__ = ["Connection"] -logger = logging.getLogger(__name__) - class Connection: """ - Threaded implementation of a WebSocket connection. + :mod:`threading` implementation of a WebSocket connection. :class:`Connection` provides APIs shared between WebSocket servers and clients. @@ -82,7 +80,7 @@ def __init__( self.close_deadline: Optional[Deadline] = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, threading.Event] = {} + self.ping_waiters: Dict[bytes, threading.Event] = {} # Receiving events from the socket. self.recv_events_thread = threading.Thread(target=self.recv_events) @@ -90,7 +88,7 @@ def __init__( # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. - self.recv_events_exc: Optional[BaseException] = None + self.recv_exc: Optional[BaseException] = None # Public attributes @@ -198,7 +196,7 @@ def recv(self, timeout: Optional[float] = None) -> Data: try: return self.recv_messages.get(timeout) except EOFError: - raise self.protocol.close_exc from self.recv_events_exc + raise self.protocol.close_exc from self.recv_exc except RuntimeError: raise RuntimeError( "cannot call recv while another thread " @@ -229,9 +227,10 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + for frame in self.recv_messages.get_iter(): + yield frame except EOFError: - raise self.protocol.close_exc from self.recv_events_exc + raise self.protocol.close_exc from self.recv_exc except RuntimeError: raise RuntimeError( "cannot call recv_streaming while another thread " @@ -273,7 +272,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If a connection is busy sending a fragmented message. + RuntimeError: If the connection is sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -449,15 +448,15 @@ def ping(self, data: Optional[Data] = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pings: + if data in self.ping_waiters: raise RuntimeError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pings: + while data is None or data in self.ping_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pings[data] = pong_waiter + self.ping_waiters[data] = pong_waiter self.protocol.send_ping(data) return pong_waiter @@ -504,22 +503,22 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.pings: + if data not in self.ping_waiters: return # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pings.items(): + for ping_id, ping in self.ping_waiters.items(): ping_ids.append(ping_id) ping.set() if ping_id == data: break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pings. + # Remove acknowledged pings from self.ping_waiters. for ping_id in ping_ids: - del self.pings[ping_id] + del self.ping_waiters[ping_id] def recv_events(self) -> None: """ @@ -541,10 +540,10 @@ def recv_events(self) -> None: self.logger.debug("error while receiving data", exc_info=True) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. - # In that case, send_context() already set recv_events_exc. - # Calling set_recv_events_exc() avoids overwriting it. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. with self.protocol_mutex: - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) break if data == b"": @@ -552,7 +551,7 @@ def recv_events(self) -> None: # Acquire the connection lock. with self.protocol_mutex: - # Feed incoming data to the connection. + # Feed incoming data to the protocol. self.protocol.receive_data(data) # This isn't expected to raise an exception. @@ -568,7 +567,7 @@ def recv_events(self) -> None: # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() # returns above but before send_data() calls send(). - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) break if self.protocol.close_expected(): @@ -595,7 +594,7 @@ def recv_events(self) -> None: # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. with self.protocol_mutex: - # Feed the end of the data stream to the connection. + # Feed the end of the data stream to the protocol. self.protocol.receive_eof() # This isn't expected to generate events. @@ -609,7 +608,7 @@ def recv_events(self) -> None: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) with self.protocol_mutex: - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) # We don't know where we crashed. Force protocol state to CLOSED. self.protocol.state = CLOSED finally: @@ -668,7 +667,6 @@ def send_context( wait_for_close = True # If the connection is expected to close soon, set the # close deadline based on the close timeout. - # Since we tested earlier that protocol.state was OPEN # (or CONNECTING) and we didn't release protocol_mutex, # it is certain that self.close_deadline is still None. @@ -710,11 +708,11 @@ def send_context( # original_exc is never set when wait_for_close is True. assert original_exc is None original_exc = TimeoutError("timed out while closing connection") - # Set recv_events_exc before closing the socket in order to get + # Set recv_exc before closing the socket in order to get # proper exception reporting. raise_close_exc = True with self.protocol_mutex: - self.set_recv_events_exc(original_exc) + self.set_recv_exc(original_exc) # If an error occurred, close the socket to terminate the connection and # raise an exception. @@ -745,16 +743,16 @@ def send_data(self) -> None: except OSError: # socket already closed pass - def set_recv_events_exc(self, exc: Optional[BaseException]) -> None: + def set_recv_exc(self, exc: Optional[BaseException]) -> None: """ - Set recv_events_exc, if not set yet. + Set recv_exc, if not set yet. This method requires holding protocol_mutex. """ assert self.protocol_mutex.locked() - if self.recv_events_exc is None: - self.recv_events_exc = exc + if self.recv_exc is None: + self.recv_exc = exc def close_socket(self) -> None: """ diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index fa6087d54..a070edf18 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -30,7 +30,7 @@ class ServerConnection(Connection): """ - Threaded implementation of a WebSocket server connection. + :mod:`threading` implementation of a WebSocket server connection. :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. @@ -188,6 +188,8 @@ class WebSocketServer: handler: Handler for one connection. Receives the socket and address returned by :meth:`~socket.socket.accept`. logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. """ @@ -311,16 +313,16 @@ def serve( Whenever a client connects, the server creates a :class:`ServerConnection`, performs the opening handshake, and delegates to the ``handler``. - The handler receives a :class:`ServerConnection` instance, which you can use - to send and receive messages. + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - :class:`WebSocketServer` mirrors the API of + This function returns a :class:`WebSocketServer` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call the :meth:`~WebSocketServer.serve_forever` - method to serve requests:: + that it will be closed and call :meth:`~WebSocketServer.serve_forever` to + serve requests:: def handler(websocket): ... @@ -454,15 +456,13 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: sock.do_handshake() sock.settimeout(None) - # Create a closure so that select_subprotocol has access to self. - + # Create a closure to give select_subprotocol access to connection. protocol_select_subprotocol: Optional[ Callable[ [ServerProtocol, Sequence[Subprotocol]], Optional[Subprotocol], ] ] = None - if select_subprotocol is not None: def protocol_select_subprotocol( @@ -475,19 +475,18 @@ def protocol_select_subprotocol( assert protocol is connection.protocol return select_subprotocol(connection, subprotocols) - # Initialize WebSocket connection + # Initialize WebSocket protocol protocol = ServerProtocol( origins=origins, extensions=extensions, subprotocols=subprotocols, select_subprotocol=protocol_select_subprotocol, - state=CONNECTING, max_size=max_size, logger=logger, ) - # Initialize WebSocket protocol + # Initialize WebSocket connection assert create_connection is not None # help mypy connection = create_connection( @@ -522,7 +521,7 @@ def protocol_select_subprotocol( def unix_serve( - handler: Callable[[ServerConnection], Any], + handler: Callable[[ServerConnection], None], path: Optional[str] = None, **kwargs: Any, ) -> WebSocketServer: @@ -541,4 +540,4 @@ def unix_serve( path: File system path to the Unix socket. """ - return serve(handler, path=path, unix=True, **kwargs) + return serve(handler, unix=True, path=path, **kwargs) From e217458ef8b692e45ca6f66c5aeb7fad0aee97ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:20:48 +0100 Subject: [PATCH 500/762] Small cleanups in legacy implementation. --- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/server.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 255696580..e5da8b13a 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -60,7 +60,7 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send` coroutines for receiving and sending messages. - It supports asynchronous iteration to receive incoming messages:: + It supports asynchronous iteration to receive messages:: async for message in websocket: await process(message) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 4659ed9a6..0f3c1c150 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -649,9 +649,7 @@ class WebSocketServer: """ WebSocket server returned by :func:`serve`. - This class provides the same interface as :class:`~asyncio.Server`, - notably the :meth:`~asyncio.Server.close` - and :meth:`~asyncio.Server.wait_closed` methods. + This class mirrors the API of :class:`~asyncio.Server`. It keeps track of WebSocket connections in order to close them properly when shutting down. From 5f24866bfeefbe561fa76f7e5a494996d95a2757 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 08:48:42 +0200 Subject: [PATCH 501/762] Always mark background threads as daemon. Fix #1455. --- src/websockets/sync/connection.py | 9 +++++++-- src/websockets/sync/server.py | 3 +++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 6ac40cd7c..b41202dc9 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -82,8 +82,13 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: Dict[bytes, threading.Event] = {} - # Receiving events from the socket. - self.recv_events_thread = threading.Thread(target=self.recv_events) + # Receiving events from the socket. This thread explicitly is marked as + # to support creating a connection in a non-daemon thread then using it + # in a daemon thread; this shouldn't block the intpreter from exiting. + self.recv_events_thread = threading.Thread( + target=self.recv_events, + daemon=True, + ) self.recv_events_thread.start() # Exception raised in recv_events, to be chained to ConnectionClosed diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index a070edf18..fd4f5d3bd 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -233,6 +233,9 @@ def serve_forever(self) -> None: sock, addr = self.socket.accept() except OSError: break + # Since there isn't a mechanism for tracking connections and waiting + # for them to terminate, we cannot use daemon threads, or else all + # connections would be terminate brutally when closing the server. thread = threading.Thread(target=self.handler, args=(sock, addr)) thread.start() From 2774fabc13f09311dec345cc8513aa7b93200b92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20M=C3=A9taireau?= Date: Tue, 16 Apr 2024 16:52:01 +0200 Subject: [PATCH 502/762] docs(nginx): Fix a typo --- docs/howto/nginx.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index 30545fbc7..ff42c3c2b 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -17,9 +17,9 @@ Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/nginx/app.py :emphasize-lines: 21,23 -We'd like to nginx to connect to websockets servers via Unix sockets in order -to avoid the overhead of TCP for communicating between processes running in -the same OS. +We'd like nginx to connect to websockets servers via Unix sockets in order to +avoid the overhead of TCP for communicating between processes running in the +same OS. We start the app with :func:`~websockets.server.unix_serve`. Each server process listens on a different socket thanks to an environment variable set From 0fdc694a980ede0e91286ea5ea1d4f9c62bb42fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 08:55:36 +0200 Subject: [PATCH 503/762] Make it easy to monkey-patch length of frames repr. Fix #1451. --- src/websockets/frames.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 201bc5068..862eef3aa 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -146,6 +146,9 @@ class Frame: rsv2: bool = False rsv3: bool = False + # Monkey-patch if you want to see more in logs. Should be a multiple of 3. + MAX_LOG = 75 + def __str__(self) -> str: """ Return a human-readable representation of a frame. @@ -163,8 +166,9 @@ def __str__(self) -> str: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. binary = self.data - if len(binary) > 25: - binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) + if len(binary) > self.MAX_LOG // 3: + cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: data = str(Close.parse(self.data)) @@ -179,15 +183,17 @@ def __str__(self) -> str: coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data - if len(binary) > 25: - binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]]) + if len(binary) > self.MAX_LOG // 3: + cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: data = "''" - if len(data) > 75: - data = data[:48] + "..." + data[-24:] + if len(data) > self.MAX_LOG: + cut = self.MAX_LOG // 3 - 1 # by default cut = 24 + data = data[: 2 * cut] + "..." + data[-cut:] metadata = ", ".join(filter(None, [coding, length, non_final])) From f0398141d2efd28f64d8e1d6d9adc179a9e5e334 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 16 Apr 2024 19:41:14 +0200 Subject: [PATCH 504/762] Bump asyncio_timeout to 4.0.3. This makes type checking pass again. --- src/websockets/legacy/async_timeout.py | 39 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/legacy/async_timeout.py index 8264094f5..6ffa89969 100644 --- a/src/websockets/legacy/async_timeout.py +++ b/src/websockets/legacy/async_timeout.py @@ -9,12 +9,12 @@ from typing import Optional, Type -# From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py -# Licensed under the Python Software Foundation License (PSF-2.0) - if sys.version_info >= (3, 11): from typing import final else: + # From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py + # Licensed under the Python Software Foundation License (PSF-2.0) + # @final exists in 3.8+, but we backport it for all versions # before 3.11 to keep support for the __final__ attribute. # See https://bugs.python.org/issue46342 @@ -49,10 +49,21 @@ class Other(Leaf): # Error reported by type checker pass return f + # End https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py + + +if sys.version_info >= (3, 11): + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + task.uncancel() + +else: + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + pass -# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py -__version__ = "4.0.2" +__version__ = "4.0.3" __all__ = ("timeout", "timeout_at", "Timeout") @@ -124,7 +135,7 @@ class Timeout: # The purpose is to time out as soon as possible # without waiting for the next await expression. - __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task") def __init__( self, deadline: Optional[float], loop: asyncio.AbstractEventLoop @@ -132,6 +143,7 @@ def __init__( self._loop = loop self._state = _State.INIT + self._task: Optional["asyncio.Task[object]"] = None self._timeout_handler = None # type: Optional[asyncio.Handle] if deadline is None: self._deadline = None # type: Optional[float] @@ -187,6 +199,7 @@ def reject(self) -> None: self._reject() def _reject(self) -> None: + self._task = None if self._timeout_handler is not None: self._timeout_handler.cancel() self._timeout_handler = None @@ -234,11 +247,11 @@ def _reschedule(self) -> None: if self._timeout_handler is not None: self._timeout_handler.cancel() - task = asyncio.current_task() + self._task = asyncio.current_task() if deadline <= now: - self._timeout_handler = self._loop.call_soon(self._on_timeout, task) + self._timeout_handler = self._loop.call_soon(self._on_timeout) else: - self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout) def _do_enter(self) -> None: if self._state != _State.INIT: @@ -248,15 +261,19 @@ def _do_enter(self) -> None: def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + assert self._task is not None + _uncancel_task(self._task) self._timeout_handler = None + self._task = None raise asyncio.TimeoutError # timeout has not expired self._state = _State.EXIT self._reject() return None - def _on_timeout(self, task: "asyncio.Task[None]") -> None: - task.cancel() + def _on_timeout(self) -> None: + assert self._task is not None + self._task.cancel() self._state = _State.TIMEOUT # drop the reference early self._timeout_handler = None From 33997631a04320a5f8d57fac0f2645dc2d654c29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 08:35:58 +0200 Subject: [PATCH 505/762] Update ruff. --- Makefile | 2 +- pyproject.toml | 4 ++-- tox.ini | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index bf8c8dc58..dacfe2a0b 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ default: style types tests style: black src tests - ruff --fix src tests + ruff check --fix src tests types: mypy --strict src diff --git a/pyproject.toml b/pyproject.toml index c4c5412c5..2367849ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ exclude_lines = [ "@unittest.skip", ] -[tool.ruff] +[tool.ruff.lint] select = [ "E", # pycodestyle "F", # Pyflakes @@ -82,6 +82,6 @@ ignore = [ "F405", ] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true lines-after-imports = 2 diff --git a/tox.ini b/tox.ini index 538b638d9..b0e4a5931 100644 --- a/tox.ini +++ b/tox.ini @@ -32,7 +32,7 @@ commands = black --check src tests deps = black [testenv:ruff] -commands = ruff src tests +commands = ruff check src tests deps = ruff [testenv:mypy] From 2d195baaa632efd9fb87f09813d01af28464eb8c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:01:24 +0200 Subject: [PATCH 506/762] Don't run tests on Python 3.7. Forgotten in 1bf73423. --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index b0e4a5931..06003c85b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,5 @@ [tox] envlist = - py37 py38 py39 py310 From 7f402303fe1703767d9236494aacc3f197fbc708 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:02:07 +0200 Subject: [PATCH 507/762] Switch from typing.Optional to | None. --- src/websockets/client.py | 16 +-- src/websockets/exceptions.py | 19 ++- src/websockets/extensions/base.py | 4 +- .../extensions/permessage_deflate.py | 32 ++--- src/websockets/frames.py | 8 +- src/websockets/headers.py | 6 +- src/websockets/http11.py | 14 +- src/websockets/imports.py | 6 +- src/websockets/legacy/auth.py | 18 +-- src/websockets/legacy/client.py | 77 +++++----- src/websockets/legacy/framing.py | 8 +- src/websockets/legacy/protocol.py | 65 ++++----- src/websockets/legacy/server.py | 135 +++++++++--------- src/websockets/protocol.py | 28 ++-- src/websockets/server.py | 35 ++--- src/websockets/sync/client.py | 44 +++--- src/websockets/sync/connection.py | 28 ++-- src/websockets/sync/messages.py | 12 +- src/websockets/sync/server.py | 92 ++++++------ src/websockets/sync/utils.py | 7 +- src/websockets/typing.py | 2 +- src/websockets/uri.py | 8 +- 22 files changed, 334 insertions(+), 330 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 633b1960b..cfb441fd9 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Generator, List, Optional, Sequence +from typing import Any, Generator, List, Sequence from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -73,12 +73,12 @@ def __init__( self, wsuri: WebSocketURI, *, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, state: State = CONNECTING, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ): super().__init__( side=CLIENT, @@ -261,7 +261,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: return accepted_extensions - def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -274,7 +274,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: Subprotocol, if one was selected. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None subprotocols = headers.get_all("Sec-WebSocket-Protocol") diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f7169e3b1..adb66e262 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,7 +31,6 @@ from __future__ import annotations import http -from typing import Optional from . import datastructures, frames, http11 from .typing import StatusLike @@ -78,11 +77,11 @@ class ConnectionClosed(WebSocketException): Raised when trying to interact with a closed connection. Attributes: - rcvd (Optional[Close]): if a close frame was received, its code and + rcvd (Close | None): if a close frame was received, its code and reason are available in ``rcvd.code`` and ``rcvd.reason``. - sent (Optional[Close]): if a close frame was sent, its code and reason + sent (Close | None): if a close frame was sent, its code and reason are available in ``sent.code`` and ``sent.reason``. - rcvd_then_sent (Optional[bool]): if close frames were received and + rcvd_then_sent (bool | None): if close frames were received and sent, this attribute tells in which order this happened, from the perspective of this side of the connection. @@ -90,9 +89,9 @@ class ConnectionClosed(WebSocketException): def __init__( self, - rcvd: Optional[frames.Close], - sent: Optional[frames.Close], - rcvd_then_sent: Optional[bool] = None, + rcvd: frames.Close | None, + sent: frames.Close | None, + rcvd_then_sent: bool | None = None, ) -> None: self.rcvd = rcvd self.sent = sent @@ -181,7 +180,7 @@ class InvalidHeader(InvalidHandshake): """ - def __init__(self, name: str, value: Optional[str] = None) -> None: + def __init__(self, name: str, value: str | None = None) -> None: self.name = name self.value = value @@ -221,7 +220,7 @@ class InvalidOrigin(InvalidHeader): """ - def __init__(self, origin: Optional[str]) -> None: + def __init__(self, origin: str | None) -> None: super().__init__("Origin", origin) @@ -301,7 +300,7 @@ class InvalidParameterValue(NegotiationError): """ - def __init__(self, name: str, value: Optional[str]) -> None: + def __init__(self, name: str, value: str | None) -> None: self.name = name self.value = value diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 7446c990c..5b5528a09 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Tuple +from typing import List, Sequence, Tuple from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -22,7 +22,7 @@ def decode( self, frame: frames.Frame, *, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index edccac3ca..e95b1064b 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple, Union from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -36,7 +36,7 @@ def __init__( local_no_context_takeover: bool, remote_max_window_bits: int, local_max_window_bits: int, - compress_settings: Optional[Dict[Any, Any]] = None, + compress_settings: Dict[Any, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension. @@ -84,7 +84,7 @@ def decode( self, frame: frames.Frame, *, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> frames.Frame: """ Decode an incoming frame. @@ -174,8 +174,8 @@ def encode(self, frame: frames.Frame) -> frames.Frame: def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, - server_max_window_bits: Optional[int], - client_max_window_bits: Optional[Union[int, bool]], + server_max_window_bits: int | None, + client_max_window_bits: Union[int, bool] | None, ) -> List[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]: +) -> Tuple[bool, bool, int | None, Union[int, bool] | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -207,8 +207,8 @@ def _extract_parameters( """ server_no_context_takeover: bool = False client_no_context_takeover: bool = False - server_max_window_bits: Optional[int] = None - client_max_window_bits: Optional[Union[int, bool]] = None + server_max_window_bits: int | None = None + client_max_window_bits: Union[int, bool] | None = None for name, value in params: if name == "server_no_context_takeover": @@ -286,9 +286,9 @@ def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[Union[int, bool]] = True, - compress_settings: Optional[Dict[str, Any]] = None, + server_max_window_bits: int | None = None, + client_max_window_bits: Union[int, bool] | None = True, + compress_settings: Dict[str, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -433,7 +433,7 @@ def process_response_params( def enable_client_permessage_deflate( - extensions: Optional[Sequence[ClientExtensionFactory]], + extensions: Sequence[ClientExtensionFactory] | None, ) -> Sequence[ClientExtensionFactory]: """ Enable Per-Message Deflate with default settings in client extensions. @@ -489,9 +489,9 @@ def __init__( self, server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, - client_max_window_bits: Optional[int] = None, - compress_settings: Optional[Dict[str, Any]] = None, + server_max_window_bits: int | None = None, + client_max_window_bits: int | None = None, + compress_settings: Dict[str, Any] | None = None, require_client_max_window_bits: bool = False, ) -> None: """ @@ -635,7 +635,7 @@ def process_request_params( def enable_server_permessage_deflate( - extensions: Optional[Sequence[ServerExtensionFactory]], + extensions: Sequence[ServerExtensionFactory] | None, ) -> Sequence[ServerExtensionFactory]: """ Enable Per-Message Deflate with default settings in server extensions. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 862eef3aa..5a304d6a7 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -5,7 +5,7 @@ import io import secrets import struct -from typing import Callable, Generator, Optional, Sequence, Tuple +from typing import Callable, Generator, Sequence, Tuple from . import exceptions, extensions from .typing import Data @@ -205,8 +205,8 @@ def parse( read_exact: Callable[[int], Generator[None, None, bytes]], *, mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence[extensions.Extension]] = None, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> Generator[None, None, Frame]: """ Parse a WebSocket frame. @@ -280,7 +280,7 @@ def serialize( self, *, mask: bool, - extensions: Optional[Sequence[extensions.Extension]] = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> bytes: """ Serialize a WebSocket frame. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 463df3061..3b316e0bf 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,7 @@ import binascii import ipaddress import re -from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast +from typing import Callable, List, Sequence, Tuple, TypeVar, cast from . import exceptions from .typing import ( @@ -63,7 +63,7 @@ def build_host(host: str, port: int, secure: bool) -> str: # https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. -def peek_ahead(header: str, pos: int) -> Optional[str]: +def peek_ahead(header: str, pos: int) -> str | None: """ Return the next character from ``header`` at the given position. @@ -314,7 +314,7 @@ def parse_extension_item_param( name, pos = parse_token(header, pos, header_name) pos = parse_OWS(header, pos) # Extract parameter value, if there is one. - value: Optional[str] = None + value: str | None = None if peek_ahead(header, pos) == "=": pos = parse_OWS(header, pos + 1) if peek_ahead(header, pos) == '"': diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 6fe775eec..a7e9ae682 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -3,7 +3,7 @@ import dataclasses import re import warnings -from typing import Callable, Generator, Optional +from typing import Callable, Generator from . import datastructures, exceptions @@ -62,10 +62,10 @@ class Request: headers: datastructures.Headers # body isn't useful is the context of this library. - _exception: Optional[Exception] = None + _exception: Exception | None = None @property - def exception(self) -> Optional[Exception]: # pragma: no cover + def exception(self) -> Exception | None: # pragma: no cover warnings.warn( "Request.exception is deprecated; " "use ServerProtocol.handshake_exc instead", @@ -164,12 +164,12 @@ class Response: status_code: int reason_phrase: str headers: datastructures.Headers - body: Optional[bytes] = None + body: bytes | None = None - _exception: Optional[Exception] = None + _exception: Exception | None = None @property - def exception(self) -> Optional[Exception]: # pragma: no cover + def exception(self) -> Exception | None: # pragma: no cover warnings.warn( "Response.exception is deprecated; " "use ClientProtocol.handshake_exc instead", @@ -245,7 +245,7 @@ def parse( if 100 <= status_code < 200 or status_code == 204 or status_code == 304: body = None else: - content_length: Optional[int] + content_length: int | None try: # MultipleValuesError is sufficiently unlikely that we don't # attempt to handle it. Instead we document that its parent diff --git a/src/websockets/imports.py b/src/websockets/imports.py index a6a59d4c2..9c05234f5 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable __all__ = ["lazy_import"] @@ -30,8 +30,8 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: def lazy_import( namespace: Dict[str, Any], - aliases: Optional[Dict[str, str]] = None, - deprecated_aliases: Optional[Dict[str, str]] = None, + aliases: Dict[str, str] | None = None, + deprecated_aliases: Dict[str, str] | None = None, ) -> None: """ Provide lazy, module-level imports. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8217afedd..067f9c78c 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,7 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast +from typing import Any, Awaitable, Callable, Iterable, Tuple, Union, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -39,14 +39,14 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): encoding of non-ASCII characters is undefined. """ - username: Optional[str] = None + username: str | None = None """Username of the authenticated user.""" def __init__( self, *args: Any, - realm: Optional[str] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, + realm: str | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, **kwargs: Any, ) -> None: if realm is not None: @@ -79,7 +79,7 @@ async def process_request( self, path: str, request_headers: Headers, - ) -> Optional[HTTPResponse]: + ) -> HTTPResponse | None: """ Check HTTP Basic Auth and return an HTTP 401 response if needed. @@ -115,10 +115,10 @@ async def process_request( def basic_auth_protocol_factory( - realm: Optional[str] = None, - credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, - check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, - create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, + realm: str | None = None, + credentials: Union[Credentials, Iterable[Credentials]] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, + create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index e5da8b13a..f7464368f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -13,7 +13,6 @@ Callable, Generator, List, - Optional, Sequence, Tuple, Type, @@ -86,12 +85,12 @@ class WebSocketClientProtocol(WebSocketCommonProtocol): def __init__( self, *, - logger: Optional[LoggerLike] = None, - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, + logger: LoggerLike | None = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, **kwargs: Any, ) -> None: if logger is None: @@ -152,7 +151,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: @staticmethod def process_extensions( headers: Headers, - available_extensions: Optional[Sequence[ClientExtensionFactory]], + available_extensions: Sequence[ClientExtensionFactory] | None, ) -> List[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -224,8 +223,8 @@ def process_extensions( @staticmethod def process_subprotocol( - headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP response header. @@ -234,7 +233,7 @@ def process_subprotocol( Return the selected subprotocol. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -260,10 +259,10 @@ def process_subprotocol( async def handshake( self, wsuri: WebSocketURI, - origin: Optional[Origin] = None, - available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, + origin: Origin | None = None, + available_extensions: Sequence[ClientExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, ) -> None: """ Perform the client side of the opening handshake. @@ -427,26 +426,26 @@ def __init__( self, uri: str, *, - create_protocol: Optional[Callable[..., WebSocketClientProtocol]] = None, - logger: Optional[LoggerLike] = None, - compression: Optional[str] = "deflate", - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - open_timeout: Optional[float] = 10, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + create_protocol: Callable[..., WebSocketClientProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. - timeout: Optional[float] = kwargs.pop("timeout", None) + timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -456,7 +455,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None) + klass: Type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: @@ -469,7 +468,7 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: @@ -516,13 +515,13 @@ def __init__( ) if kwargs.pop("unix", False): - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) create_connection = functools.partial( loop.create_unix_connection, factory, path, **kwargs ) else: - host: Optional[str] - port: Optional[int] + host: str | None + port: int | None if kwargs.get("sock") is None: host, port = wsuri.host, wsuri.port else: @@ -630,9 +629,9 @@ async def __aenter__(self) -> WebSocketClientProtocol: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: await self.protocol.close() @@ -679,7 +678,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol: def unix_connect( - path: Optional[str] = None, + path: str | None = None, uri: str = "ws://localhost/", **kwargs: Any, ) -> Connect: diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index b77b869e3..8a13fa446 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Awaitable, Callable, NamedTuple, Sequence, Tuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -44,8 +44,8 @@ async def read( reader: Callable[[int], Awaitable[bytes]], *, mask: bool, - max_size: Optional[int] = None, - extensions: Optional[Sequence[extensions.Extension]] = None, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> Frame: """ Read a WebSocket frame. @@ -122,7 +122,7 @@ def write( write: Callable[[bytes], Any], *, mask: bool, - extensions: Optional[Sequence[extensions.Extension]] = None, + extensions: Sequence[extensions.Extension] | None = None, ) -> None: """ Write a WebSocket frame. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 26d50a2cc..94d42cfdb 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -22,7 +22,6 @@ Iterable, List, Mapping, - Optional, Tuple, Union, cast, @@ -173,21 +172,21 @@ class WebSocketCommonProtocol(asyncio.Protocol): def __init__( self, *, - logger: Optional[LoggerLike] = None, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + logger: LoggerLike | None = None, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, # The following arguments are kept only for backwards compatibility. - host: Optional[str] = None, - port: Optional[int] = None, - secure: Optional[bool] = None, + host: str | None = None, + port: int | None = None, + secure: bool | None = None, legacy_recv: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, - timeout: Optional[float] = None, + loop: asyncio.AbstractEventLoop | None = None, + timeout: float | None = None, ) -> None: if legacy_recv: # pragma: no cover warnings.warn("legacy_recv is deprecated", DeprecationWarning) @@ -243,7 +242,7 @@ def __init__( # Copied from asyncio.FlowControlMixin self._paused = False - self._drain_waiter: Optional[asyncio.Future[None]] = None + self._drain_waiter: asyncio.Future[None] | None = None self._drain_lock = asyncio.Lock() @@ -265,13 +264,13 @@ def __init__( # WebSocket protocol parameters. self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None + self.subprotocol: Subprotocol | None = None """Subprotocol, if one was negotiated.""" # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None # Completed when the connection state becomes CLOSED. Translates the # :meth:`connection_lost` callback to a :class:`~asyncio.Future` @@ -281,11 +280,11 @@ def __init__( # Queue of received messages. self.messages: Deque[Data] = collections.deque() - self._pop_message_waiter: Optional[asyncio.Future[None]] = None - self._put_message_waiter: Optional[asyncio.Future[None]] = None + self._pop_message_waiter: asyncio.Future[None] | None = None + self._put_message_waiter: asyncio.Future[None] | None = None # Protect sending fragmented messages. - self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None + self._fragmented_message_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} @@ -306,7 +305,7 @@ def __init__( self.transfer_data_task: asyncio.Task[None] # Exception that occurred during data transfer, if any. - self.transfer_data_exc: Optional[BaseException] = None + self.transfer_data_exc: BaseException | None = None # Task sending keepalive pings. self.keepalive_ping_task: asyncio.Task[None] @@ -363,19 +362,19 @@ def connection_open(self) -> None: self.close_connection_task = self.loop.create_task(self.close_connection()) @property - def host(self) -> Optional[str]: + def host(self) -> str | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) return self._host @property - def port(self) -> Optional[int]: + def port(self) -> int | None: alternative = "remote_address" if self.is_client else "local_address" warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) return self._port @property - def secure(self) -> Optional[bool]: + def secure(self) -> bool | None: warnings.warn("don't use secure", DeprecationWarning) return self._secure @@ -447,7 +446,7 @@ def closed(self) -> bool: return self.state is State.CLOSED @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: """ WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. @@ -465,7 +464,7 @@ def close_code(self) -> Optional[int]: return self.close_rcvd.code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: """ WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. @@ -804,7 +803,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Optional[Data] = None) -> Awaitable[float]: + async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -1017,7 +1016,7 @@ async def transfer_data(self) -> None: self.transfer_data_exc = exc self.fail_connection(CloseCode.INTERNAL_ERROR) - async def read_message(self) -> Optional[Data]: + async def read_message(self) -> Data | None: """ Read a single message from the connection. @@ -1090,7 +1089,7 @@ def append(frame: Frame) -> None: return ("" if text else b"").join(fragments) - async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: + async def read_data_frame(self, max_size: int | None) -> Frame | None: """ Read a single data frame from the connection. @@ -1153,7 +1152,7 @@ async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: else: return frame - async def read_frame(self, max_size: Optional[int]) -> Frame: + async def read_frame(self, max_size: int | None) -> Frame: """ Read a single frame from the connection. @@ -1204,9 +1203,7 @@ async def write_frame( self.write_frame_sync(fin, opcode, data) await self.drain() - async def write_close_frame( - self, close: Close, data: Optional[bytes] = None - ) -> None: + async def write_close_frame(self, close: Close, data: bytes | None = None) -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1484,7 +1481,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) - def connection_lost(self, exc: Optional[Exception]) -> None: + def connection_lost(self, exc: Exception | None) -> None: """ 7.1.4. The WebSocket Connection is Closed. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 0f3c1c150..551115174 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -16,7 +16,6 @@ Generator, Iterable, List, - Optional, Sequence, Set, Tuple, @@ -103,19 +102,19 @@ def __init__( ], ws_server: WebSocketServer, *, - logger: Optional[LoggerLike] = None, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - server_header: Optional[str] = USER_AGENT, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - open_timeout: Optional[float] = 10, + logger: LoggerLike | None = None, + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = USER_AGENT, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, **kwargs: Any, ) -> None: if logger is None: @@ -293,7 +292,7 @@ async def read_http_request(self) -> Tuple[str, Headers]: return path, headers def write_http_response( - self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None + self, status: http.HTTPStatus, headers: Headers, body: bytes | None = None ) -> None: """ Write status line and headers to the HTTP response. @@ -322,7 +321,7 @@ def write_http_response( async def process_request( self, path: str, request_headers: Headers - ) -> Optional[HTTPResponse]: + ) -> HTTPResponse | None: """ Intercept the HTTP request and return an HTTP response if appropriate. @@ -371,8 +370,8 @@ async def process_request( @staticmethod def process_origin( - headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None - ) -> Optional[Origin]: + headers: Headers, origins: Sequence[Origin | None] | None = None + ) -> Origin | None: """ Handle the Origin HTTP request header. @@ -387,9 +386,11 @@ def process_origin( # "The user agent MUST NOT include more than one Origin header field" # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: - origin = cast(Optional[Origin], headers.get("Origin")) + origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origin is not None: + origin = cast(Origin, origin) if origins is not None: if origin not in origins: raise InvalidOrigin(origin) @@ -398,8 +399,8 @@ def process_origin( @staticmethod def process_extensions( headers: Headers, - available_extensions: Optional[Sequence[ServerExtensionFactory]], - ) -> Tuple[Optional[str], List[Extension]]: + available_extensions: Sequence[ServerExtensionFactory] | None, + ) -> Tuple[str | None, List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -435,7 +436,7 @@ def process_extensions( InvalidHandshake: To abort the handshake with an HTTP 400 error. """ - response_header_value: Optional[str] = None + response_header_value: str | None = None extension_headers: List[ExtensionHeader] = [] accepted_extensions: List[Extension] = [] @@ -479,8 +480,8 @@ def process_extensions( # Not @staticmethod because it calls self.select_subprotocol() def process_subprotocol( - self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] - ) -> Optional[Subprotocol]: + self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -495,7 +496,7 @@ def process_subprotocol( InvalidHandshake: To abort the handshake with an HTTP 400 error. """ - subprotocol: Optional[Subprotocol] = None + subprotocol: Subprotocol | None = None header_values = headers.get_all("Sec-WebSocket-Protocol") @@ -514,7 +515,7 @@ def select_subprotocol( self, client_subprotocols: Sequence[Subprotocol], server_subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: """ Pick a subprotocol among those supported by the client and the server. @@ -552,10 +553,10 @@ def select_subprotocol( async def handshake( self, - origins: Optional[Sequence[Optional[Origin]]] = None, - available_extensions: Optional[Sequence[ServerExtensionFactory]] = None, - available_subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, + origins: Sequence[Origin | None] | None = None, + available_extensions: Sequence[ServerExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, ) -> str: """ Perform the server side of the opening handshake. @@ -661,7 +662,7 @@ class WebSocketServer: """ - def __init__(self, logger: Optional[LoggerLike] = None): + def __init__(self, logger: LoggerLike | None = None): if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger @@ -670,7 +671,7 @@ def __init__(self, logger: Optional[LoggerLike] = None): self.websockets: Set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. - self.close_task: Optional[asyncio.Task[None]] = None + self.close_task: asyncio.Task[None] | None = None # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] @@ -869,9 +870,9 @@ async def __aenter__(self) -> WebSocketServer: # pragma: no cover async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: # pragma: no cover self.close() await self.wait_closed() @@ -941,8 +942,8 @@ class Serve: server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - process_request (Optional[Callable[[str, Headers], \ - Awaitable[Optional[Tuple[StatusLike, HeadersLike, bytes]]]]]): + process_request (Callable[[str, Headers], \ + Awaitable[Tuple[StatusLike, HeadersLike, bytes] | None]] | None): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. @@ -975,35 +976,35 @@ def __init__( Callable[[WebSocketServerProtocol], Awaitable[Any]], Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated ], - host: Optional[Union[str, Sequence[str]]] = None, - port: Optional[int] = None, + host: Union[str, Sequence[str]] | None = None, + port: int | None = None, *, - create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None, - logger: Optional[LoggerLike] = None, - compression: Optional[str] = "deflate", - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - extra_headers: Optional[HeadersLikeOrCallable] = None, - server_header: Optional[str] = USER_AGENT, - process_request: Optional[ - Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] - ] = None, - select_subprotocol: Optional[ - Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] - ] = None, - open_timeout: Optional[float] = 10, - ping_interval: Optional[float] = 20, - ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = None, - max_size: Optional[int] = 2**20, - max_queue: Optional[int] = 2**5, + create_protocol: Callable[..., WebSocketServerProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = USER_AGENT, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, read_limit: int = 2**16, write_limit: int = 2**16, **kwargs: Any, ) -> None: # Backwards compatibility: close_timeout used to be called timeout. - timeout: Optional[float] = kwargs.pop("timeout", None) + timeout: float | None = kwargs.pop("timeout", None) if timeout is None: timeout = 10 else: @@ -1013,7 +1014,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Optional[Type[WebSocketServerProtocol]] = kwargs.pop("klass", None) + klass: Type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: @@ -1026,7 +1027,7 @@ def __init__( legacy_recv: bool = kwargs.pop("legacy_recv", False) # Backwards compatibility: the loop parameter used to be supported. - _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) if _loop is None: loop = asyncio.get_event_loop() else: @@ -1076,7 +1077,7 @@ def __init__( ) if kwargs.pop("unix", False): - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) # unix_serve(path) must not specify host and port parameters. assert host is None and port is None create_server = functools.partial( @@ -1098,9 +1099,9 @@ async def __aenter__(self) -> WebSocketServer: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.ws_server.close() await self.ws_server.wait_closed() @@ -1129,7 +1130,7 @@ def unix_serve( Callable[[WebSocketServerProtocol], Awaitable[Any]], Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated ], - path: Optional[str] = None, + path: str | None = None, **kwargs: Any, ) -> Serve: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0b36202e5..8aa222eeb 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,7 @@ import enum import logging import uuid -from typing import Generator, List, Optional, Type, Union +from typing import Generator, List, Type, Union from .exceptions import ( ConnectionClosed, @@ -89,8 +89,8 @@ def __init__( side: Side, *, state: State = OPEN, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ) -> None: # Unique identifier. For logs. self.id: uuid.UUID = uuid.uuid4() @@ -116,24 +116,24 @@ def __init__( # Current size of incoming message in bytes. Only set while reading a # fragmented message i.e. a data frames with the FIN bit not set. - self.cur_size: Optional[int] = None + self.cur_size: int | None = None # True while sending a fragmented message i.e. a data frames with the # FIN bit not set. self.expect_continuation_frame = False # WebSocket protocol parameters. - self.origin: Optional[Origin] = None + self.origin: Origin | None = None self.extensions: List[Extension] = [] - self.subprotocol: Optional[Subprotocol] = None + self.subprotocol: Subprotocol | None = None # Close code and reason, set when a close frame is sent or received. - self.close_rcvd: Optional[Close] = None - self.close_sent: Optional[Close] = None - self.close_rcvd_then_sent: Optional[bool] = None + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None # Track if an exception happened during the handshake. - self.handshake_exc: Optional[Exception] = None + self.handshake_exc: Exception | None = None """ Exception to raise if the opening handshake failed. @@ -150,7 +150,7 @@ def __init__( self.writes: List[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine - self.parser_exc: Optional[Exception] = None + self.parser_exc: Exception | None = None @property def state(self) -> State: @@ -169,7 +169,7 @@ def state(self, state: State) -> None: self._state = state @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: """ `WebSocket close code`_. @@ -187,7 +187,7 @@ def close_code(self) -> Optional[int]: return self.close_rcvd.code @property - def close_reason(self) -> Optional[str]: + def close_reason(self) -> str | None: """ `WebSocket close reason`_. @@ -348,7 +348,7 @@ def send_binary(self, data: bytes, fin: bool = True) -> None: self.expect_continuation_frame = not fin self.send_frame(Frame(OP_BINARY, data, fin)) - def send_close(self, code: Optional[int] = None, reason: str = "") -> None: + def send_close(self, code: int | None = None, reason: str = "") -> None: """ Send a `Close frame`_. diff --git a/src/websockets/server.py b/src/websockets/server.py index 330e54f37..a92541085 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, Generator, List, Sequence, Tuple, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -77,18 +77,19 @@ class ServerProtocol(Protocol): def __init__( self, *, - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[ + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None, + | None + ) = None, state: State = CONNECTING, - max_size: Optional[int] = 2**20, - logger: Optional[LoggerLike] = None, + max_size: int | None = 2**20, + logger: LoggerLike | None = None, ): super().__init__( side=SERVER, @@ -200,7 +201,7 @@ def accept(self, request: Request) -> Response: def process_request( self, request: Request, - ) -> Tuple[str, Optional[str], Optional[str]]: + ) -> Tuple[str, str | None, str | None]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -285,7 +286,7 @@ def process_request( protocol_header, ) - def process_origin(self, headers: Headers) -> Optional[Origin]: + def process_origin(self, headers: Headers) -> Origin | None: """ Handle the Origin HTTP request header. @@ -303,9 +304,11 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: # "The user agent MUST NOT include more than one Origin header field" # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. try: - origin = cast(Optional[Origin], headers.get("Origin")) + origin = headers.get("Origin") except MultipleValuesError as exc: raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origin is not None: + origin = cast(Origin, origin) if self.origins is not None: if origin not in self.origins: raise InvalidOrigin(origin) @@ -314,7 +317,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]: def process_extensions( self, headers: Headers, - ) -> Tuple[Optional[str], List[Extension]]: + ) -> Tuple[str | None, List[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -350,7 +353,7 @@ def process_extensions( InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. """ - response_header_value: Optional[str] = None + response_header_value: str | None = None extension_headers: List[ExtensionHeader] = [] accepted_extensions: List[Extension] = [] @@ -392,7 +395,7 @@ def process_extensions( return response_header_value, accepted_extensions - def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: """ Handle the Sec-WebSocket-Protocol HTTP request header. @@ -420,7 +423,7 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: def select_subprotocol( self, subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: """ Pick a subprotocol among those offered by the client. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 0bb7a76fd..60b49ebc3 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,7 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Optional, Sequence, Type +from typing import Any, Sequence, Type from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -52,7 +52,7 @@ def __init__( socket: socket.socket, protocol: ClientProtocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -64,9 +64,9 @@ def __init__( def handshake( self, - additional_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - timeout: Optional[float] = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + timeout: float | None = None, ) -> None: """ Perform the opening handshake. @@ -128,25 +128,25 @@ def connect( uri: str, *, # TCP/TLS - sock: Optional[socket.socket] = None, - ssl: Optional[ssl_module.SSLContext] = None, - server_hostname: Optional[str] = None, + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, # WebSocket - origin: Optional[Origin] = None, - extensions: Optional[Sequence[ClientExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - additional_headers: Optional[HeadersLike] = None, - user_agent_header: Optional[str] = USER_AGENT, - compression: Optional[str] = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", # Timeouts - open_timeout: Optional[float] = 10, - close_timeout: Optional[float] = 10, + open_timeout: float | None = 10, + close_timeout: float | None = 10, # Limits - max_size: Optional[int] = 2**20, + max_size: int | None = 2**20, # Logging - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Optional[Type[ClientConnection]] = None, + create_connection: Type[ClientConnection] | None = None, **kwargs: Any, ) -> ClientConnection: """ @@ -219,7 +219,7 @@ def connect( # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) if unix: if path is None and sock is None: @@ -307,8 +307,8 @@ def connect( def unix_connect( - path: Optional[str] = None, - uri: Optional[str] = None, + path: str | None = None, + uri: str | None = None, **kwargs: Any, ) -> ClientConnection: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index b41202dc9..bb9743181 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Type, Union +from typing import Any, Dict, Iterable, Iterator, Mapping, Type, Union from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -42,7 +42,7 @@ def __init__( socket: socket.socket, protocol: Protocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.socket = socket self.protocol = protocol @@ -62,9 +62,9 @@ def __init__( self.debug = self.protocol.debug # HTTP handshake request and response. - self.request: Optional[Request] = None + self.request: Request | None = None """Opening handshake request.""" - self.response: Optional[Response] = None + self.response: Response | None = None """Opening handshake response.""" # Mutex serializing interactions with the protocol. @@ -77,7 +77,7 @@ def __init__( self.send_in_progress = False # Deadline for the closing handshake. - self.close_deadline: Optional[Deadline] = None + self.close_deadline: Deadline | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: Dict[bytes, threading.Event] = {} @@ -93,7 +93,7 @@ def __init__( # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: Optional[BaseException] = None + self.recv_exc: BaseException | None = None # Public attributes @@ -124,7 +124,7 @@ def remote_address(self) -> Any: return self.socket.getpeername() @property - def subprotocol(self) -> Optional[Subprotocol]: + def subprotocol(self) -> Subprotocol | None: """ Subprotocol negotiated during the opening handshake. @@ -140,9 +140,9 @@ def __enter__(self) -> Connection: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: if exc_type is None: self.close() @@ -166,7 +166,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: Optional[float] = None) -> Data: + def recv(self, timeout: float | None = None) -> Data: """ Receive the next message. @@ -420,7 +420,7 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: # They mean that the connection is closed, which was the goal. pass - def ping(self, data: Optional[Data] = None) -> threading.Event: + def ping(self, data: Data | None = None) -> threading.Event: """ Send a Ping_. @@ -647,7 +647,7 @@ def send_context( # Should we close the socket and raise ConnectionClosed? raise_close_exc = False # What exception should we chain ConnectionClosed to? - original_exc: Optional[BaseException] = None + original_exc: BaseException | None = None # Acquire the protocol lock. with self.protocol_mutex: @@ -748,7 +748,7 @@ def send_data(self) -> None: except OSError: # socket already closed pass - def set_recv_exc(self, exc: Optional[BaseException]) -> None: + def set_recv_exc(self, exc: BaseException | None) -> None: """ Set recv_exc, if not set yet. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index dcba183d9..2c604ba09 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Iterator, List, Optional, cast +from typing import Iterator, List, cast from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -41,7 +41,7 @@ def __init__(self) -> None: self.put_in_progress = False # Decoder for text frames, None for binary frames. - self.decoder: Optional[codecs.IncrementalDecoder] = None + self.decoder: codecs.IncrementalDecoder | None = None # Buffer of frames belonging to the same message. self.chunks: List[Data] = [] @@ -54,12 +54,12 @@ def __init__(self) -> None: # Stream data from frames belonging to the same message. # Remove quotes around type when dropping Python < 3.9. - self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None + self.chunks_queue: "queue.SimpleQueue[Data | None] | None" = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: Optional[float] = None) -> Data: + def get(self, timeout: float | None = None) -> Data: """ Read the next message. @@ -151,7 +151,7 @@ def get_iter(self) -> Iterator[Data]: self.chunks = [] self.chunks_queue = cast( # Remove quotes around type when dropping Python < 3.9. - "queue.SimpleQueue[Optional[Data]]", + "queue.SimpleQueue[Data | None]", queue.SimpleQueue(), ) @@ -164,7 +164,7 @@ def get_iter(self) -> Iterator[Data]: self.get_in_progress = True # Locking with get_in_progress ensures only one thread can get here. - chunk: Optional[Data] + chunk: Data | None for chunk in chunks: yield chunk while (chunk := self.chunks_queue.get()) is not None: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index fd4f5d3bd..b801510b4 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,7 +10,7 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Optional, Sequence, Type +from typing import Any, Callable, Sequence, Type from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate @@ -57,7 +57,7 @@ def __init__( socket: socket.socket, protocol: ServerProtocol, *, - close_timeout: Optional[float] = 10, + close_timeout: float | None = 10, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -69,20 +69,22 @@ def __init__( def handshake( self, - process_request: Optional[ + process_request: ( Callable[ [ServerConnection, Request], - Optional[Response], + Response | None, ] - ] = None, - process_response: Optional[ + | None + ) = None, + process_response: ( Callable[ [ServerConnection, Request, Response], - Optional[Response], + Response | None, ] - ] = None, - server_header: Optional[str] = USER_AGENT, - timeout: Optional[float] = None, + | None + ) = None, + server_header: str | None = USER_AGENT, + timeout: float | None = None, ) -> None: """ Perform the opening handshake. @@ -197,7 +199,7 @@ def __init__( self, socket: socket.socket, handler: Callable[[socket.socket, Any], None], - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, ): self.socket = socket self.handler = handler @@ -260,54 +262,57 @@ def __enter__(self) -> WebSocketServer: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.shutdown() def serve( handler: Callable[[ServerConnection], None], - host: Optional[str] = None, - port: Optional[int] = None, + host: str | None = None, + port: int | None = None, *, # TCP/TLS - sock: Optional[socket.socket] = None, - ssl: Optional[ssl_module.SSLContext] = None, + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, # WebSocket - origins: Optional[Sequence[Optional[Origin]]] = None, - extensions: Optional[Sequence[ServerExtensionFactory]] = None, - subprotocols: Optional[Sequence[Subprotocol]] = None, - select_subprotocol: Optional[ + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( Callable[ [ServerConnection, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None, - process_request: Optional[ + | None + ) = None, + process_request: ( Callable[ [ServerConnection, Request], - Optional[Response], + Response | None, ] - ] = None, - process_response: Optional[ + | None + ) = None, + process_response: ( Callable[ [ServerConnection, Request, Response], - Optional[Response], + Response | None, ] - ] = None, - server_header: Optional[str] = USER_AGENT, - compression: Optional[str] = "deflate", + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", # Timeouts - open_timeout: Optional[float] = 10, - close_timeout: Optional[float] = 10, + open_timeout: float | None = 10, + close_timeout: float | None = 10, # Limits - max_size: Optional[int] = 2**20, + max_size: int | None = 2**20, # Logging - logger: Optional[LoggerLike] = None, + logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Optional[Type[ServerConnection]] = None, + create_connection: Type[ServerConnection] | None = None, **kwargs: Any, ) -> WebSocketServer: """ @@ -412,7 +417,7 @@ def handler(websocket): # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) - path: Optional[str] = kwargs.pop("path", None) + path: str | None = kwargs.pop("path", None) if sock is None: if unix: @@ -460,18 +465,19 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: sock.settimeout(None) # Create a closure to give select_subprotocol access to connection. - protocol_select_subprotocol: Optional[ + protocol_select_subprotocol: ( Callable[ [ServerProtocol, Sequence[Subprotocol]], - Optional[Subprotocol], + Subprotocol | None, ] - ] = None + | None + ) = None if select_subprotocol is not None: def protocol_select_subprotocol( protocol: ServerProtocol, subprotocols: Sequence[Subprotocol], - ) -> Optional[Subprotocol]: + ) -> Subprotocol | None: # mypy doesn't know that select_subprotocol is immutable. assert select_subprotocol is not None # Ensure this function is only used in the intended context. @@ -525,7 +531,7 @@ def protocol_select_subprotocol( def unix_serve( handler: Callable[[ServerConnection], None], - path: Optional[str] = None, + path: str | None = None, **kwargs: Any, ) -> WebSocketServer: """ diff --git a/src/websockets/sync/utils.py b/src/websockets/sync/utils.py index 3364bdc2d..00bce2cc6 100644 --- a/src/websockets/sync/utils.py +++ b/src/websockets/sync/utils.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -from typing import Optional __all__ = ["Deadline"] @@ -16,14 +15,14 @@ class Deadline: """ - def __init__(self, timeout: Optional[float]) -> None: - self.deadline: Optional[float] + def __init__(self, timeout: float | None) -> None: + self.deadline: float | None if timeout is None: self.deadline = None else: self.deadline = time.monotonic() + timeout - def timeout(self, *, raise_if_elapsed: bool = True) -> Optional[float]: + def timeout(self, *, raise_if_elapsed: bool = True) -> float | None: """ Calculate a timeout from a deadline. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 5dfecf66f..7c5b3664d 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -53,7 +53,7 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" - +# Change to str | None when dropping Python < 3.10. ExtensionParameter = Tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 8cf581743..902716066 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,7 +2,7 @@ import dataclasses import urllib.parse -from typing import Optional, Tuple +from typing import Tuple from . import exceptions @@ -33,8 +33,8 @@ class WebSocketURI: port: int path: str query: str - username: Optional[str] = None - password: Optional[str] = None + username: str | None = None + password: str | None = None @property def resource_name(self) -> str: @@ -47,7 +47,7 @@ def resource_name(self) -> str: return resource_name @property - def user_info(self) -> Optional[Tuple[str, str]]: + def user_info(self) -> Tuple[str, str] | None: if self.username is None: return None assert self.password is not None From 63a2d8eff62fe487c42a8f3176528730b7eed727 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:16:17 +0200 Subject: [PATCH 508/762] Switch from typing.Union to |. --- src/websockets/datastructures.py | 1 + .../extensions/permessage_deflate.py | 10 ++--- src/websockets/legacy/auth.py | 4 +- src/websockets/legacy/protocol.py | 3 +- src/websockets/legacy/server.py | 37 ++++++++++--------- src/websockets/protocol.py | 1 + src/websockets/sync/connection.py | 4 +- src/websockets/typing.py | 3 ++ 8 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index aef11bf23..5605772d8 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -177,6 +177,7 @@ def keys(self) -> Iterable[str]: ... def __getitem__(self, key: str) -> str: ... +# Change to Headers | Mapping[str, str] | ... when dropping Python < 3.10. HeadersLike = Union[ Headers, Mapping[str, str], diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index e95b1064b..48a6a0833 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -175,7 +175,7 @@ def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, server_max_window_bits: int | None, - client_max_window_bits: Union[int, bool] | None, + client_max_window_bits: int | bool | None, ) -> List[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, int | None, Union[int, bool] | None]: +) -> Tuple[bool, bool, int | None, int | bool | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -208,7 +208,7 @@ def _extract_parameters( server_no_context_takeover: bool = False client_no_context_takeover: bool = False server_max_window_bits: int | None = None - client_max_window_bits: Union[int, bool] | None = None + client_max_window_bits: int | bool | None = None for name, value in params: if name == "server_no_context_takeover": @@ -287,7 +287,7 @@ def __init__( server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, - client_max_window_bits: Union[int, bool] | None = True, + client_max_window_bits: int | bool | None = True, compress_settings: Dict[str, Any] | None = None, ) -> None: """ diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 067f9c78c..9d685d9f4 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,7 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Tuple, Union, cast +from typing import Any, Awaitable, Callable, Iterable, Tuple, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -116,7 +116,7 @@ async def process_request( def basic_auth_protocol_factory( realm: str | None = None, - credentials: Union[Credentials, Iterable[Credentials]] | None = None, + credentials: Credentials | Iterable[Credentials] | None = None, check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 94d42cfdb..f4c5901dc 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -23,7 +23,6 @@ List, Mapping, Tuple, - Union, cast, ) @@ -578,7 +577,7 @@ async def recv(self) -> Data: async def send( self, - message: Union[Data, Iterable[Data], AsyncIterable[Data]], + message: Data | Iterable[Data] | AsyncIterable[Data], ) -> None: """ Send a message. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 551115174..13a6f5591 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -54,6 +54,7 @@ __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] +# Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] @@ -96,10 +97,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), ws_server: WebSocketServer, *, logger: LoggerLike | None = None, @@ -934,7 +935,7 @@ class Serve: should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. - extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]): + extra_headers (HeadersLike | Callable[[str, Headers] | HeadersLike]): Arbitrary HTTP headers to add to the response. This can be a :data:`~websockets.datastructures.HeadersLike` or a callable taking the request path and headers in arguments and returning @@ -972,11 +973,11 @@ class Serve: def __init__( self, - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], - host: Union[str, Sequence[str]] | None = None, + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), + host: str | Sequence[str] | None = None, port: int | None = None, *, create_protocol: Callable[..., WebSocketServerProtocol] | None = None, @@ -1126,10 +1127,10 @@ async def __await_impl__(self) -> WebSocketServer: def unix_serve( - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated - ], + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + ), path: str | None = None, **kwargs: Any, ) -> Serve: @@ -1152,10 +1153,10 @@ def unix_serve( def remove_path_argument( - ws_handler: Union[ - Callable[[WebSocketServerProtocol], Awaitable[Any]], - Callable[[WebSocketServerProtocol, str], Awaitable[Any]], - ] + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] + ) ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: try: inspect.signature(ws_handler).bind(None) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8aa222eeb..f288a2733 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -38,6 +38,7 @@ "SEND_EOF", ] +# Change to Request | Response | Frame when dropping Python < 3.10. Event = Union[Request, Response, Frame] """Events that :meth:`~Protocol.events_received` may return.""" diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index bb9743181..7a750331d 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Type, Union +from typing import Any, Dict, Iterable, Iterator, Mapping, Type from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -242,7 +242,7 @@ def recv_streaming(self) -> Iterator[Data]: "is already running recv or recv_streaming" ) from None - def send(self, message: Union[Data, Iterable[Data]]) -> None: + def send(self, message: Data | Iterable[Data]) -> None: """ Send a message. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 7c5b3664d..73d4a4754 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -19,6 +19,7 @@ # Public types used in the signature of public APIs +# Change to str | bytes when dropping Python < 3.10. Data = Union[str, bytes] """Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. @@ -29,6 +30,7 @@ """ +# Change to logging.Logger | ... when dropping Python < 3.10. if typing.TYPE_CHECKING: LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" @@ -37,6 +39,7 @@ """Types accepted where a :class:`~logging.Logger` is expected.""" +# Change to http.HTTPStatus | int when dropping Python < 3.10. StatusLike = Union[http.HTTPStatus, int] """ Types accepted where an :class:`~http.HTTPStatus` is expected.""" From cd059d5633eed129775572054a7458d8e3f07166 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:28:36 +0200 Subject: [PATCH 509/762] Switch from typing.Dict/List/Tuple/Type/Set to native types. --- src/websockets/client.py | 12 +++---- src/websockets/datastructures.py | 11 +++--- src/websockets/extensions/base.py | 6 ++-- .../extensions/permessage_deflate.py | 18 +++++----- src/websockets/frames.py | 4 +-- src/websockets/headers.py | 34 +++++++++---------- src/websockets/imports.py | 10 +++--- src/websockets/legacy/auth.py | 1 + src/websockets/legacy/client.py | 15 ++++---- src/websockets/legacy/framing.py | 4 +-- src/websockets/legacy/handshake.py | 9 +++-- src/websockets/legacy/http.py | 5 ++- src/websockets/legacy/protocol.py | 9 ++--- src/websockets/legacy/server.py | 28 +++++++-------- src/websockets/protocol.py | 14 ++++---- src/websockets/server.py | 16 ++++----- src/websockets/sync/client.py | 4 +-- src/websockets/sync/connection.py | 6 ++-- src/websockets/sync/messages.py | 4 +-- src/websockets/sync/server.py | 6 ++-- src/websockets/typing.py | 4 ++- src/websockets/uri.py | 3 +- 22 files changed, 107 insertions(+), 116 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index cfb441fd9..8f78ac320 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Generator, List, Sequence +from typing import Any, Generator, Sequence from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -153,7 +153,7 @@ def process_response(self, response: Response) -> None: headers = response.headers - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) @@ -162,7 +162,7 @@ def process_response(self, response: Response) -> None: "Connection", ", ".join(connection) if connection else None ) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -188,7 +188,7 @@ def process_response(self, response: Response) -> None: self.subprotocol = self.process_subprotocol(headers) - def process_extensions(self, headers: Headers) -> List[Extension]: + def process_extensions(self, headers: Headers) -> list[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -219,7 +219,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: InvalidHandshake: To abort the handshake. """ - accepted_extensions: List[Extension] = [] + accepted_extensions: list[Extension] = [] extensions = headers.get_all("Sec-WebSocket-Extensions") @@ -227,7 +227,7 @@ def process_extensions(self, headers: Headers) -> List[Extension]: if self.available_extensions is None: raise InvalidHandshake("no extensions supported") - parsed_extensions: List[ExtensionHeader] = sum( + parsed_extensions: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in extensions], [] ) diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 5605772d8..3d64d951e 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -2,10 +2,8 @@ from typing import ( Any, - Dict, Iterable, Iterator, - List, Mapping, MutableMapping, Protocol, @@ -72,8 +70,8 @@ class Headers(MutableMapping[str, str]): # Like dict, Headers accepts an optional "mapping or iterable" argument. def __init__(self, *args: HeadersLike, **kwargs: str) -> None: - self._dict: Dict[str, List[str]] = {} - self._list: List[Tuple[str, str]] = [] + self._dict: dict[str, list[str]] = {} + self._list: list[tuple[str, str]] = [] self.update(*args, **kwargs) def __str__(self) -> str: @@ -147,7 +145,7 @@ def update(self, *args: HeadersLike, **kwargs: str) -> None: # Methods for handling multiple values - def get_all(self, key: str) -> List[str]: + def get_all(self, key: str) -> list[str]: """ Return the (possibly empty) list of all values for a header. @@ -157,7 +155,7 @@ def get_all(self, key: str) -> List[str]: """ return self._dict.get(key.lower(), []) - def raw_items(self) -> Iterator[Tuple[str, str]]: + def raw_items(self) -> Iterator[tuple[str, str]]: """ Return an iterator of all values as ``(name, value)`` pairs. @@ -181,6 +179,7 @@ def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ Headers, Mapping[str, str], + # Change to tuple[str, str] when dropping Python < 3.9. Iterable[Tuple[str, str]], SupportsKeysAndGetItem, ] diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 5b5528a09..a6c76c3d4 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Sequence, Tuple +from typing import Sequence from .. import frames from ..typing import ExtensionName, ExtensionParameter @@ -63,7 +63,7 @@ class ClientExtensionFactory: name: ExtensionName """Extension identifier.""" - def get_request_params(self) -> List[ExtensionParameter]: + def get_request_params(self) -> list[ExtensionParameter]: """ Build parameters to send to the server for this extension. @@ -108,7 +108,7 @@ def process_request_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], - ) -> Tuple[List[ExtensionParameter], Extension]: + ) -> tuple[list[ExtensionParameter], Extension]: """ Process parameters received from the client. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 48a6a0833..579262f02 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import dataclasses import zlib -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Sequence from .. import exceptions, frames from ..typing import ExtensionName, ExtensionParameter @@ -36,7 +36,7 @@ def __init__( local_no_context_takeover: bool, remote_max_window_bits: int, local_max_window_bits: int, - compress_settings: Dict[Any, Any] | None = None, + compress_settings: dict[Any, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension. @@ -176,12 +176,12 @@ def _build_parameters( client_no_context_takeover: bool, server_max_window_bits: int | None, client_max_window_bits: int | bool | None, -) -> List[ExtensionParameter]: +) -> list[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. """ - params: List[ExtensionParameter] = [] + params: list[ExtensionParameter] = [] if server_no_context_takeover: params.append(("server_no_context_takeover", None)) if client_no_context_takeover: @@ -197,7 +197,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> Tuple[bool, bool, int | None, int | bool | None]: +) -> tuple[bool, bool, int | None, int | bool | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -288,7 +288,7 @@ def __init__( client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, client_max_window_bits: int | bool | None = True, - compress_settings: Dict[str, Any] | None = None, + compress_settings: dict[str, Any] | None = None, ) -> None: """ Configure the Per-Message Deflate extension factory. @@ -314,7 +314,7 @@ def __init__( self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings - def get_request_params(self) -> List[ExtensionParameter]: + def get_request_params(self) -> list[ExtensionParameter]: """ Build request parameters. @@ -491,7 +491,7 @@ def __init__( client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, client_max_window_bits: int | None = None, - compress_settings: Dict[str, Any] | None = None, + compress_settings: dict[str, Any] | None = None, require_client_max_window_bits: bool = False, ) -> None: """ @@ -524,7 +524,7 @@ def process_request_params( self, params: Sequence[ExtensionParameter], accepted_extensions: Sequence[Extension], - ) -> Tuple[List[ExtensionParameter], PerMessageDeflate]: + ) -> tuple[list[ExtensionParameter], PerMessageDeflate]: """ Process request parameters. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 5a304d6a7..0da676432 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -5,7 +5,7 @@ import io import secrets import struct -from typing import Callable, Generator, Sequence, Tuple +from typing import Callable, Generator, Sequence from . import exceptions, extensions from .typing import Data @@ -353,7 +353,7 @@ def check(self) -> None: raise exceptions.ProtocolError("fragmented control frame") -def prepare_data(data: Data) -> Tuple[int, bytes]: +def prepare_data(data: Data) -> tuple[int, bytes]: """ Convert a string or byte-like object to an opcode and a bytes-like object. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 3b316e0bf..bc42e0b72 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,7 @@ import binascii import ipaddress import re -from typing import Callable, List, Sequence, Tuple, TypeVar, cast +from typing import Callable, Sequence, TypeVar, cast from . import exceptions from .typing import ( @@ -96,7 +96,7 @@ def parse_OWS(header: str, pos: int) -> int: _token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") -def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a token from ``header`` at the given position. @@ -120,7 +120,7 @@ def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]: _unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") -def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a quoted string from ``header`` at the given position. @@ -158,11 +158,11 @@ def build_quoted_string(value: str) -> str: def parse_list( - parse_item: Callable[[str, int, str], Tuple[T, int]], + parse_item: Callable[[str, int, str], tuple[T, int]], header: str, pos: int, header_name: str, -) -> List[T]: +) -> list[T]: """ Parse a comma-separated list from ``header`` at the given position. @@ -227,7 +227,7 @@ def parse_list( def parse_connection_option( header: str, pos: int, header_name: str -) -> Tuple[ConnectionOption, int]: +) -> tuple[ConnectionOption, int]: """ Parse a Connection option from ``header`` at the given position. @@ -241,7 +241,7 @@ def parse_connection_option( return cast(ConnectionOption, item), pos -def parse_connection(header: str) -> List[ConnectionOption]: +def parse_connection(header: str) -> list[ConnectionOption]: """ Parse a ``Connection`` header. @@ -264,7 +264,7 @@ def parse_connection(header: str) -> List[ConnectionOption]: def parse_upgrade_protocol( header: str, pos: int, header_name: str -) -> Tuple[UpgradeProtocol, int]: +) -> tuple[UpgradeProtocol, int]: """ Parse an Upgrade protocol from ``header`` at the given position. @@ -282,7 +282,7 @@ def parse_upgrade_protocol( return cast(UpgradeProtocol, match.group()), match.end() -def parse_upgrade(header: str) -> List[UpgradeProtocol]: +def parse_upgrade(header: str) -> list[UpgradeProtocol]: """ Parse an ``Upgrade`` header. @@ -300,7 +300,7 @@ def parse_upgrade(header: str) -> List[UpgradeProtocol]: def parse_extension_item_param( header: str, pos: int, header_name: str -) -> Tuple[ExtensionParameter, int]: +) -> tuple[ExtensionParameter, int]: """ Parse a single extension parameter from ``header`` at the given position. @@ -336,7 +336,7 @@ def parse_extension_item_param( def parse_extension_item( header: str, pos: int, header_name: str -) -> Tuple[ExtensionHeader, int]: +) -> tuple[ExtensionHeader, int]: """ Parse an extension definition from ``header`` at the given position. @@ -359,7 +359,7 @@ def parse_extension_item( return (cast(ExtensionName, name), parameters), pos -def parse_extension(header: str) -> List[ExtensionHeader]: +def parse_extension(header: str) -> list[ExtensionHeader]: """ Parse a ``Sec-WebSocket-Extensions`` header. @@ -389,7 +389,7 @@ def parse_extension(header: str) -> List[ExtensionHeader]: def build_extension_item( - name: ExtensionName, parameters: List[ExtensionParameter] + name: ExtensionName, parameters: list[ExtensionParameter] ) -> str: """ Build an extension definition. @@ -424,7 +424,7 @@ def build_extension(extensions: Sequence[ExtensionHeader]) -> str: def parse_subprotocol_item( header: str, pos: int, header_name: str -) -> Tuple[Subprotocol, int]: +) -> tuple[Subprotocol, int]: """ Parse a subprotocol from ``header`` at the given position. @@ -438,7 +438,7 @@ def parse_subprotocol_item( return cast(Subprotocol, item), pos -def parse_subprotocol(header: str) -> List[Subprotocol]: +def parse_subprotocol(header: str) -> list[Subprotocol]: """ Parse a ``Sec-WebSocket-Protocol`` header. @@ -498,7 +498,7 @@ def build_www_authenticate_basic(realm: str) -> str: _token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*") -def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]: +def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]: """ Parse a token68 from ``header`` at the given position. @@ -525,7 +525,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None: raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos) -def parse_authorization_basic(header: str) -> Tuple[str, str]: +def parse_authorization_basic(header: str) -> tuple[str, str]: """ Parse an ``Authorization`` header for HTTP Basic Auth. diff --git a/src/websockets/imports.py b/src/websockets/imports.py index 9c05234f5..bb80e4eac 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,13 +1,13 @@ from __future__ import annotations import warnings -from typing import Any, Dict, Iterable +from typing import Any, Iterable __all__ = ["lazy_import"] -def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: +def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any: """ Import ``name`` from ``source`` in ``namespace``. @@ -29,9 +29,9 @@ def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any: def lazy_import( - namespace: Dict[str, Any], - aliases: Dict[str, str] | None = None, - deprecated_aliases: Dict[str, str] | None = None, + namespace: dict[str, Any], + aliases: dict[str, str] | None = None, + deprecated_aliases: dict[str, str] | None = None, ) -> None: """ Provide lazy, module-level imports. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 9d685d9f4..c2d30e4b4 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -13,6 +13,7 @@ __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] +# Change to tuple[str, str] when dropping Python < 3.9. Credentials = Tuple[str, str] diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index f7464368f..d9d69fdaa 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -12,10 +12,7 @@ AsyncIterator, Callable, Generator, - List, Sequence, - Tuple, - Type, cast, ) @@ -122,7 +119,7 @@ def write_http_request(self, path: str, headers: Headers) -> None: self.transport.write(request.encode()) - async def read_http_response(self) -> Tuple[int, Headers]: + async def read_http_response(self) -> tuple[int, Headers]: """ Read status line and headers from the HTTP response. @@ -152,7 +149,7 @@ async def read_http_response(self) -> Tuple[int, Headers]: def process_extensions( headers: Headers, available_extensions: Sequence[ClientExtensionFactory] | None, - ) -> List[Extension]: + ) -> list[Extension]: """ Handle the Sec-WebSocket-Extensions HTTP response header. @@ -179,7 +176,7 @@ def process_extensions( order of extensions, may be implemented by overriding this method. """ - accepted_extensions: List[Extension] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") @@ -187,7 +184,7 @@ def process_extensions( if available_extensions is None: raise InvalidHandshake("no extensions supported") - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) @@ -455,7 +452,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) + klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketClientProtocol else: @@ -629,7 +626,7 @@ async def __aenter__(self) -> WebSocketClientProtocol: async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 8a13fa446..1aaca5cc6 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Sequence, Tuple +from typing import Any, Awaitable, Callable, NamedTuple, Sequence from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError @@ -152,7 +152,7 @@ def write( ) -def parse_close(data: bytes) -> Tuple[int, str]: +def parse_close(data: bytes) -> tuple[int, str]: """ Parse the payload from a close frame. diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 5853c31db..2a39c1b03 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -2,7 +2,6 @@ import base64 import binascii -from typing import List from ..datastructures import Headers, MultipleValuesError from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade @@ -55,14 +54,14 @@ def check_request(headers: Headers) -> str: Then, the server must return a 400 Bad Request error. """ - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", ", ".join(connection)) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -135,14 +134,14 @@ def check_response(headers: Headers, key: str) -> None: InvalidHandshake: If the handshake response is invalid. """ - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", " ".join(connection)) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 2ac7f7092..9a553e175 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -2,7 +2,6 @@ import asyncio import re -from typing import Tuple from ..datastructures import Headers from ..exceptions import SecurityError @@ -42,7 +41,7 @@ def d(value: bytes) -> str: _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") -async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: +async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]: """ Read an HTTP/1.1 GET request and return ``(path, headers)``. @@ -91,7 +90,7 @@ async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: return path, headers -async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]: +async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]: """ Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index f4c5901dc..67161019f 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -18,11 +18,8 @@ Awaitable, Callable, Deque, - Dict, Iterable, - List, Mapping, - Tuple, cast, ) @@ -262,7 +259,7 @@ def __init__( """Opening handshake response headers.""" # WebSocket protocol parameters. - self.extensions: List[Extension] = [] + self.extensions: list[Extension] = [] self.subprotocol: Subprotocol | None = None """Subprotocol, if one was negotiated.""" @@ -286,7 +283,7 @@ def __init__( self._fragmented_message_waiter: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} + self.pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self.latency: float = 0 """ @@ -1042,7 +1039,7 @@ async def read_message(self) -> Data | None: return frame.data.decode("utf-8") if text else frame.data # 5.4. Fragmentation - fragments: List[Data] = [] + fragments: list[Data] = [] max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 13a6f5591..c0ea6a764 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -15,11 +15,8 @@ Callable, Generator, Iterable, - List, Sequence, - Set, Tuple, - Type, Union, cast, ) @@ -57,6 +54,7 @@ # Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] +# Change to tuple[...] when dropping Python < 3.9. HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] @@ -263,7 +261,7 @@ async def handler(self) -> None: self.ws_server.unregister(self) self.logger.info("connection closed") - async def read_http_request(self) -> Tuple[str, Headers]: + async def read_http_request(self) -> tuple[str, Headers]: """ Read request line and headers from the HTTP request. @@ -349,7 +347,7 @@ async def process_request( request_headers: Request headers. Returns: - Tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to + tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to continue the WebSocket handshake normally. An HTTP response, represented by a 3-uple of the response status, @@ -401,7 +399,7 @@ def process_origin( def process_extensions( headers: Headers, available_extensions: Sequence[ServerExtensionFactory] | None, - ) -> Tuple[str | None, List[Extension]]: + ) -> tuple[str | None, list[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -439,13 +437,13 @@ def process_extensions( """ response_header_value: str | None = None - extension_headers: List[ExtensionHeader] = [] - accepted_extensions: List[Extension] = [] + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) @@ -502,7 +500,7 @@ def process_subprotocol( header_values = headers.get_all("Sec-WebSocket-Protocol") if header_values and available_subprotocols: - parsed_header_values: List[Subprotocol] = sum( + parsed_header_values: list[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) @@ -669,7 +667,7 @@ def __init__(self, logger: LoggerLike | None = None): self.logger = logger # Keep track of active connections. - self.websockets: Set[WebSocketServerProtocol] = set() + self.websockets: set[WebSocketServerProtocol] = set() # Task responsible for closing the server and terminating connections. self.close_task: asyncio.Task[None] | None = None @@ -871,7 +869,7 @@ async def __aenter__(self) -> WebSocketServer: # pragma: no cover async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: # pragma: no cover @@ -944,7 +942,7 @@ class Serve: It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. process_request (Callable[[str, Headers], \ - Awaitable[Tuple[StatusLike, HeadersLike, bytes] | None]] | None): + Awaitable[tuple[StatusLike, HeadersLike, bytes] | None]] | None): Intercept HTTP request before the opening handshake. See :meth:`~WebSocketServerProtocol.process_request` for details. select_subprotocol: Select a subprotocol supported by the client. @@ -1015,7 +1013,7 @@ def __init__( close_timeout = timeout # Backwards compatibility: create_protocol used to be called klass. - klass: Type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) + klass: type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) if klass is None: klass = WebSocketServerProtocol else: @@ -1100,7 +1098,7 @@ async def __aenter__(self) -> WebSocketServer: async def __aexit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index f288a2733..2f5542f6e 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,7 @@ import enum import logging import uuid -from typing import Generator, List, Type, Union +from typing import Generator, Union from .exceptions import ( ConnectionClosed, @@ -125,7 +125,7 @@ def __init__( # WebSocket protocol parameters. self.origin: Origin | None = None - self.extensions: List[Extension] = [] + self.extensions: list[Extension] = [] self.subprotocol: Subprotocol | None = None # Close code and reason, set when a close frame is sent or received. @@ -147,8 +147,8 @@ def __init__( # Parser state. self.reader = StreamReader() - self.events: List[Event] = [] - self.writes: List[bytes] = [] + self.events: list[Event] = [] + self.writes: list[bytes] = [] self.parser = self.parse() next(self.parser) # start coroutine self.parser_exc: Exception | None = None @@ -222,7 +222,7 @@ def close_exc(self) -> ConnectionClosed: """ assert self.state is CLOSED, "connection isn't closed yet" - exc_type: Type[ConnectionClosed] + exc_type: type[ConnectionClosed] if ( self.close_rcvd is not None and self.close_sent is not None @@ -458,7 +458,7 @@ def fail(self, code: int, reason: str = "") -> None: # Public method for getting incoming events after receiving data. - def events_received(self) -> List[Event]: + def events_received(self) -> list[Event]: """ Fetch events generated from data received from the network. @@ -474,7 +474,7 @@ def events_received(self) -> List[Event]: # Public method for getting outgoing data after receiving data or sending events. - def data_to_send(self) -> List[bytes]: + def data_to_send(self) -> list[bytes]: """ Obtain data to send to the network. diff --git a/src/websockets/server.py b/src/websockets/server.py index a92541085..f976ebad7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,7 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, List, Sequence, Tuple, cast +from typing import Any, Callable, Generator, Sequence, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -201,7 +201,7 @@ def accept(self, request: Request) -> Response: def process_request( self, request: Request, - ) -> Tuple[str, str | None, str | None]: + ) -> tuple[str, str | None, str | None]: """ Check a handshake request and negotiate extensions and subprotocol. @@ -224,7 +224,7 @@ def process_request( """ headers = request.headers - connection: List[ConnectionOption] = sum( + connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) @@ -233,7 +233,7 @@ def process_request( "Connection", ", ".join(connection) if connection else None ) - upgrade: List[UpgradeProtocol] = sum( + upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) @@ -317,7 +317,7 @@ def process_origin(self, headers: Headers) -> Origin | None: def process_extensions( self, headers: Headers, - ) -> Tuple[str | None, List[Extension]]: + ) -> tuple[str | None, list[Extension]]: """ Handle the Sec-WebSocket-Extensions HTTP request header. @@ -355,13 +355,13 @@ def process_extensions( """ response_header_value: str | None = None - extension_headers: List[ExtensionHeader] = [] - accepted_extensions: List[Extension] = [] + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] header_values = headers.get_all("Sec-WebSocket-Extensions") if header_values and self.available_extensions: - parsed_header_values: List[ExtensionHeader] = sum( + parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] ) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 60b49ebc3..c97a09402 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,7 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Sequence, Type +from typing import Any, Sequence from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -146,7 +146,7 @@ def connect( # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Type[ClientConnection] | None = None, + create_connection: type[ClientConnection] | None = None, **kwargs: Any, ) -> ClientConnection: """ diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 7a750331d..33d8299e2 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -8,7 +8,7 @@ import threading import uuid from types import TracebackType -from typing import Any, Dict, Iterable, Iterator, Mapping, Type +from typing import Any, Iterable, Iterator, Mapping from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl @@ -80,7 +80,7 @@ def __init__( self.close_deadline: Deadline | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.ping_waiters: Dict[bytes, threading.Event] = {} + self.ping_waiters: dict[bytes, threading.Event] = {} # Receiving events from the socket. This thread explicitly is marked as # to support creating a connection in a non-daemon thread then using it @@ -140,7 +140,7 @@ def __enter__(self) -> Connection: def __exit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 2c604ba09..a6e78e7fd 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Iterator, List, cast +from typing import Iterator, cast from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -44,7 +44,7 @@ def __init__(self) -> None: self.decoder: codecs.IncrementalDecoder | None = None # Buffer of frames belonging to the same message. - self.chunks: List[Data] = [] + self.chunks: list[Data] = [] # When switching from "buffering" to "streaming", we use a thread-safe # queue for transferring frames from the writing thread (library code) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index b801510b4..4f088b63a 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,7 +10,7 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Sequence, Type +from typing import Any, Callable, Sequence from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate @@ -262,7 +262,7 @@ def __enter__(self) -> WebSocketServer: def __exit__( self, - exc_type: Type[BaseException] | None, + exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: @@ -312,7 +312,7 @@ def serve( # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization - create_connection: Type[ServerConnection] | None = None, + create_connection: type[ServerConnection] | None = None, **kwargs: Any, ) -> WebSocketServer: """ diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 73d4a4754..6360c7a0a 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -56,13 +56,15 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" -# Change to str | None when dropping Python < 3.10. +# Change to tuple[str, Optional[str]] when dropping Python < 3.9. +# Change to tuple[str, str | None] when dropping Python < 3.10. ExtensionParameter = Tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" # Private types +# Change to tuple[.., list[...]] when dropping Python < 3.9. ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 902716066..5cb38a9cc 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,7 +2,6 @@ import dataclasses import urllib.parse -from typing import Tuple from . import exceptions @@ -47,7 +46,7 @@ def resource_name(self) -> str: return resource_name @property - def user_info(self) -> Tuple[str, str] | None: + def user_info(self) -> tuple[str, str] | None: if self.username is None: return None assert self.password is not None From f45286b3b2d54f8b79087b060858042b2488688b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 12 Jul 2024 09:37:09 +0200 Subject: [PATCH 510/762] Pick changes suggested by `pyupgrade --py38-plus`. Other changes were ignored, on purpose. --- src/websockets/sync/messages.py | 3 +-- tests/legacy/test_client_server.py | 2 +- tests/legacy/utils.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index a6e78e7fd..6cbff2595 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -53,8 +53,7 @@ def __init__(self) -> None: # value marking the end of the message, superseding message_complete. # Stream data from frames belonging to the same message. - # Remove quotes around type when dropping Python < 3.9. - self.chunks_queue: "queue.SimpleQueue[Data | None] | None" = None + self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 51a74734b..c38086572 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -65,7 +65,7 @@ async def default_handler(ws): await ws.wait_closed() await asyncio.sleep(2 * MS) else: - await ws.send((await ws.recv())) + await ws.send(await ws.recv()) async def redirect_request(path, headers, test, status): diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 4a21dcaeb..28bc90df3 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -79,6 +79,6 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): for recorded in recorded_warnings: self.assertEqual(type(recorded.message), DeprecationWarning) self.assertEqual( - set(str(recorded.message) for recorded in recorded_warnings), + {str(recorded.message) for recorded in recorded_warnings}, set(expected_warnings), ) From 650d08caf1c5f84c77a8bf8780a1b407a1432357 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:19:32 +0200 Subject: [PATCH 511/762] Upgrade to mypy 1.11. --- src/websockets/legacy/auth.py | 5 +++++ src/websockets/legacy/client.py | 6 ++++-- src/websockets/legacy/server.py | 3 +++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index c2d30e4b4..8526bad6b 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -178,6 +178,11 @@ async def check_credentials(username: str, password: str) -> bool: if create_protocol is None: create_protocol = BasicAuthWebSocketServerProtocol + # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] | + # Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc] + create_protocol = cast( + Callable[..., BasicAuthWebSocketServerProtocol], create_protocol + ) return functools.partial( create_protocol, realm=realm, diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index d9d69fdaa..b15eddf75 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -489,6 +489,9 @@ def __init__( if subprotocols is not None: validate_subprotocols(subprotocols) + # Help mypy and avoid this error: "type[WebSocketClientProtocol] | + # Callable[..., WebSocketClientProtocol]" not callable [misc] + create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol) factory = functools.partial( create_protocol, logger=logger, @@ -641,8 +644,7 @@ def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: async def __await_impl__(self) -> WebSocketClientProtocol: async with asyncio_timeout(self.open_timeout): for _redirects in range(self.MAX_REDIRECTS_ALLOWED): - _transport, _protocol = await self._create_connection() - protocol = cast(WebSocketClientProtocol, _protocol) + _transport, protocol = await self._create_connection() try: await protocol.handshake( self._wsuri, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c0ea6a764..08c82df25 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -1045,6 +1045,9 @@ def __init__( if subprotocols is not None: validate_subprotocols(subprotocols) + # Help mypy and avoid this error: "type[WebSocketServerProtocol] | + # Callable[..., WebSocketServerProtocol]" not callable [misc] + create_protocol = cast(Callable[..., WebSocketServerProtocol], create_protocol) factory = functools.partial( create_protocol, # For backwards compatibility with 10.0 or earlier. Done here in From e05f6dc83434dae7d91fc0db822ab15aa1e4c00b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:05:24 +0200 Subject: [PATCH 512/762] Support ws:// to wss:// redirects. Fix #1454. --- src/websockets/legacy/client.py | 14 ++++++++++---- tests/legacy/test_client_server.py | 13 ++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b15eddf75..b0e15b543 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -558,21 +558,27 @@ def handle_redirect(self, uri: str) -> None: raise SecurityError("redirect from WSS to WS") same_origin = ( - old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port + old_wsuri.secure == new_wsuri.secure + and old_wsuri.host == new_wsuri.host + and old_wsuri.port == new_wsuri.port ) - # Rewrite the host and port arguments for cross-origin redirects. + # Rewrite secure, host, and port for cross-origin redirects. # This preserves connection overrides with the host and port # arguments if the redirect points to the same host and port. if not same_origin: - # Replace the host and port argument passed to the protocol factory. factory = self._create_connection.args[0] + # Support TLS upgrade. + if not old_wsuri.secure and new_wsuri.secure: + factory.keywords["secure"] = True + self._create_connection.keywords.setdefault("ssl", True) + # Replace secure, host, and port arguments of the protocol factory. factory = functools.partial( factory.func, *factory.args, **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), ) - # Replace the host and port argument passed to create_connection. + # Replace secure, host, and port arguments of create_connection. self._create_connection = functools.partial( self._create_connection.func, *(factory, new_wsuri.host, new_wsuri.port), diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c38086572..09b3b361a 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -75,6 +75,8 @@ async def redirect_request(path, headers, test, status): location = "/" elif path == "/infinite": location = get_server_uri(test.server, test.secure, "/infinite") + elif path == "/force_secure": + location = get_server_uri(test.server, True, "/") elif path == "/force_insecure": location = get_server_uri(test.server, False, "/") elif path == "/missing_location": @@ -1290,7 +1292,16 @@ def test_connection_error_during_closing_handshake(self, close): class ClientServerTests( CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase ): - pass + + def test_redirect_secure(self): + with temp_test_redirecting_server(self): + # websockets doesn't support serving non-TLS and TLS connections + # from the same server and this test suite makes it difficult to + # run two servers. Therefore, we expect the redirect to create a + # TLS client connection to a non-TLS server, which will fail. + with self.assertRaises(ssl.SSLError): + with self.temp_client("/force_secure"): + self.fail("did not raise") class SecureClientServerTests( From 61b69db60cceff6c46ff308d2c10f7f81480788c Mon Sep 17 00:00:00 2001 From: Antonio Curado Date: Mon, 20 May 2024 18:29:55 +0200 Subject: [PATCH 513/762] Correct handle exceptions in `legacy/broadcast` --- src/websockets/legacy/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 67161019f..3d09440e1 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1630,5 +1630,5 @@ def broadcast( exc_info=True, ) - if raise_exceptions: + if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) From 96d3adf6617fd53fc7a7adcc5a560eeeb8493473 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 08:30:01 +0200 Subject: [PATCH 514/762] Add tests for previous commit. --- tests/legacy/test_protocol.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index f3dcd9ac7..05d2f3795 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1472,10 +1472,24 @@ def test_broadcast_text(self): broadcast([self.protocol], "café") self.assertOneFrameSent(True, OP_TEXT, "café".encode()) + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_text_reports_no_errors(self): + broadcast([self.protocol], "café", raise_exceptions=True) + self.assertOneFrameSent(True, OP_TEXT, "café".encode()) + def test_broadcast_binary(self): broadcast([self.protocol], b"tea") self.assertOneFrameSent(True, OP_BINARY, b"tea") + @unittest.skipIf( + sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + ) + def test_broadcast_binary_reports_no_errors(self): + broadcast([self.protocol], b"tea", raise_exceptions=True) + self.assertOneFrameSent(True, OP_BINARY, b"tea") + def test_broadcast_type_error(self): with self.assertRaises(TypeError): broadcast([self.protocol], ["ca", "fé"]) From ee997c157d3214f758c2422fc44c2a582153f58a Mon Sep 17 00:00:00 2001 From: xuanzhi33 <37460139+xuanzhi33@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:36:11 +0800 Subject: [PATCH 515/762] docs: Correct the example for "Starting a server" in the API reference --- src/websockets/legacy/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 08c82df25..fb91265d8 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -899,7 +899,8 @@ class Serve: server = await serve(...) await stop - await server.close() + server.close() + await server.wait_closed() :func:`serve` can be used as an asynchronous context manager. Then, the server is shut down automatically when exiting the context:: From 1210ee81e470bd8df7700d459ba101263fb7413c Mon Sep 17 00:00:00 2001 From: xuanzhi33 <37460139+xuanzhi33@users.noreply.github.com> Date: Wed, 28 Feb 2024 18:30:01 +0800 Subject: [PATCH 516/762] Update server.rst --- docs/faq/server.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 08b412d30..cba1cd35f 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -300,7 +300,8 @@ Here's how to adapt the example just above:: server = await websockets.serve(echo, "localhost", 8765) await stop - await server.close(close_connections=False) + server.close(close_connections=False) + await server.wait_closed() How do I implement a health check? ---------------------------------- From 41c42b8681dc1245a65e2db8491a573bba1827dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 13:39:14 +0200 Subject: [PATCH 517/762] Make it easier to debug version numbers. --- src/websockets/version.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index f1de3cbf4..145c7a9ed 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -55,7 +55,8 @@ def get_version(tag: str) -> str: else: description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" match = re.fullmatch(description_re, description) - assert match is not None + if match is None: + raise ValueError(f"Unexpected git description: {description}") distance, remainder = match.groups() remainder = remainder.replace("-", ".") # required by PEP 440 return f"{tag}.dev{distance}+{remainder}" @@ -75,7 +76,8 @@ def get_commit(tag: str, version: str) -> str: # Extract commit from version, falling back to tag if not available. version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?" match = re.fullmatch(version_re, version) - assert match is not None + if match is None: + raise ValueError(f"Unexpected version: {version}") (commit,) = match.groups() return tag if commit == "unknown" else commit From eaa64c07676a9c28d17c9538242fab4638754584 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 14:41:49 +0200 Subject: [PATCH 518/762] Avoid reading the wrong version. This was causing builds to fail on Read the Docs since sphinx-autobuild added websockets as a dependency. --- src/websockets/version.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 145c7a9ed..46ae34a47 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -34,8 +34,20 @@ def get_version(tag: str) -> str: file_path = pathlib.Path(__file__) root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2] - # Read version from git if available. This prevents reading stale - # information from src/websockets.egg-info after building a sdist. + # Read version from package metadata if it is installed. + try: + version = importlib.metadata.version("websockets") + except ImportError: + pass + else: + # Check that this file belongs to the installed package. + files = importlib.metadata.files("websockets") + if files: + version_file = [f for f in files if f.name == file_path.name][0] + if version_file.locate() == file_path: + return version + + # Read version from git if available. try: description = subprocess.run( ["git", "describe", "--dirty", "--tags", "--long"], @@ -61,12 +73,6 @@ def get_version(tag: str) -> str: remainder = remainder.replace("-", ".") # required by PEP 440 return f"{tag}.dev{distance}+{remainder}" - # Read version from package metadata if it is installed. - try: - return importlib.metadata.version("websockets") - except ImportError: - pass - # Avoid crashing if the development version cannot be determined. return f"{tag}.dev0+gunknown" From e10eebaec368b04f28102df513e26e933ed5a6fd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 15:24:25 +0200 Subject: [PATCH 519/762] Unshallow git clone on RtD. This is required for get_version to find the last tag. --- .readthedocs.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.readthedocs.yml b/.readthedocs.yml index 0369e0656..28c990c5c 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,6 +4,9 @@ build: os: ubuntu-20.04 tools: python: "3.10" + jobs: + post_checkout: + - git fetch --unshallow sphinx: configuration: docs/conf.py From c8c0a9bfee962540eb3c9c228e36d4ef7bd7ed42 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 21 Jul 2024 15:56:42 +0200 Subject: [PATCH 520/762] Improve error reporting when header is too long. Refs #1471. --- src/websockets/legacy/server.py | 6 +++++- src/websockets/server.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index fb91265d8..c0c138767 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -199,10 +199,14 @@ async def handler(self) -> None: elif isinstance(exc, InvalidHandshake): if self.debug: self.logger.debug("! invalid handshake", exc_info=True) + exc_str = f"{exc}" + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += f"; {exc}" status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), - f"Failed to open a WebSocket connection: {exc}.\n".encode(), + f"Failed to open a WebSocket connection: {exc_str}.\n".encode(), ) else: self.logger.error("opening handshake failed", exc_info=True) diff --git a/src/websockets/server.py b/src/websockets/server.py index f976ebad7..7f5631230 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -163,9 +163,13 @@ def accept(self, request: Request) -> Response: self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) + exc_str = f"{exc}" + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += f"; {exc}" return self.reject( http.HTTPStatus.BAD_REQUEST, - f"Failed to open a WebSocket connection: {exc}.\n", + f"Failed to open a WebSocket connection: {exc_str}.\n", ) except Exception as exc: # Handle exceptions raised by user-provided select_subprotocol and From d26bac47eac98e6f3b77358b8c836ed02e493fc6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 30 Jul 2024 09:04:22 +0200 Subject: [PATCH 521/762] Make eaa64c07 more robust. This avoids crashing on ossfuzz, which uses a custom loader. --- src/websockets/version.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/websockets/version.py b/src/websockets/version.py index 46ae34a47..44709a91b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -43,9 +43,11 @@ def get_version(tag: str) -> str: # Check that this file belongs to the installed package. files = importlib.metadata.files("websockets") if files: - version_file = [f for f in files if f.name == file_path.name][0] - if version_file.locate() == file_path: - return version + version_files = [f for f in files if f.name == file_path.name] + if version_files: + version_file = version_files[0] + if version_file.locate() == file_path: + return version # Read version from git if available. try: From d2710227cd2464162861b584f5dd83c472208929 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:04:43 +0200 Subject: [PATCH 522/762] Make mypy happy. --- src/websockets/legacy/server.py | 9 +++++---- src/websockets/server.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index c0c138767..93698e1cb 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -199,10 +199,11 @@ async def handler(self) -> None: elif isinstance(exc, InvalidHandshake): if self.debug: self.logger.debug("! invalid handshake", exc_info=True) - exc_str = f"{exc}" - while exc.__cause__ is not None: - exc = exc.__cause__ - exc_str += f"; {exc}" + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" status, headers, body = ( http.HTTPStatus.BAD_REQUEST, Headers(), diff --git a/src/websockets/server.py b/src/websockets/server.py index 7f5631230..baab400d4 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -163,10 +163,11 @@ def accept(self, request: Request) -> Response: self.handshake_exc = exc if self.debug: self.logger.debug("! invalid handshake", exc_info=True) - exc_str = f"{exc}" - while exc.__cause__ is not None: - exc = exc.__cause__ - exc_str += f"; {exc}" + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" return self.reject( http.HTTPStatus.BAD_REQUEST, f"Failed to open a WebSocket connection: {exc_str}.\n", From 309e62fa89311de51083e1a62adf06d4450fc5f2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:09:24 +0200 Subject: [PATCH 523/762] Test against current PyPy versions. --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8161f1cbb..15a45bdfb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,15 +62,15 @@ jobs: - "3.10" - "3.11" - "3.12" - - "pypy-3.8" - "pypy-3.9" + - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.8" - is_main: false - python: "pypy-3.9" is_main: false + - python: "pypy-3.10" + is_main: false steps: - name: Check out repository uses: actions/checkout@v4 From fab77d60b660585bcdd996a10ec904f79c901085 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:38:17 +0200 Subject: [PATCH 524/762] Annotate __init__ methods consistently. --- src/websockets/client.py | 2 +- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 2 +- src/websockets/sync/server.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 8f78ac320..07d1d34ed 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -79,7 +79,7 @@ def __init__( state: State = CONNECTING, max_size: int | None = 2**20, logger: LoggerLike | None = None, - ): + ) -> None: super().__init__( side=CLIENT, state=state, diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 93698e1cb..39464be6c 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -666,7 +666,7 @@ class WebSocketServer: """ - def __init__(self, logger: LoggerLike | None = None): + def __init__(self, logger: LoggerLike | None = None) -> None: if logger is None: logger = logging.getLogger("websockets.server") self.logger = logger diff --git a/src/websockets/server.py b/src/websockets/server.py index baab400d4..7211d3cbf 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -90,7 +90,7 @@ def __init__( state: State = CONNECTING, max_size: int | None = 2**20, logger: LoggerLike | None = None, - ): + ) -> None: super().__init__( side=SERVER, state=state, diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 4f088b63a..7fb46f5aa 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -200,7 +200,7 @@ def __init__( socket: socket.socket, handler: Callable[[socket.socket, Any], None], logger: LoggerLike | None = None, - ): + ) -> None: self.socket = socket self.handler = handler if logger is None: From 14cca7699971c19c30a38cad260aeb5f26e0c3ca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 3 Aug 2024 10:46:04 +0200 Subject: [PATCH 525/762] Bugs in coverage were fixed \o/ --- src/websockets/legacy/protocol.py | 3 +-- tests/legacy/test_client_server.py | 1 - tests/legacy/test_protocol.py | 18 ++++++------------ 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 3d09440e1..de9ea59b6 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1349,8 +1349,7 @@ async def close_transport(self) -> None: # Abort the TCP connection. Buffers are discarded. if self.debug: self.logger.debug("x aborting TCP connection") - # Due to a bug in coverage, this is erroneously reported as not covered. - self.transport.abort() # pragma: no cover + self.transport.abort() # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 09b3b361a..0c3f22156 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1584,7 +1584,6 @@ async def run_client(): else: # Exit block with an exception. raise Exception("BOOM") - pass # work around bug in coverage with self.assertLogs("websockets", logging.INFO) as logs: with self.assertRaises(Exception) as raised: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 05d2f3795..d6303dcc7 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1613,8 +1613,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1634,8 +1633,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1656,8 +1654,7 @@ def test_local_close_send_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.ABNORMAL_CLOSURE, "", ) @@ -1670,8 +1667,7 @@ def test_local_close_receive_close_frame_timeout(self): # Check the timing within -1/+9ms for robustness. with self.assertCompletesWithin(19 * MS, 29 * MS): self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.ABNORMAL_CLOSURE, "", ) @@ -1689,8 +1685,7 @@ def test_local_close_connection_lost_timeout_after_write_eof(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) @@ -1713,8 +1708,7 @@ def test_local_close_connection_lost_timeout_after_close(self): self.receive_frame(self.close_frame) self.run_loop_once() self.loop.run_until_complete(self.protocol.close(reason="close")) - # Due to a bug in coverage, this is erroneously reported as not covered. - self.assertConnectionClosed( # pragma: no cover + self.assertConnectionClosed( CloseCode.NORMAL_CLOSURE, "close", ) From 7bb18a6ea84d2651b68ad45f5e9464a47d314b6b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 07:53:06 +0200 Subject: [PATCH 526/762] Update references to Python's bug tracker. --- src/websockets/__main__.py | 2 +- src/websockets/legacy/protocol.py | 6 +++--- src/websockets/legacy/server.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index f2ea5cf4e..8647481d0 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -22,7 +22,7 @@ def win_enable_vt100() -> None: """ Enable VT-100 for console output on Windows. - See also https://bugs.python.org/issue29059. + See also https://github.com/python/cpython/issues/73245. """ import ctypes diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index de9ea59b6..57cb4e770 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1175,9 +1175,9 @@ def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None: async def drain(self) -> None: try: - # drain() cannot be called concurrently by multiple coroutines: - # http://bugs.python.org/issue29930. Remove this lock when no - # version of Python where this bugs exists is supported anymore. + # drain() cannot be called concurrently by multiple coroutines. + # See https://github.com/python/cpython/issues/74116 for details. + # This workaround can be removed when dropping Python < 3.10. async with self._drain_lock: # Handle flow control automatically. await self._drain() diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 39464be6c..f4442fecc 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -764,7 +764,8 @@ async def _close(self, close_connections: bool) -> None: self.server.close() # Wait until all accepted connections reach connection_made() and call - # register(). See https://bugs.python.org/issue34852 for details. + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. await asyncio.sleep(0) if close_connections: From 273db5bcc4113061bd7d8f0a4edbf6c4d76c4d84 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Aug 2024 15:18:07 +0200 Subject: [PATCH 527/762] Make it easier to enable logs while running tests. --- tests/__init__.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index dd78609f5..bb1866f2d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,14 @@ import logging +import os -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) +format = "%(asctime)s %(levelname)s %(name)s %(message)s" + +if bool(os.environ.get("WEBSOCKETS_DEBUG")): # pragma: no cover + # Display every frame sent or received in debug mode. + level = logging.DEBUG +else: + # Hide stack traces of exceptions. + level = logging.CRITICAL + +logging.basicConfig(format=format, level=level) From cbcb7fd715be0f1efb98102302739e9d9f3ca08c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 15:59:16 +0200 Subject: [PATCH 528/762] Pass WEBSOCKETS_TESTS_TIMEOUT_FACTOR to tox. Previously, despite being declared in .github/workflows/tests.yml, it had no effect because tox insulates test runs from the environment. --- tox.ini | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 06003c85b..b00833e73 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = +env_list = py38 py39 py310 @@ -12,6 +12,7 @@ envlist = [testenv] commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} +pass_env = WEBSOCKETS_* [testenv:coverage] commands = From 8c4fd9c24a701b7050681f786b2918d205b91338 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 16:04:10 +0200 Subject: [PATCH 529/762] Remove superfluous `coverage erase`. `coverage run` starts clean unless `--append` is specified. --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index b00833e73..1edcfe261 100644 --- a/tox.ini +++ b/tox.ini @@ -16,7 +16,6 @@ pass_env = WEBSOCKETS_* [testenv:coverage] commands = - python -m coverage erase python -m coverage run --source {envsitepackagesdir}/websockets,tests -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 deps = coverage From 02b333829e385d4fef42fa0565996adcfea653b3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 5 Aug 2024 08:59:51 +0200 Subject: [PATCH 530/762] Make Protocol.receive_eof idempotent. This removes the need for keeping track of whether you called it or not, especially in an asyncio context where it may be called in eof_received or in connection_lost. --- src/websockets/protocol.py | 5 +++-- tests/test_protocol.py | 8 ++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2f5542f6e..7f2b45c74 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -270,10 +270,11 @@ def receive_eof(self) -> None: - You aren't expected to call :meth:`events_received`; it won't return any new events. - Raises: - EOFError: If :meth:`receive_eof` was called earlier. + :meth:`receive_eof` is idempotent. """ + if self.reader.eof: + return self.reader.feed_eof() next(self.parser) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e1527525b..7f1276bb2 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1590,18 +1590,14 @@ def test_client_receives_eof_after_eof(self): client.receive_data(b"\x88\x00") self.assertConnectionClosing(client) client.receive_eof() - with self.assertRaises(EOFError) as raised: - client.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + client.receive_eof() # this is idempotent def test_server_receives_eof_after_eof(self): server = Protocol(SERVER) server.receive_data(b"\x88\x80\x3c\x3c\x3c\x3c") self.assertConnectionClosing(server) server.receive_eof() - with self.assertRaises(EOFError) as raised: - server.receive_eof() - self.assertEqual(str(raised.exception), "stream ended") + server.receive_eof() # this is idempotent class TCPCloseTests(ProtocolTestCase): From 3ad92b50515e5c83344f0771a34e8d9e7cd8ff4e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 15:28:18 +0200 Subject: [PATCH 531/762] Don't specify the encoding when it's utf-8. --- src/websockets/frames.py | 2 +- src/websockets/legacy/protocol.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 0da676432..af56d3f8f 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -447,7 +447,7 @@ def parse(cls, data: bytes) -> Close: """ if len(data) >= 2: (code,) = struct.unpack("!H", data[:2]) - reason = data[2:].decode("utf-8") + reason = data[2:].decode() close = cls(code, reason) close.check() return close diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 57cb4e770..c28bdcf48 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1036,7 +1036,7 @@ async def read_message(self) -> Data | None: # Shortcut for the common case - no fragmentation if frame.fin: - return frame.data.decode("utf-8") if text else frame.data + return frame.data.decode() if text else frame.data # 5.4. Fragmentation fragments: list[Data] = [] From d0fd9cf61432a8dcf1cd639139ebb57d4b522c01 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 4 Aug 2024 08:08:35 +0200 Subject: [PATCH 532/762] Improve tests for sync implementation slightly. --- src/websockets/sync/connection.py | 3 +- tests/sync/connection.py | 4 +- tests/sync/test_connection.py | 76 +++++++++++++++++++++---------- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 33d8299e2..2bcb3aa0e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -373,6 +373,7 @@ def send(self, message: Data | Iterable[Data]) -> None: except RuntimeError: # We didn't start sending a fragmented message. + # The connection is still usable. raise except Exception: @@ -756,7 +757,7 @@ def set_recv_exc(self, exc: BaseException | None) -> None: """ assert self.protocol_mutex.locked() - if self.recv_exc is None: + if self.recv_exc is None: # pragma: no branch self.recv_exc = exc def close_socket(self) -> None: diff --git a/tests/sync/connection.py b/tests/sync/connection.py index 89d4909ee..9c8bacea0 100644 --- a/tests/sync/connection.py +++ b/tests/sync/connection.py @@ -8,7 +8,7 @@ class InterceptingConnection(Connection): """ Connection subclass that can intercept outgoing packets. - By interfacing with this connection, you can simulate network conditions + By interfacing with this connection, we simulate network conditions affecting what the component being tested receives during a test. """ @@ -80,7 +80,7 @@ def drop_eof_sent(self): class InterceptingSocket: """ - Socket wrapper that intercepts calls to sendall and shutdown. + Socket wrapper that intercepts calls to ``sendall()`` and ``shutdown()``. This is coupled to the implementation, which relies on these two methods. diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 953c8c253..88cbcd669 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -246,13 +246,15 @@ def test_recv_streaming_connection_closed_ok(self): """recv_streaming raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() with self.assertRaises(ConnectionClosedOK): - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") def test_recv_streaming_connection_closed_error(self): """recv_streaming raises ConnectionClosedError after an error.""" self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") def test_recv_streaming_during_recv(self): """recv_streaming raises RuntimeError when called concurrently with recv.""" @@ -260,7 +262,8 @@ def test_recv_streaming_during_recv(self): recv_thread.start() with self.assertRaises(RuntimeError) as raised: - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") self.assertEqual( str(raised.exception), "cannot call recv_streaming while another thread " @@ -278,7 +281,8 @@ def test_recv_streaming_during_recv_streaming(self): recv_streaming_thread.start() with self.assertRaises(RuntimeError) as raised: - list(self.connection.recv_streaming()) + for _ in self.connection.recv_streaming(): + self.fail("did not raise") self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another thread " @@ -374,7 +378,7 @@ def test_send_empty_iterable(self): """send does nothing when called with an empty iterable.""" self.connection.send([]) self.connection.close() - self.assertEqual(list(iter(self.remote_connection)), []) + self.assertEqual(list(self.remote_connection), []) def test_send_mixed_iterable(self): """send raises TypeError when called with an iterable of inconsistent types.""" @@ -437,7 +441,7 @@ def test_close_waits_for_connection_closed(self): def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" - with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() with self.assertRaises(ConnectionClosedError) as raised: @@ -464,6 +468,10 @@ def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) def test_close_waits_for_recv(self): + # The sync implementation doesn't have a buffer for incoming messsages. + # It requires reading incoming frames until the close frame is reached. + # This behavior — close() blocks until recv() is called — is less than + # ideal and inconsistent with the asyncio implementation. self.remote_connection.send("😀") close_thread = threading.Thread(target=self.connection.close) @@ -547,6 +555,25 @@ def closer(): close_thread.join() + def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + + def closer(): + time.sleep(MS) + self.connection.close() + + close_thread = threading.Thread(target=closer) + close_thread.start() + + with self.assertRaises(ConnectionClosedOK) as raised: + self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + close_thread.join() + def test_close_during_send(self): """close fails the connection when called concurrently with send.""" close_gate = threading.Event() @@ -599,42 +626,45 @@ def test_ping_explicit_binary(self): self.connection.ping(b"ping") self.assertFrameSent(Frame(Opcode.PING, b"ping")) - def test_ping_duplicate_payload(self): - """ping rejects the same payload until receiving the pong.""" - with self.remote_connection.protocol_mutex: # block response to ping - pong_waiter = self.connection.ping("idem") - with self.assertRaises(RuntimeError) as raised: - self.connection.ping("idem") - self.assertEqual( - str(raised.exception), - "already waiting for a pong with the same data", - ) - self.assertTrue(pong_waiter.wait(MS)) - self.connection.ping("idem") # doesn't raise an exception - def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") - self.assertFalse(pong_waiter.wait(MS)) self.remote_connection.pong("this") self.assertTrue(pong_waiter.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.remote_connection.pong("that") self.assertFalse(pong_waiter.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong with the same payload as a later ping.""" - with self.drop_frames_rcvd(): + with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") self.assertTrue(pong_waiter.wait(MS)) + def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + self.remote_connection.pong("idem") + self.assertTrue(pong_waiter.wait(MS)) + + self.connection.ping("idem") # doesn't raise an exception + # Test pong. def test_pong(self): From c92fba02db87a88af54a47e7f5bae050587490dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 23:11:11 +0200 Subject: [PATCH 533/762] Move asyncio compatibility to a new package. --- pyproject.toml | 4 ++-- src/websockets/asyncio/__init__.py | 0 src/websockets/{legacy => asyncio}/async_timeout.py | 0 src/websockets/{legacy => asyncio}/compatibility.py | 0 src/websockets/legacy/client.py | 2 +- src/websockets/legacy/protocol.py | 2 +- src/websockets/legacy/server.py | 2 +- tests/legacy/test_client_server.py | 2 +- tests/maxi_cov.py | 6 ++++-- 9 files changed, 10 insertions(+), 8 deletions(-) create mode 100644 src/websockets/asyncio/__init__.py rename src/websockets/{legacy => asyncio}/async_timeout.py (100%) rename src/websockets/{legacy => asyncio}/compatibility.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 2367849ca..de8acd6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ branch = true omit = [ # */websockets matches src/websockets and .tox/**/site-packages/websockets "*/websockets/__main__.py", - "*/websockets/legacy/async_timeout.py", - "*/websockets/legacy/compatibility.py", + "*/websockets/asyncio/async_timeout.py", + "*/websockets/asyncio/compatibility.py", "tests/maxi_cov.py", ] diff --git a/src/websockets/asyncio/__init__.py b/src/websockets/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/legacy/async_timeout.py b/src/websockets/asyncio/async_timeout.py similarity index 100% rename from src/websockets/legacy/async_timeout.py rename to src/websockets/asyncio/async_timeout.py diff --git a/src/websockets/legacy/compatibility.py b/src/websockets/asyncio/compatibility.py similarity index 100% rename from src/websockets/legacy/compatibility.py rename to src/websockets/asyncio/compatibility.py diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b0e15b543..d1d8d5608 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -16,6 +16,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike from ..exceptions import ( InvalidHandshake, @@ -40,7 +41,6 @@ from ..http import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .compatibility import asyncio_timeout from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index c28bdcf48..120ff8e73 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -23,6 +23,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers from ..exceptions import ( ConnectionClosed, @@ -49,7 +50,6 @@ ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import asyncio_timeout from .framing import Frame diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index f4442fecc..208ffa780 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -21,6 +21,7 @@ cast, ) +from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( AbortHandshake, @@ -42,7 +43,6 @@ from ..http import USER_AGENT from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .compatibility import asyncio_timeout from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 0c3f22156..b5c5d726a 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -14,6 +14,7 @@ import urllib.request import warnings +from websockets.asyncio.compatibility import asyncio_timeout from websockets.datastructures import Headers from websockets.exceptions import ( ConnectionClosed, @@ -30,7 +31,6 @@ from websockets.frames import CloseCode from websockets.http import USER_AGENT from websockets.legacy.client import * -from websockets.legacy.compatibility import asyncio_timeout from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index bc4a44e8c..83686c3d3 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -52,8 +52,9 @@ def get_mapping(src_dir="src"): os.path.relpath(src_file, src_dir) for src_file in sorted(src_files) if "legacy" not in os.path.dirname(src_file) - if os.path.basename(src_file) != "__init__.py" + and os.path.basename(src_file) != "__init__.py" and os.path.basename(src_file) != "__main__.py" + and os.path.basename(src_file) != "async_timeout.py" and os.path.basename(src_file) != "compatibility.py" ] test_files = [ @@ -102,7 +103,8 @@ def get_ignored_files(src_dir="src"): "*/websockets/typing.py", # We don't test compatibility modules with previous versions of Python # or websockets (import locations). - "*/websockets/*/compatibility.py", + "*/websockets/asyncio/async_timeout.py", + "*/websockets/asyncio/compatibility.py", "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. From 9f8f2f27218e4dc7ad4126109e6ffe012946b71b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Apr 2023 23:12:28 +0200 Subject: [PATCH 534/762] Add asyncio message assembler. --- src/websockets/asyncio/compatibility.py | 20 +- src/websockets/asyncio/messages.py | 283 ++++++++++++++ src/websockets/sync/messages.py | 4 +- tests/asyncio/__init__.py | 0 tests/asyncio/test_messages.py | 471 ++++++++++++++++++++++++ tests/asyncio/utils.py | 5 + 6 files changed, 777 insertions(+), 6 deletions(-) create mode 100644 src/websockets/asyncio/messages.py create mode 100644 tests/asyncio/__init__.py create mode 100644 tests/asyncio/test_messages.py create mode 100644 tests/asyncio/utils.py diff --git a/src/websockets/asyncio/compatibility.py b/src/websockets/asyncio/compatibility.py index 6bd01e70d..390f00ac7 100644 --- a/src/websockets/asyncio/compatibility.py +++ b/src/websockets/asyncio/compatibility.py @@ -3,10 +3,22 @@ import sys -__all__ = ["asyncio_timeout"] +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout"] if sys.version_info[:2] >= (3, 11): - from asyncio import timeout as asyncio_timeout # noqa: F401 -else: - from .async_timeout import timeout as asyncio_timeout # noqa: F401 + TimeoutError = TimeoutError + aiter = aiter + anext = anext + from asyncio import timeout as asyncio_timeout + +else: # Python < 3.11 + from asyncio import TimeoutError + + def aiter(async_iterable): + return type(async_iterable).__aiter__(async_iterable) + + async def anext(async_iterator): + return await type(async_iterator).__anext__(async_iterator) + + from .async_timeout import timeout as asyncio_timeout diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py new file mode 100644 index 000000000..2a9c4d37d --- /dev/null +++ b/src/websockets/asyncio/messages.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import asyncio +import codecs +import collections +from typing import ( + Any, + AsyncIterator, + Callable, + Generic, + Iterable, + TypeVar, +) + +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +T = TypeVar("T") + + +class SimpleQueue(Generic[T]): + """ + Simplified version of :class:`asyncio.Queue`. + + Provides only the subset of functionality needed by :class:`Assembler`. + + """ + + def __init__(self) -> None: + self.loop = asyncio.get_running_loop() + self.get_waiter: asyncio.Future[None] | None = None + self.queue: collections.deque[T] = collections.deque() + + def __len__(self) -> int: + return len(self.queue) + + def put(self, item: T) -> None: + """Put an item into the queue without waiting.""" + self.queue.append(item) + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_result(None) + + async def get(self) -> T: + """Remove and return an item from the queue, waiting if necessary.""" + if not self.queue: + if self.get_waiter is not None: + raise RuntimeError("get is already running") + self.get_waiter = self.loop.create_future() + try: + await self.get_waiter + finally: + self.get_waiter.cancel() + self.get_waiter = None + return self.queue.popleft() + + def reset(self, items: Iterable[T]) -> None: + """Put back items into an empty, idle queue.""" + assert self.get_waiter is None, "cannot reset() while get() is running" + assert not self.queue, "cannot reset() while queue isn't empty" + self.queue.extend(items) + + def abort(self) -> None: + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_exception(EOFError("stream of frames ended")) + # Clear the queue to avoid storing unnecessary data in memory. + self.queue.clear() + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + # coverage reports incorrectly: "line NN didn't jump to the function exit" + def __init__( # pragma: no cover + self, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming messages. Each item is a queue of frames. + self.frames: SimpleQueue[Frame] = SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + self.high = 16 + self.low = 4 + self.paused = False + self.pause = pause + self.resume = resume + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get() or get_iter() is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + + # First frame + try: + frame = await self.frames.get() + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + self.get_in_progress = False + + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`RuntimeError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Raises: + EOFError: If the stream of frames has ended. + RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` + concurrently. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + if self.get_in_progress: + raise RuntimeError("get() or get_iter() is already running") + + # Locking with get_in_progress ensures only one coroutine can get here. + self.get_in_progress = True + + # First frame + try: + frame = await self.frames.get() + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + # We cannot handle asyncio.CancelledError because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise RuntimeError. + frame = await self.frames.get() + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.frames.put(frame) + self.maybe_pause() + + def get_limits(self) -> tuple[int, int]: + """Return low and high water marks for flow control.""" + return self.low, self.high + + def set_limits(self, low: int = 4, high: int = 16) -> None: + """Configure low and high water marks for flow control.""" + self.low, self.high = low, high + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Check for "> high" to support high = 0 + if len(self.frames) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Check for "<= low" to support low = 0 + if len(self.frames) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.frames.abort() diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 6cbff2595..ff90345ac 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -85,7 +85,7 @@ def get(self, timeout: float | None = None) -> Data: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get or get_iter is already running") + raise RuntimeError("get() or get_iter() is already running") self.get_in_progress = True @@ -144,7 +144,7 @@ def get_iter(self) -> Iterator[Data]: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get or get_iter is already running") + raise RuntimeError("get() or get_iter() is already running") chunks = self.chunks self.chunks = [] diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py new file mode 100644 index 000000000..c8a2d7cd5 --- /dev/null +++ b/tests/asyncio/test_messages.py @@ -0,0 +1,471 @@ +import asyncio +import unittest +import unittest.mock + +from websockets.asyncio.compatibility import aiter, anext +from websockets.asyncio.messages import * +from websockets.asyncio.messages import SimpleQueue +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame + +from .utils import alist + + +class SimpleQueueTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.queue = SimpleQueue() + + async def test_len(self): + """__len__ returns queue length.""" + self.assertEqual(len(self.queue), 0) + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + await self.queue.get() + self.assertEqual(len(self.queue), 0) + + async def test_put_then_get(self): + """get returns an item that is already put.""" + self.queue.put(42) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_get_then_put(self): + """get returns an item when it is put.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.put(42) + item = await getter_task + self.assertEqual(item, 42) + + async def test_get_concurrently(self): + """get cannot be called concurrently with itself.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + with self.assertRaises(RuntimeError): + await self.queue.get() + getter_task.cancel() + + async def test_reset(self): + """reset sets the content of the queue.""" + self.queue.reset([42]) + item = await self.queue.get() + self.assertEqual(item, 42) + + async def test_abort(self): + """abort throws an exception in get.""" + getter_task = asyncio.create_task(self.queue.get()) + await asyncio.sleep(0) # let the task start + self.queue.abort() + with self.assertRaises(EOFError): + await getter_task + + async def test_abort_clears_queue(self): + """abort clears buffered data from the queue.""" + self.queue.put(42) + self.assertEqual(len(self.queue), 1) + self.queue.abort() + self.assertEqual(len(self.queue), 0) + + +class AssemblerTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(pause=self.pause, resume=self.resume) + self.assembler.set_limits(low=1, high=2) + + # Test get + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await getter_task + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + getter_task = asyncio.create_task(self.assembler.get()) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await getter_task + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + async def test_cancel_get_before_first_frame(self): + """get can be canceled safely before reading the first frame.""" + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_cancel_get_after_first_frame(self): + """get can be canceled safely after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + # Test get_iter + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await getter_task + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await getter_task + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + iterator = aiter(self.assembler.get_iter()) + + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + async def test_cancel_get_iter_before_first_frame(self): + """get_iter can be canceled safely before reading the first frame.""" + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_cancel_get_iter_after_first_frame(self): + """get cannot be canceled after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start getter_task + getter_task.cancel() + with self.assertRaises(asyncio.CancelledError): + await getter_task + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + + # Test put + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + # Test termination + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + asyncio.get_running_loop().call_soon(self.assembler.close) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently with itself.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await self.assembler.get() + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + self.assembler.close() # let task terminate + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently with itself.""" + asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) + with self.assertRaises(RuntimeError): + await alist(self.assembler.get_iter()) + self.assembler.close() # let task terminate + + # Test getting and setting limits + + async def test_get_limits(self): + """get_limits returns low and high water marks.""" + low, high = self.assembler.get_limits() + self.assertEqual(low, 1) + self.assertEqual(high, 2) + + async def test_set_limits(self): + """set_limits changes low and high water marks.""" + self.assembler.set_limits(low=2, high=4) + low, high = self.assembler.get_limits() + self.assertEqual(low, 2) + self.assertEqual(high, 4) diff --git a/tests/asyncio/utils.py b/tests/asyncio/utils.py new file mode 100644 index 000000000..a611bfc4b --- /dev/null +++ b/tests/asyncio/utils.py @@ -0,0 +1,5 @@ +async def alist(async_iterable): + items = [] + async for item in async_iterable: + items.append(item) + return items From 4a981688198f91385281b8c8e1cdfc0197d43bf5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Apr 2023 08:37:40 +0200 Subject: [PATCH 535/762] Add new asyncio-based implementation. --- docs/project/changelog.rst | 15 +- docs/reference/index.rst | 12 + docs/reference/new-asyncio/client.rst | 53 ++ docs/reference/new-asyncio/common.rst | 43 ++ docs/reference/new-asyncio/server.rst | 72 ++ src/websockets/asyncio/client.py | 331 +++++++++ src/websockets/asyncio/compatibility.py | 12 +- src/websockets/asyncio/connection.py | 883 ++++++++++++++++++++++ src/websockets/asyncio/server.py | 772 +++++++++++++++++++ tests/asyncio/client.py | 33 + tests/asyncio/connection.py | 115 +++ tests/asyncio/server.py | 50 ++ tests/asyncio/test_client.py | 306 ++++++++ tests/asyncio/test_connection.py | 948 ++++++++++++++++++++++++ tests/asyncio/test_server.py | 525 +++++++++++++ 15 files changed, 4165 insertions(+), 5 deletions(-) create mode 100644 docs/reference/new-asyncio/client.rst create mode 100644 docs/reference/new-asyncio/common.rst create mode 100644 docs/reference/new-asyncio/server.rst create mode 100644 src/websockets/asyncio/client.py create mode 100644 src/websockets/asyncio/connection.py create mode 100644 src/websockets/asyncio/server.py create mode 100644 tests/asyncio/client.py create mode 100644 tests/asyncio/connection.py create mode 100644 tests/asyncio/server.py create mode 100644 tests/asyncio/test_client.py create mode 100644 tests/asyncio/test_connection.py create mode 100644 tests/asyncio/test_server.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index fd186a5fc..108b7c9c0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,7 +52,7 @@ Backwards-incompatible changes async def handler(request, path): ... - You should switch to the recommended pattern since 10.1:: + You should switch to the pattern recommended since version 10.1:: async def handler(request): path = request.path # only if handler() uses the path argument @@ -61,6 +61,16 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 11.0 introduces a new :mod:`asyncio` implementation. + :class: important + + This new implementation is intended to be a drop-in replacement for the + current implementation. It will become the default in a future release. + Please try it and report any issue that you encounter! + + See :func:`websockets.asyncio.client.connect` and + :func:`websockets.asyncio.server.serve` for details. + * Validated compatibility with Python 3.12. 12.0 @@ -175,7 +185,8 @@ New features It is particularly suited to client applications that establish only one connection. It may be used for servers handling few connections. - See :func:`~sync.client.connect` and :func:`~sync.server.serve` for details. + See :func:`websockets.sync.client.connect` and + :func:`websockets.sync.server.serve` for details. * Added ``open_timeout`` to :func:`~server.serve`. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 0b80f087a..2486ac564 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -26,6 +26,18 @@ clients concurrently. asyncio/server asyncio/client +:mod:`asyncio` (new) +-------------------- + +This is a rewrite of the :mod:`asyncio` implementation. It will become the +default in the future. + +.. toctree:: + :titlesonly: + + new-asyncio/server + new-asyncio/client + :mod:`threading` ---------------- diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst new file mode 100644 index 000000000..552d83b2f --- /dev/null +++ b/docs/reference/new-asyncio/client.rst @@ -0,0 +1,53 @@ +Client (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.client + +Opening a connection +-------------------- + +.. autofunction:: connect + :async: + +.. autofunction:: unix_connect + :async: + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst new file mode 100644 index 000000000..ba23552dc --- /dev/null +++ b/docs/reference/new-asyncio/common.rst @@ -0,0 +1,43 @@ +:orphan: + +Both sides (:mod:`asyncio` - new) +================================= + +.. automodule:: websockets.asyncio.connection + +.. autoclass:: Connection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst new file mode 100644 index 000000000..f3446fb80 --- /dev/null +++ b/docs/reference/new-asyncio/server.rst @@ -0,0 +1,72 @@ +Server (:mod:`asyncio` - new) +============================= + +.. automodule:: websockets.asyncio.server + +Creating a server +----------------- + +.. autofunction:: serve + :async: + +.. autofunction:: unix_serve + :async: + +Running a server +---------------- + +.. autoclass:: WebSocketServer + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + + .. autoattribute:: sockets + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py new file mode 100644 index 000000000..040d68ece --- /dev/null +++ b/src/websockets/asyncio/client.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +import asyncio +from types import TracebackType +from typing import Any, Generator, Sequence + +from ..client import ClientProtocol +from ..datastructures import HeadersLike +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Response +from ..protocol import CONNECTING, Event +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import parse_uri +from .compatibility import TimeoutError, asyncio_timeout +from .connection import Connection + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + + +class ClientConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ClientProtocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.response_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers["User-Agent"] = user_agent_header + self.protocol.send_request(self.request) + + # May raise CancelledError if open_timeout is exceeded. + await self.response_rcvd + + if self.response is None: + raise ConnectionError("connection closed during handshake") + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.response_rcvd.done(): + self.response_rcvd.set_result(None) + + +# This is spelled in lower case because it's exposed as a callable in the API. +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as an asynchronous context manager:: + + async with websockets.asyncio.client.connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + Args: + uri: URI of the WebSocket server. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings. + When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS + context is created with :func:`~ssl.create_default_context`. + + * You can set ``server_hostname`` to override the host name from ``uri`` in + the TLS handshake. + + * You can set ``host`` and ``port`` to connect to a different host and port + from those found in ``uri``. This only changes the destination of the TCP + connection. The host name from ``uri`` is still used in the TLS handshake + for secure connections and in the ``Host`` header. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_connection` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_connection` method) to create a suitable + client socket and customize it. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + def __init__( + self, + uri: str, + *, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to loop.create_connection + **kwargs: Any, + ) -> None: + + wsuri = parse_uri(uri) + + if wsuri.secure: + kwargs.setdefault("ssl", True) + kwargs.setdefault("server_hostname", wsuri.host) + if kwargs.get("ssl") is None: + raise TypeError("ssl=None is incompatible with a wss:// URI") + else: + if kwargs.get("ssl") is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ClientConnection + + def factory() -> ClientConnection: + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ClientProtocol( + wsuri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_connection = loop.create_unix_connection(factory, **kwargs) + else: + if kwargs.get("sock") is None: + kwargs.setdefault("host", wsuri.host) + kwargs.setdefault("port", wsuri.port) + self._create_connection = loop.create_connection(factory, **kwargs) + + self._handshake_args = ( + additional_headers, + user_agent_header, + ) + + self._open_timeout = open_timeout + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.connection.close() + + # ... = await connect(...) + + def __await__(self) -> Generator[Any, None, ClientConnection]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> ClientConnection: + try: + async with asyncio_timeout(self._open_timeout): + _transport, self.connection = await self._create_connection + try: + await self.connection.handshake(*self._handshake_args) + except (Exception, asyncio.CancelledError): + self.connection.transport.close() + raise + else: + return self.connection + except TimeoutError: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during handshake") from None + + # ... = yield from connect(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_connect( + path: str | None = None, + uri: str | None = None, + **kwargs: Any, +) -> connect: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl`` argument is provided, to + ``wss://localhost/``. + + """ + if uri is None: + if kwargs.get("ssl") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) diff --git a/src/websockets/asyncio/compatibility.py b/src/websockets/asyncio/compatibility.py index 390f00ac7..e17000069 100644 --- a/src/websockets/asyncio/compatibility.py +++ b/src/websockets/asyncio/compatibility.py @@ -3,14 +3,17 @@ import sys -__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout"] +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"] if sys.version_info[:2] >= (3, 11): TimeoutError = TimeoutError aiter = aiter anext = anext - from asyncio import timeout as asyncio_timeout + from asyncio import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) else: # Python < 3.11 from asyncio import TimeoutError @@ -21,4 +24,7 @@ def aiter(async_iterable): async def anext(async_iterator): return await type(async_iterator).__anext__(async_iterator) - from .async_timeout import timeout as asyncio_timeout + from .async_timeout import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py new file mode 100644 index 000000000..550c0ac97 --- /dev/null +++ b/src/websockets/asyncio/connection.py @@ -0,0 +1,883 @@ +from __future__ import annotations + +import asyncio +import collections +import contextlib +import logging +import random +import struct +import uuid +from types import TracebackType +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Iterable, + Mapping, + cast, +) + +from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .compatibility import TimeoutError, aiter, anext, asyncio_timeout_at +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection(asyncio.Protocol): + """ + :mod:`asyncio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.asyncio.client.ClientConnection` or + :class:`~websockets.asyncio.server.ServerConnection`. + + """ + + def __init__( + self, + protocol: Protocol, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol = protocol + self.close_timeout = close_timeout + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Event loop running this connection. + self.loop = asyncio.get_running_loop() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages: Assembler # initialized in connection_made + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Protect sending fragmented messages. + self.fragmented_send_waiter: asyncio.Future[None] | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + + # Adapted from asyncio.FlowControlMixin + self.paused: bool = False + self.drain_waiters: collections.deque[asyncio.Future[None]] = ( + collections.deque() + ) + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.transport.get_extra_info("peername") + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.close() + else: + await self.close(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + async def recv(self) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get() + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise :exc:`RuntimeError`, making the + connection unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`close`. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(): + yield frame + except EOFError: + raise self.protocol.close_exc from self.recv_exc + except RuntimeError: + raise RuntimeError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + + async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If the connection busy sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.fragmented_send_waiter is not None: + await asyncio.shield(self.fragmented_send_waiter) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.fragmented_send_waiter is None + self.fragmented_send_waiter = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + text = True + async with self.send_context(): + self.protocol.send_text( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike): + text = False + async with self.send_context(): + self.protocol.send_binary( + chunk, + fin=False, + ) + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and text: + async with self.send_context(): + self.protocol.send_continuation( + chunk.encode(), + fin=False, + ) + elif isinstance(chunk, BytesLike) and not text: + async with self.send_context(): + self.protocol.send_continuation( + chunk, + fin=False, + ) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail(1011, "error in fragmented message") + raise + + finally: + self.fragmented_send_waiter.set_result(None) + self.fragmented_send_waiter = None + + else: + raise TypeError("data must be bytes, str, iterable, or async iterable") + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.fragmented_send_waiter is not None: + self.protocol.fail(1011, "close during fragmented message") + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: Data | None = None) -> Awaitable[None]: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. + + :: + + pong_waiter = await ws.ping() + # only if you want to wait for the corresponding pong + latency = await pong_waiter + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if data is not None: + data = prepare_ctrl(data) + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pong_waiters: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pong_waiters: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = self.loop.create_future() + # The event loop's default clock is time.monotonic(). Its resolution + # is a bit low on Windows (~16ms). We cannot use time.perf_counter() + # because it doesn't count time elapsed while the process sleeps. + ping_timestamp = self.loop.time() + self.pong_waiters[data] = (pong_waiter, ping_timestamp) + self.protocol.send_ping(data) + return pong_waiter + + async def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + data = prepare_ctrl(data) + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pong_waiters: + return + + pong_timestamp = self.loop.time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + ping_ids.append(ping_id) + pong_waiter.set_result(pong_timestamp - ping_timestamp) + if ping_id == data: + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pong_waiters. + for ping_id in ping_ids: + del self.pong_waiters[ping_id] + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.protocol.state is CLOSED + exc = self.protocol.close_exc + + for pong_waiter, _ping_timestamp in self.pong_waiters.values(): + pong_waiter.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + pong_waiter.cancel() + + self.pong_waiters.clear() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the transport and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, RuntimeError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING), self.close_deadline is still None. + if self.close_timeout is not None: + assert self.close_deadline is None + self.close_deadline = self.loop.time() + self.close_timeout + # Write outgoing data to the socket and enforce flow control. + try: + self.send_data() + await self.drain() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + try: + async with asyncio_timeout_at(self.close_deadline): + await asyncio.shield(self.connection_lost_waiter) + except TimeoutError: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + self.close_transport() + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + Raises: + OSError: When a socket operations fails. + + """ + for data in self.protocol.data_to_send(): + if data: + self.transport.write(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if self.transport.can_write_eof(): + if self.debug: + self.logger.debug("x half-closing TCP connection") + # write_eof() doesn't document which exceptions it raises. + # OSError is plausible. uvloop can raise RuntimeError here. + try: + self.transport.write_eof() + except (OSError, RuntimeError): # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + def close_transport(self) -> None: + """ + Close transport and message assembler. + + """ + self.transport.close() + self.recv_messages.close() + + # asyncio.Protocol methods + + # Connection callbacks + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.recv_messages = Assembler( + pause=self.transport.pause_reading, + resume=self.transport.resume_reading, + ) + + def connection_lost(self, exc: Exception | None) -> None: + self.protocol.receive_eof() # receive_eof is idempotent + self.recv_messages.close() + self.set_recv_exc(exc) + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + self.abort_pings() + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: # pragma: no cover + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + # Flow control callbacks + + def pause_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert not self.paused + self.paused = True + + def resume_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert self.paused + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + waiter.set_result(None) + + async def drain(self) -> None: # pragma: no cover + # We don't check if the connection is closed because we call drain() + # immediately after write() and write() would fail in that case. + + # Adapted from asyncio.streams.StreamWriter + # Yield to the event loop so that connection_lost() may be called. + if self.transport.is_closing(): + await asyncio.sleep(0) + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: + waiter = self.loop.create_future() + self.drain_waiters.append(waiter) + try: + await waiter + finally: + self.drain_waiters.remove(waiter) + + # Streaming protocol callbacks + + def data_received(self, data: bytes) -> None: + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the transport. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("error while sending data", exc_info=True) + self.set_recv_exc(exc) + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = self.loop.time() + self.close_timeout + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + def eof_received(self) -> None: + # Feed the end of the data stream to the connection. + self.protocol.receive_eof() + + # This isn't expected to generate events. + assert not self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it shouldn't raise errors. + self.send_data() + + # The WebSocket protocol has its own closing handshake: endpoints close + # the TCP or TLS connection after sending and receiving a close frame. + # As a consequence, they never need to write after receiving EOF, so + # there's no reason to keep the transport open by returning True. + # Besides, that doesn't work on TLS connections. diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py new file mode 100644 index 000000000..aa175f775 --- /dev/null +++ b/src/websockets/asyncio/server.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import asyncio +import http +import logging +import socket +import sys +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Generator, + Iterable, + Sequence, +) + +from websockets.frames import CloseCode + +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import validate_subprotocols +from ..http import USER_AGENT +from ..http11 import Request, Response +from ..protocol import CONNECTING, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, Subprotocol +from .compatibility import asyncio_timeout +from .connection import Connection + + +__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] + + +class ServerConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + Args: + protocol: Sans-I/O connection. + server: Server that manages this connection. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + + """ + + def __init__( + self, + protocol: ServerProtocol, + server: WebSocketServer, + *, + close_timeout: float | None = 10, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + protocol, + close_timeout=close_timeout, + ) + self.server = server + self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + # May raise CancelledError if open_timeout is exceeded. + await self.request_rcvd + + if self.request is None: + raise ConnectionError("connection closed during handshake") + + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if self.server.is_serving(): + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header is not None: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + if self.protocol.handshake_exc is not None: + try: + async with asyncio_timeout(self.close_timeout): + await self.connection_lost_waiter + finally: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + self.server.start_connection_handler(self) + + def connection_lost(self, exc: Exception | None) -> None: + try: + super().connection_lost(exc) + finally: + # If the connection is closed during the handshake, unblock it. + if not self.request_rcvd.done(): + self.request_rcvd.set_result(None) + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~asyncio.Server`. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + *, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + logger: LoggerLike | None = None, + ) -> None: + self.loop = asyncio.get_running_loop() + self.handler = handler + self.process_request = process_request + self.process_response = process_response + self.server_header = server_header + self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + # Keep track of active connections. + self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} + + # Task responsible for closing the server and terminating connections. + self.close_task: asyncio.Task[None] | None = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + + def wrap(self, server: asyncio.Server) -> None: + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + + """ + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + async def conn_handler(self, connection: ServerConnection) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller that can handle exceptions, + it attempts to log relevant ones. + + It guarantees that the TCP connection is closed before exiting. + + """ + try: + # On failure, handshake() closes the transport, raises an + # exception, and logs it. + async with asyncio_timeout(self.open_timeout): + await connection.handshake( + self.process_request, + self.process_response, + self.server_header, + ) + + try: + await self.handler(connection) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + await connection.close(CloseCode.INTERNAL_ERROR) + else: + await connection.close() + + except Exception: + # Don't leak connections on errors. + connection.transport.abort() + + finally: + # Registration is tied to the lifecycle of conn_handler() because + # the server waits for connection handlers to terminate, even if + # all connections are already closed. + del self.handlers[connection] + + def start_connection_handler(self, connection: ServerConnection) -> None: + """ + Register a connection with this server. + + """ + # The connection must be registered in self.handlers immediately. + # If it was registered in conn_handler(), a race condition could + # happen when closing the server after scheduling conn_handler() + # but before it starts executing. + self.handlers[connection] = self.loop.create_task(self.conn_handler(connection)) + + def close(self, close_connections: bool = True) -> None: + """ + Close the server. + + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) + + async def _close(self, close_connections: bool) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.server.close() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. + await asyncio.sleep(0) + + if close_connections: + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. + close_tasks = [ + asyncio.create_task(connection.close(1001)) + for connection in self.handlers + if connection.protocol.state is not CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait(close_tasks) + + # Wait until all TCP connections are closed. + await self.server.wait_closed() + + # Wait until all connection handlers terminate. + # asyncio.wait doesn't accept an empty first argument. + if self.handlers: + await asyncio.wait(self.handlers.values()) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + self.logger.info("server closed") + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~ServerConnection.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~ServerConnection.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + + """ + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: # pragma: no cover + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.start_serving`. + + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + + """ + await self.server.start_serving() + + async def serve_forever(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.serve_forever`. + + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + + """ + await self.server.serve_forever() + + @property + def sockets(self) -> Iterable[socket.socket]: + """ + See :attr:`asyncio.Server.sockets`. + + """ + return self.server.sockets + + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() + + +# This is spelled in lower case because it's exposed as a callable in the API. +class serve: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler`` coroutine. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + This coroutine returns a :class:`WebSocketServer` whose API mirrors + :class:`~asyncio.Server`. Treat it as an asynchronous context manager to + ensure that the server will be closed:: + + def handler(websocket): + ... + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with websockets.asyncio.server.serve(handler, host, port): + await stop + + Alternatively, call :meth:`~WebSocketServer.serve_forever` to serve requests + and cancel it to stop the server:: + + server = await websockets.asyncio.server.serve(handler, host, port) + await server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_server` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_server` method) to create a suitable server + socket and customize it. + + * You can set ``start_serving`` to ``False`` to start accepting connections + only after you call :meth:`~WebSocketServer.start_serving()` or + :meth:`~WebSocketServer.serve_forever()`. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + host: str | None = None, + port: int | None = None, + *, + # WebSocket + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = USER_AGENT, + compression: str | None = "deflate", + # Timeouts + open_timeout: float | None = 10, + close_timeout: float | None = 10, + # Limits + max_size: int | None = 2**20, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Other keyword arguments are passed to loop.create_server + **kwargs: Any, + ) -> None: + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + self.server = WebSocketServer( + handler, + process_request=process_request, + process_response=process_response, + server_header=server_header, + open_timeout=open_timeout, + logger=logger, + ) + + if kwargs.get("ssl") is not None: + kwargs.setdefault("ssl_handshake_timeout", open_timeout) + if sys.version_info[:2] >= (3, 11): # pragma: no branch + kwargs.setdefault("ssl_shutdown_timeout", close_timeout) + + def factory() -> ServerConnection: + """ + Create an asyncio protocol for managing a WebSocket connection. + + """ + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + self.server, + close_timeout=close_timeout, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self._create_server = loop.create_unix_server(factory, **kwargs) + else: + # mypy cannot tell that kwargs must provide sock when port is None. + self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] + + # async with serve(...) as ...: ... + + async def __aenter__(self) -> WebSocketServer: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.server.close() + await self.server.wait_closed() + + # ... = await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketServer: + server = await self._create_server + self.server.wrap(server) + return self.server + + # ... = yield from serve(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +def unix_serve( + handler: Callable[[ServerConnection], Awaitable[None]], + path: str | None = None, + **kwargs: Any, +) -> Awaitable[WebSocketServer]: + """ + Create a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, unix=True, path=path, **kwargs) diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py new file mode 100644 index 000000000..e5826add7 --- /dev/null +++ b/tests/asyncio/client.py @@ -0,0 +1,33 @@ +import contextlib + +from websockets.asyncio.client import * +from websockets.asyncio.server import WebSocketServer + +from .server import get_server_host_port + + +__all__ = [ + "run_client", + "run_unix_client", +] + + +@contextlib.asynccontextmanager +async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): + if isinstance(wsuri_or_server, str): + wsuri = wsuri_or_server + else: + assert isinstance(wsuri_or_server, WebSocketServer) + if secure is None: + secure = "ssl" in kwargs + protocol = "wss" if secure else "ws" + host, port = get_server_host_port(wsuri_or_server) + wsuri = f"{protocol}://{host}:{port}{resource_name}" + async with connect(wsuri, **kwargs) as client: + yield client + + +@contextlib.asynccontextmanager +async def run_unix_client(path, **kwargs): + async with unix_connect(path, **kwargs) as client: + yield client diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py new file mode 100644 index 000000000..ad1c121bf --- /dev/null +++ b/tests/asyncio/connection.py @@ -0,0 +1,115 @@ +import asyncio +import contextlib + +from websockets.asyncio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def connection_made(self, transport): + super().connection_made(InterceptingTransport(transport)) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write is None + self.transport.delay_write = delay + try: + yield + finally: + self.transport.delay_write = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + This can result in out-of-order writes, which is unrealistic. + + """ + assert self.transport.delay_write_eof is None + self.transport.delay_write_eof = delay + try: + yield + finally: + self.transport.delay_write_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write + self.transport.drop_write = True + try: + yield + finally: + self.transport.drop_write = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.transport.drop_write_eof + self.transport.drop_write_eof = True + try: + yield + finally: + self.transport.drop_write_eof = False + + +class InterceptingTransport: + """ + Transport wrapper that intercepts calls to ``write()`` and ``write_eof()``. + + This is coupled to the implementation, which relies on these two methods. + + Since ``write()`` and ``write_eof()`` are not coroutines, this effect is + achieved by scheduling writes at a later time, after the methods return. + This can easily result in out-of-order writes, which is unrealistic. + + """ + + def __init__(self, transport): + self.loop = asyncio.get_running_loop() + self.transport = transport + self.delay_write = None + self.delay_write_eof = None + self.drop_write = False + self.drop_write_eof = False + + def __getattr__(self, name): + return getattr(self.transport, name) + + def write(self, data): + if not self.drop_write: + if self.delay_write is not None: + self.loop.call_later(self.delay_write, self.transport.write, data) + else: + self.transport.write(data) + + def write_eof(self): + if not self.drop_write_eof: + if self.delay_write_eof is not None: + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + else: + self.transport.write_eof() diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py new file mode 100644 index 000000000..0fe20dc65 --- /dev/null +++ b/tests/asyncio/server.py @@ -0,0 +1,50 @@ +import asyncio +import contextlib +import socket + +from websockets.asyncio.server import * + + +def get_server_host_port(server): + for sock in server.sockets: + if sock.family == socket.AF_INET: # pragma: no branch + return sock.getsockname() + raise AssertionError("expected at least one IPv4 socket") + + +async def eval_shell(ws): + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + + +class EvalShellMixin: + async def assertEval(self, client, expr, value): + await client.send(expr) + self.assertEqual(await client.recv(), value) + + +async def crash(ws): + raise RuntimeError + + +async def do_nothing(ws): + pass + + +async def keep_running(ws): + delay = float(await ws.recv()) + await ws.close() + await asyncio.sleep(delay) + + +@contextlib.asynccontextmanager +async def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): + async with serve(handler, host, port, **kwargs) as server: + yield server + + +@contextlib.asynccontextmanager +async def run_unix_server(path, handler=eval_shell, **kwargs): + async with unix_serve(handler, path, **kwargs) as server: + yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py new file mode 100644 index 000000000..aab65cd2e --- /dev/null +++ b/tests/asyncio/test_client.py @@ -0,0 +1,306 @@ +import asyncio +import socket +import ssl +import unittest + +from websockets.asyncio.client import * +from websockets.asyncio.compatibility import TimeoutError +from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.extensions.permessage_deflate import PerMessageDeflate + +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path +from .client import run_client, run_unix_client +from .server import do_nothing, get_server_host_port, run_server, run_unix_server + + +class ClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with run_server() as server: + with socket.create_connection(get_server_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to the right socket. + async with run_client("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with run_server() as server: + async with run_client( + server, additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with run_server() as server: + async with run_client(server, user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with run_server() as server: + async with run_client(server, compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with run_server() as server: + async with run_client( + server, create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + async def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + async with run_client("http://localhost"): # invalid scheme + self.fail("did not raise") + + async def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + async with run_client("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + async def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server( + do_nothing, process_response=remove_accept_header + ) as server: + with self.assertRaises(InvalidHandshake) as raised: + async with run_client(server, close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + + async def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + gate = asyncio.get_running_loop().create_future() + + async def stall_connection(self, request): + await gate + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server(do_nothing, process_request=stall_connection) as server: + try: + with self.assertRaises(TimeoutError) as raised: + async with run_client(server, open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) + finally: + gate.set_result(None) + + async def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + def close_connection(self, request): + self.close_transport() + + async with run_server(process_request=close_connection) as server: + with self.assertRaises(ConnectionError) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connection closed during handshake", + ) + + +class SecureClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate isn't trusted system-wide. + async with run_client(server, secure=True): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + + async def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # This hostname isn't included in the test certificate. + async with run_client( + server, ssl=CLIENT_CONTEXT, server_hostname="invalid" + ): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception), + ) + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_set_host_header(self): + """Client sets the Host header to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path, uri="ws://overridden/") as client: + self.assertEqual(client.request.headers["Host"], "overridden") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Client connects to server securely over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + # This is part of the documented behavior of unix_connect(). + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client( + path, + ssl=CLIENT_CONTEXT, + uri="wss://overridden/", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + + +class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", ssl=CLIENT_CONTEXT) + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_secure_uri_without_ssl(self): + """Client rejects no ssl when URI is secure.""" + with self.assertRaises(TypeError) as raised: + await connect("wss://localhost/", ssl=None) + self.assertEqual( + str(raised.exception), + "ssl=None is incompatible with a wss:// URI", + ) + + async def test_unix_without_path_or_sock(self): + """Unix client requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_connect() + self.assertEqual( + str(raised.exception), + "no path and sock were specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix client rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_connect(path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await connect("ws://localhost/", subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py new file mode 100644 index 000000000..a8b3980b4 --- /dev/null +++ b/tests/asyncio/test_connection.py @@ -0,0 +1,948 @@ +import asyncio +import contextlib +import logging +import socket +import unittest +import uuid +from unittest.mock import patch + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout +from websockets.asyncio.connection import * +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol + +from ..protocol import RecordingProtocol +from ..utils import MS +from .connection import InterceptingConnection +from .utils import alist + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + loop = asyncio.get_running_loop() + socket_, remote_socket = socket.socketpair() + self.transport, self.connection = await loop.create_connection( + lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), + sock=socket_, + ) + self.remote_transport, self.remote_connection = await loop.create_connection( + lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), + sock=remote_socket, + ) + + async def asyncTearDown(self): + await self.remote_connection.close() + await self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + # Let the remote side process messages. + # Two runs of the event loop are required for answering pings. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + # Run the event loop twice for consistency with assertFrameSent. + await asyncio.sleep(0) + await asyncio.sleep(0) + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await asyncio.sleep(MS) # let the remote side process messages + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + aiterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(aiterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + aiterator = aiter(self.connection) + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(aiterator) + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnnectionClosedError after an error.""" + aiterator = aiter(self.connection) + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(aiterator) + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_during_recv(self): + """recv raises RuntimeError when called concurrently with itself.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_during_recv_streaming(self): + """recv raises RuntimeError when called concurrently with recv_streaming.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_cancellation_before_receiving(self): + """recv can be cancelled before receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv cannot be cancelled after receiving a frame.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_task + + # Running recv again receives the complete message. + gate.set_result(None) + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises RuntimeError when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task + self.addCleanup(recv_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises RuntimeError when called concurrently with itself.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + self.addCleanup(recv_streaming_task.cancel) + + with self.assertRaises(RuntimeError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be cancelled before receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be cancelled after receiving a frame.""" + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.remote_connection.send(fragments())) + await asyncio.sleep(MS) + + recv_streaming_task.cancel() + await asyncio.sleep(0) # let the event loop cancel recv_streaming_task + + gate.set_result(None) + # Running recv_streaming again fails. + with self.assertRaises(RuntimeError): + await alist(self.connection.recv_streaming()) + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_while_send_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an Iterable. + self.connection.pause_writing() + asyncio.create_task(self.connection.send(["⏳", "⌛️"])) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_while_send_async_blocked(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + self.connection.pause_writing() + + async def fragments(): + yield "⏳" + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + self.connection.resume_writing() + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_during_send_async(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with fragmented_send_waiter is removed + # from send() in the case when message is an AsyncIterable. + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + asyncio.create_task(self.connection.send("✅")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + gate.set_result(None) + await asyncio.sleep(MS) + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + with self.assertRaises(TypeError): + await self.connection.send(fragments()) + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test close. + + async def test_close(self): + """close sends a close frame.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + await self.connection.close(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_close_frame(self): + """close without timeout waits for a close frame (then EOF) before returning.""" + self.connection.close_timeout = None + + async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_connection_closed(self): + """close without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + async with self.delay_eof_rcvd(MS): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + async with self.drop_eof_rcvd(): + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + # Remove socket.timeout when dropping Python < 3.10. + self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) + + async def test_close_does_not_wait_for_recv(self): + # The asyncio implementation has a buffer for incoming messages. Closing + # the connection discards buffered messages. This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. + await self.remote_connection.send("😀") + await self.connection.close() + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.close() + await self.assertNoFrameSent() + + async def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(MS) + await self.connection.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await recv_task + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + yield "⌛️" + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + + asyncio.create_task(self.connection.close()) + await asyncio.sleep(MS) + + gate.set_result(None) + + with self.assertRaises(ConnectionClosedError) as raised: + await send_task + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test wait_closed. + + async def test_wait_closed(self): + """wait_closed waits for the connection to close.""" + wait_closed_task = asyncio.create_task(self.connection.wait_closed()) + await asyncio.sleep(0) # let the event loop start wait_closed_task + self.assertFalse(wait_closed_task.done()) + await self.connection.close() + self.assertTrue(wait_closed_task.done()) + + # Test ping. + + @patch("random.getrandbits") + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.return_value = 1918987876 + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("this") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await pong_waiter + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong with the same payload as a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + async with asyncio_timeout(MS): + await pong_waiter + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("idem") + + with self.assertRaises(RuntimeError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + async with asyncio_timeout(MS): + await pong_waiter + + await self.connection.ping("idem") # doesn't raise an exception + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234) + ) + async def test_local_address(self, get_extra_info): + """Connection provides a local_address attribute.""" + self.assertEqual(self.connection.local_address, ("sock", 1234)) + get_extra_info.assert_called_with("sockname") + + @unittest.mock.patch( + "asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234) + ) + async def test_remote_address(self, get_extra_info): + """Connection provides a remote_address attribute.""" + self.assertEqual(self.connection.remote_address, ("peer", 1234)) + get_extra_info.assert_called_with("peername") + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + # Test reporting of network errors. + + async def test_writing_in_data_received_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the transport for writing — but not by + # closing it because that would terminate the connection. + self.transport.write_eof() + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + cause = raised.exception.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received") + async def test_unexpected_failure_in_data_received(self, events_received): + """Unexpected internal error in data_received() is correctly reported.""" + # Inject a fault in a random call in data_received(). + # This test is tightly coupled to the implementation. + events_received.side_effect = AssertionError + # Receive a message to trigger the fault. + await self.remote_connection.send("😀") + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + @patch("websockets.protocol.Protocol.send_text") + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + send_text.side_effect = AssertionError + + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + + exc = raised.exception + self.assertEqual(str(exc), "no close frame received or sent") + self.assertIsInstance(exc.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py new file mode 100644 index 000000000..2e59f49b1 --- /dev/null +++ b/tests/asyncio/test_server.py @@ -0,0 +1,525 @@ +import asyncio +import dataclasses +import http +import logging +import socket +import unittest + +from websockets.asyncio.compatibility import TimeoutError, asyncio_timeout +from websockets.asyncio.server import * +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response + +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + temp_unix_socket_path, +) +from .client import run_client, run_unix_client +from .server import ( + EvalShellMixin, + crash, + do_nothing, + eval_shell, + get_server_host_port, + keep_running, + run_server, + run_unix_server, +) + + +class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_connection_handler_returns(self): + """Connection handler returns.""" + async with run_server(do_nothing) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) + + async def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + async with run_server(crash) as server: + async with run_client(server) as client: + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); " + "then sent 1011 (internal error)", + ) + + async def test_existing_socket(self): + """Server receives connection using a pre-existing socket.""" + with socket.create_server(("localhost", 0)) as sock: + async with run_server(sock=sock, host=None, port=None): + uri = "ws://{}:{}/".format(*sock.getsockname()) + async with run_client(uri) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + async with run_server( + subprotocols=["chat"], + select_subprotocol=select_subprotocol, + ) as server: + async with run_client(server, subprotocols=["chat"]) as client: + await self.assertEval(client, "ws.select_subprotocol_ran", "True") + await self.assertEval(client, "ws.subprotocol", "chat") + + async def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_request(self): + """Server runs process_request before processing the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_async_process_request(self): + """Server runs async process_request before processing the handshake.""" + + async def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_abort_handshake(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_async_process_request_abort_handshake(self): + """Server aborts handshake if async process_request returns a response.""" + + async def process_request(ws, request): + return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_request_raises_exception(self): + """Server returns an error if async process_request raises an exception.""" + + async def process_request(ws, request): + raise RuntimeError + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_process_response(self): + """Server runs process_response after processing the handshake.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_async_process_response(self): + """Server runs async process_response after processing the handshake.""" + + async def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_process_response_override_response(self): + """Server runs process_response and overrides the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_async_process_response_override_response(self): + """Server runs async process_response and overrides the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse-Ran"] = "true" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual( + client.response.headers["X-ProcessResponse-Ran"], "true" + ) + + async def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_async_process_response_raises_exception(self): + """Server returns an error if async process_response raises an exception.""" + + async def process_response(ws, request, response): + raise RuntimeError + + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_override_server(self): + """Server can override Server header with server_header.""" + async with run_server(server_header="Neo") as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.response.headers['Server']", "Neo") + + async def test_remove_server(self): + """Server can remove Server header with server_header.""" + async with run_server(server_header=None) as server: + async with run_client(server) as client: + await self.assertEval( + client, "'Server' in ws.response.headers", "False" + ) + + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with run_server() as server: + async with run_client(server) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with run_server(compression=None) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + + async def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + async with run_server(create_connection=create_connection) as server: + async with run_client(server) as client: + await self.assertEval(client, "ws.create_connection_ran", "True") + + async def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + async with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + async with run_server(open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + async with run_server() as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + async def test_close_server_rejects_connecting_connections(self): + """Server rejects connecting connections with HTTP 503 when closing.""" + + async def process_request(ws, _request): + while ws.server.is_serving(): + await asyncio.sleep(0) + + async with run_server(process_request=process_request) as server: + asyncio.get_running_loop().call_later(MS, server.close) + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 503", + ) + + async def test_close_server_closes_open_connections(self): + """Server closes open connections with close code 1001 when closing.""" + async with run_server() as server: + async with run_client(server) as client: + server.close() + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1001 (going away); then sent 1001 (going away)", + ) + + async def test_close_server_keeps_connections_open(self): + """Server waits for client to close open connections when closing.""" + async with run_server() as server: + async with run_client(server) as client: + server.close(close_connections=False) + + # Server cannot receive new connections. + await asyncio.sleep(0) + self.assertFalse(server.sockets) + + # The server waits for the client to close the connection. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + # Once the client closes the connection, the server terminates. + await client.close() + async with asyncio_timeout(MS): + await server.wait_closed() + + async def test_close_server_keeps_handlers_running(self): + """Server waits for connection handlers to terminate.""" + async with run_server(keep_running) as server: + async with run_client(server) as client: + # Delay termination of connection handler. + await client.send(str(2 * MS)) + + server.close() + + # The server waits for the connection handler to terminate. + with self.assertRaises(TimeoutError): + async with asyncio_timeout(MS): + await server.wait_closed() + + async with asyncio_timeout(2 * MS): + await server.wait_closed() + + +SSL_OBJECT = "ws.transport.get_extra_info('ssl_object')" + + +class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client(server, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + async def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + try: + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + async def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + _reader, writer = await asyncio.open_connection( + *get_server_host_port(server) + ) + writer.close() + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path): + async with run_unix_client(path) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class SecureUnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_connection(self): + """Server receives secure connection from client over a Unix socket.""" + with temp_unix_socket_path() as path: + async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + + +class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): + async def test_unix_without_path_or_sock(self): + """Unix server requires path when sock isn't provided.""" + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell) + self.assertEqual( + str(raised.exception), + "path was not specified, and no sock specified", + ) + + async def test_unix_with_path_and_sock(self): + """Unix server rejects path when sock is provided.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.addCleanup(sock.close) + with self.assertRaises(ValueError) as raised: + await unix_serve(eval_shell, path="/", sock=sock) + self.assertEqual( + str(raised.exception), + "path and sock can not be specified at the same time", + ) + + async def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await serve(eval_shell, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await serve(eval_shell, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + +class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): + async def test_logger(self): + """WebSocketServer accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) From d2120de4708d09d3cade465541f202d5a8bab722 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 08:31:06 +0200 Subject: [PATCH 536/762] Add an option to disable decoding of text frames. Also support decoding binary frames. Fix #1376. --- src/websockets/asyncio/connection.py | 34 ++++++++++++++++++++++++---- src/websockets/asyncio/messages.py | 10 ++++++++ tests/asyncio/test_connection.py | 26 +++++++++++++++++++++ 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 550c0ac97..152c6789e 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -172,7 +172,7 @@ async def __aiter__(self) -> AsyncIterator[Data]: except ConnectionClosedOK: return - async def recv(self) -> Data: + async def recv(self, decode: bool | None = None) -> Data: """ Receive the next message. @@ -192,6 +192,10 @@ async def recv(self) -> Data: When the message is fragmented, :meth:`recv` waits until all fragments are received, reassembles them, and returns the whole message. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -199,6 +203,15 @@ async def recv(self) -> Data: .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return a bytestring (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. RuntimeError: If two coroutines call :meth:`recv` or @@ -206,7 +219,7 @@ async def recv(self) -> Data: """ try: - return await self.recv_messages.get() + return await self.recv_messages.get(decode) except EOFError: raise self.protocol.close_exc from self.recv_exc except RuntimeError: @@ -215,7 +228,7 @@ async def recv(self) -> Data: "is already running recv or recv_streaming" ) from None - async def recv_streaming(self) -> AsyncIterator[Data]: + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Receive the next message frame by frame. @@ -232,6 +245,10 @@ async def recv_streaming(self) -> AsyncIterator[Data]: iterator in a partially consumed state, making the connection unusable. Instead, you should close the connection with :meth:`close`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -239,6 +256,15 @@ async def recv_streaming(self) -> AsyncIterator[Data]: .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. RuntimeError: If two coroutines call :meth:`recv` or @@ -246,7 +272,7 @@ async def recv_streaming(self) -> AsyncIterator[Data]: """ try: - async for frame in self.recv_messages.get_iter(): + async for frame in self.recv_messages.get_iter(decode): yield frame except EOFError: raise self.protocol.close_exc from self.recv_exc diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 2a9c4d37d..bc33df8d7 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -121,6 +121,11 @@ async def get(self, decode: bool | None = None) -> Data: received, then it reassembles the message and returns it. To receive messages frame by frame, use :meth:`get_iter` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` @@ -183,6 +188,11 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index a8b3980b4..2efd4e96d 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -167,6 +167,16 @@ async def test_recv_binary(self): await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + async def test_recv_encoded_text(self): + """recv receives an UTF-8 encoded text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) + + async def test_recv_decoded_binary(self): + """recv receives an UTF-8 decoded binary message.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual(await self.connection.recv(decode=True), "😀") + async def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" await self.remote_connection.send(["😀", "😀"]) @@ -271,6 +281,22 @@ async def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + async def test_recv_streaming_encoded_text(self): + """recv_streaming receives an UTF-8 encoded text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + async def test_recv_streaming_decoded_binary(self): + """recv_streaming receives a UTF-8 decoded binary message.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual( + await alist(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + async def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" await self.remote_connection.send(["😀", "😀"]) From e35c15a2a70c347f6f7a3e503ff1181ac35e1298 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 7 Aug 2024 18:45:02 +0200 Subject: [PATCH 537/762] Reduce MS for situations with performance penalties. Nowadays it's tuned with WEBSOCKETS_TESTS_TIMEOUT_FACTOR. --- tests/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index bd3b61d7b..1793f3e8b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,11 +45,11 @@ # PyPy has a performance penalty for this test suite. if platform.python_implementation() == "PyPy": # pragma: no cover - MS *= 5 + MS *= 2 -# asyncio's debug mode has a 10x performance penalty for this test suite. +# asyncio's debug mode has a performance penalty for this test suite. if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover - MS *= 10 + MS *= 2 # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From bbb316155a5aeb719f262873a5b29a98c19b25d9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 8 Aug 2024 10:07:44 +0200 Subject: [PATCH 538/762] Add env vars for configuring constants. --- docs/project/changelog.rst | 6 ++++ docs/reference/index.rst | 1 + docs/reference/variables.rst | 48 ++++++++++++++++++++++++++++++ docs/topics/logging.rst | 4 +++ docs/topics/security.rst | 38 +++++++++++++++++------ src/websockets/asyncio/client.py | 5 ++-- src/websockets/asyncio/server.py | 11 ++++--- src/websockets/frames.py | 17 ++++++----- src/websockets/http.py | 9 ------ src/websockets/http11.py | 36 +++++++++++++++++----- src/websockets/legacy/client.py | 4 +-- src/websockets/legacy/http.py | 9 +++--- src/websockets/legacy/server.py | 8 ++--- src/websockets/sync/client.py | 3 +- src/websockets/sync/server.py | 9 +++--- tests/legacy/test_client_server.py | 2 +- tests/test_http.py | 8 ----- 17 files changed, 149 insertions(+), 69 deletions(-) create mode 100644 docs/reference/variables.rst delete mode 100644 tests/test_http.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 108b7c9c0..8143e3483 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -73,6 +73,12 @@ New features * Validated compatibility with Python 3.12. +* Added :doc:`environment variables <../reference/variables>` to configure debug + logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. + + If you were monkey-patching constants, be aware that they were renamed, which + will break your configuration. You must switch to the environment variables. + 12.0 ---- diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2486ac564..d3a0e935c 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -85,6 +85,7 @@ These low-level APIs are shared by all implementations. datastructures exceptions types + variables API stability ------------- diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst new file mode 100644 index 000000000..4bca112da --- /dev/null +++ b/docs/reference/variables.rst @@ -0,0 +1,48 @@ +Environment variables +===================== + +Logging +------- + +.. envvar:: WEBSOCKETS_MAX_LOG_SIZE + + How much of each frame to show in debug logs. + + The default value is ``75``. + +See the :doc:`logging guide <../topics/logging>` for details. + +Security +........ + +.. envvar:: WEBSOCKETS_SERVER + + Server header sent by websockets. + + The default value uses the format ``"Python/x.y.z websockets/X.Y"``. + +.. envvar:: WEBSOCKETS_USER_AGENT + + User-Agent header sent by websockets. + + The default value uses the format ``"Python/x.y.z websockets/X.Y"``. + +.. envvar:: WEBSOCKETS_MAX_LINE_LENGTH + + Maximum length of the request or status line in the opening handshake. + + The default value is ``8192``. + +.. envvar:: WEBSOCKETS_MAX_NUM_HEADERS + + Maximum number of HTTP headers in the opening handshake. + + The default value is ``128``. + +.. envvar:: WEBSOCKETS_MAX_BODY_SIZE + + Maximum size of the body of an HTTP response in the opening handshake. + + The default value is ``1_048_576`` (1 MiB). + +See the :doc:`security guide <../topics/security>` for details. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index e7abd96ce..873c852c2 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -76,6 +76,10 @@ Here's how to enable debug logs for development:: level=logging.DEBUG, ) +By default, websockets elides the content of messages to improve readability. +If you want to see more, you can increase the :envvar:`WEBSOCKETS_MAX_LOG_SIZE` +environment variable. The default value is 75. + Furthermore, websockets adds a ``websocket`` attribute to log records, so you can include additional information about the current connection in logs. diff --git a/docs/topics/security.rst b/docs/topics/security.rst index d3dec21bd..83d79e35b 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -1,6 +1,8 @@ Security ======== +.. currentmodule:: websockets + Encryption ---------- @@ -27,15 +29,33 @@ an amplification factor of 1000 between network traffic and memory usage. Configuring a server to :doc:`optimize memory usage ` will improve security in addition to improving performance. -Other limits ------------- +HTTP limits +----------- + +In the opening handshake, websockets applies limits to the amount of data that +it accepts in order to minimize exposure to denial of service attacks. + +The request or status line is limited to 8192 bytes. Each header line, including +the name and value, is limited to 8192 bytes too. No more than 128 HTTP headers +are allowed. When the HTTP response includes a body, it is limited to 1 MiB. + +You may change these limits by setting the :envvar:`WEBSOCKETS_MAX_LINE_LENGTH`, +:envvar:`WEBSOCKETS_MAX_NUM_HEADERS`, and :envvar:`WEBSOCKETS_MAX_BODY_SIZE` +environment variables respectively. + +Identification +-------------- + +By default, websockets identifies itself with a ``Server`` or ``User-Agent`` +header in the format ``"Python/x.y.z websockets/X.Y"``. -websockets implements additional limits on the amount of data it accepts in -order to minimize exposure to security vulnerabilities. +You can set the ``server_header`` argument of :func:`~server.serve` or the +``user_agent_header`` argument of :func:`~client.connect` to configure another +value. Setting them to :obj:`None` removes the header. -In the opening handshake, websockets limits the number of HTTP headers to 256 -and the size of an individual header to 4096 bytes. These limits are 10 to 20 -times larger than what's expected in standard use cases. They're hard-coded. +Alternatively, you can set the :envvar:`WEBSOCKETS_SERVER` and +:envvar:`WEBSOCKETS_USER_AGENT` environment variables respectively. Setting them +to an empty string removes the header. -If you need to change these limits, you can monkey-patch the constants in -``websockets.http11``. +If both the argument and the environment variable are set, the argument takes +precedence. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 040d68ece..ac8ded8ca 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -9,8 +9,7 @@ from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Response +from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol from ..uri import parse_uri @@ -71,7 +70,7 @@ async def handshake( self.request = self.protocol.connect() if additional_headers is not None: self.request.headers.update(additional_headers) - if user_agent_header is not None: + if user_agent_header: self.request.headers["User-Agent"] = user_agent_header self.protocol.send_request(self.request) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index aa175f775..0c8b8780b 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -20,8 +20,7 @@ from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Request, Response +from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, Subprotocol @@ -88,7 +87,7 @@ async def handshake( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, ) -> None: """ Perform the opening handshake. @@ -131,7 +130,7 @@ async def handshake( assert isinstance(response, Response) # help mypy self.response = response - if server_header is not None: + if server_header: self.response.headers["Server"] = server_header response = None @@ -243,7 +242,7 @@ def __init__( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, open_timeout: float | None = 10, logger: LoggerLike | None = None, ) -> None: @@ -631,7 +630,7 @@ def __init__( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, diff --git a/src/websockets/frames.py b/src/websockets/frames.py index af56d3f8f..819fdd742 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -3,6 +3,7 @@ import dataclasses import enum import io +import os import secrets import struct from typing import Callable, Generator, Sequence @@ -146,8 +147,8 @@ class Frame: rsv2: bool = False rsv3: bool = False - # Monkey-patch if you want to see more in logs. Should be a multiple of 3. - MAX_LOG = 75 + # Configure if you want to see more in logs. Should be a multiple of 3. + MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75")) def __str__(self) -> str: """ @@ -166,8 +167,8 @@ def __str__(self) -> str: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. binary = self.data - if len(binary) > self.MAX_LOG // 3: - cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) elif self.opcode is OP_CLOSE: @@ -183,16 +184,16 @@ def __str__(self) -> str: coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data - if len(binary) > self.MAX_LOG // 3: - cut = (self.MAX_LOG // 3 - 1) // 3 # by default cut = 8 + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) data = " ".join(f"{byte:02x}" for byte in binary) coding = "binary" else: data = "''" - if len(data) > self.MAX_LOG: - cut = self.MAX_LOG // 3 - 1 # by default cut = 24 + if len(data) > self.MAX_LOG_SIZE: + cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24 data = data[: 2 * cut] + "..." + data[-cut:] metadata = ", ".join(filter(None, [coding, length, non_final])) diff --git a/src/websockets/http.py b/src/websockets/http.py index 9f86f6a1f..a24102307 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,10 +1,8 @@ from __future__ import annotations -import sys import typing from .imports import lazy_import -from .version import version as websockets_version # For backwards compatibility: @@ -26,10 +24,3 @@ "read_response": ".legacy.http", }, ) - - -__all__ = ["USER_AGENT"] - - -PYTHON_VERSION = "{}.{}".format(*sys.version_info) -USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}" diff --git a/src/websockets/http11.py b/src/websockets/http11.py index a7e9ae682..ed49fcbf9 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -1,23 +1,43 @@ from __future__ import annotations import dataclasses +import os import re +import sys import warnings from typing import Callable, Generator from . import datastructures, exceptions +from .version import version as websockets_version +__all__ = ["SERVER", "USER_AGENT", "Request", "Response"] + + +PYTHON_VERSION = "{}.{}".format(*sys.version_info) + +# User-Agent header for HTTP requests. +USER_AGENT = os.environ.get( + "WEBSOCKETS_USER_AGENT", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + +# Server header for HTTP responses. +SERVER = os.environ.get( + "WEBSOCKETS_SERVER", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + # Maximum total size of headers is around 128 * 8 KiB = 1 MiB. -MAX_HEADERS = 128 +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) # Limit request line and header lines. 8KiB is the most common default # configuration of popular HTTP servers. -MAX_LINE = 8192 +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) # Support for HTTP response bodies is intended to read an error message # returned by a server. It isn't designed to perform large file transfers. -MAX_BODY = 2**20 # 1 MiB +MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB def d(value: bytes) -> str: @@ -258,12 +278,12 @@ def parse( if content_length is None: try: - body = yield from read_to_eof(MAX_BODY) + body = yield from read_to_eof(MAX_BODY_SIZE) except RuntimeError: raise exceptions.SecurityError( - f"body too large: over {MAX_BODY} bytes" + f"body too large: over {MAX_BODY_SIZE} bytes" ) - elif content_length > MAX_BODY: + elif content_length > MAX_BODY_SIZE: raise exceptions.SecurityError( f"body too large: {content_length} bytes" ) @@ -309,7 +329,7 @@ def parse_headers( # We don't attempt to support obsolete line folding. headers = datastructures.Headers() - for _ in range(MAX_HEADERS + 1): + for _ in range(MAX_NUM_HEADERS + 1): try: line = yield from parse_line(read_line) except EOFError as exc: @@ -355,7 +375,7 @@ def parse_line( """ try: - line = yield from read_line(MAX_LINE) + line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: raise exceptions.SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index d1d8d5608..b61126c81 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -38,7 +38,7 @@ parse_subprotocol, validate_subprotocols, ) -from ..http import USER_AGENT +from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri from .handshake import build_request, check_response @@ -307,7 +307,7 @@ async def handshake( if self.extra_headers is not None: request_headers.update(self.extra_headers) - if self.user_agent_header is not None: + if self.user_agent_header: request_headers.setdefault("User-Agent", self.user_agent_header) self.write_http_request(wsuri.resource_name, request_headers) diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index 9a553e175..b5df7e4c4 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import re from ..datastructures import Headers @@ -9,8 +10,8 @@ __all__ = ["read_request", "read_response"] -MAX_HEADERS = 128 -MAX_LINE = 8192 +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) def d(value: bytes) -> str: @@ -154,7 +155,7 @@ async def read_headers(stream: asyncio.StreamReader) -> Headers: # We don't attempt to support obsolete line folding. headers = Headers() - for _ in range(MAX_HEADERS + 1): + for _ in range(MAX_NUM_HEADERS + 1): try: line = await read_line(stream) except EOFError as exc: @@ -192,7 +193,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: # Security: this is bounded by the StreamReader's limit (default = 32 KiB). line = await stream.readline() # Security: this guarantees header values are small (hard-coded = 8 KiB) - if len(line) > MAX_LINE: + if len(line) > MAX_LINE_LENGTH: raise SecurityError("line too long") # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 if not line.endswith(b"\r\n"): diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 208ffa780..cd7980e00 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -40,7 +40,7 @@ parse_subprotocol, validate_subprotocols, ) -from ..http import USER_AGENT +from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .handshake import build_response, check_request @@ -106,7 +106,7 @@ def __init__( extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, process_request: ( Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None ) = None, @@ -221,7 +221,7 @@ async def handler(self) -> None: ) headers.setdefault("Date", email.utils.formatdate(usegmt=True)) - if self.server_header is not None: + if self.server_header: headers.setdefault("Server", self.server_header) headers.setdefault("Content-Length", str(len(body))) @@ -992,7 +992,7 @@ def __init__( extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, extra_headers: HeadersLikeOrCallable | None = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, process_request: ( Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None ) = None, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index c97a09402..e33d53f62 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -11,8 +11,7 @@ from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Response +from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, OPEN, Event from ..typing import LoggerLike, Origin, Subprotocol from ..uri import parse_uri diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 7fb46f5aa..ebbbd0312 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -16,8 +16,7 @@ from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..frames import CloseCode from ..headers import validate_subprotocols -from ..http import USER_AGENT -from ..http11 import Request, Response +from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, Subprotocol @@ -83,7 +82,7 @@ def handshake( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, timeout: float | None = None, ) -> None: """ @@ -120,7 +119,7 @@ def handshake( if self.response is None: self.response = self.protocol.accept(self.request) - if server_header is not None: + if server_header: self.response.headers["Server"] = server_header if process_response is not None: @@ -302,7 +301,7 @@ def serve( ] | None ) = None, - server_header: str | None = USER_AGENT, + server_header: str | None = SERVER, compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index b5c5d726a..329f59286 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -29,7 +29,7 @@ ServerPerMessageDeflateFactory, ) from websockets.frames import CloseCode -from websockets.http import USER_AGENT +from websockets.http11 import USER_AGENT from websockets.legacy.client import * from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response diff --git a/tests/test_http.py b/tests/test_http.py deleted file mode 100644 index baaa7d416..000000000 --- a/tests/test_http.py +++ /dev/null @@ -1,8 +0,0 @@ -import unittest - -from websockets.http import * - - -class HTTPTests(unittest.TestCase): - def test_user_agent(self): - USER_AGENT # exists From a7a5042bed89b96fa2e391f3be0e255a59bffb0a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 8 Aug 2024 10:21:23 +0200 Subject: [PATCH 539/762] Deprecate fully websockets.http. All public API within this module are deprecated since version 9.0 so there's nothing to document. --- src/websockets/connection.py | 1 - src/websockets/http.py | 31 ++++++++++--------------------- tests/test_http.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 22 deletions(-) create mode 100644 tests/test_http.py diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 88bcda1aa..7942c1a28 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -2,7 +2,6 @@ import warnings -# lazy_import doesn't support this use case. from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 diff --git a/src/websockets/http.py b/src/websockets/http.py index a24102307..3dc560062 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -1,26 +1,15 @@ from __future__ import annotations -import typing +import warnings -from .imports import lazy_import +from .datastructures import Headers, MultipleValuesError # noqa: F401 +from .legacy.http import read_request, read_response # noqa: F401 -# For backwards compatibility: - - -# When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: - from .datastructures import Headers, MultipleValuesError # noqa: F401 -else: - lazy_import( - globals(), - # Headers and MultipleValuesError used to be defined in this module. - aliases={ - "Headers": ".datastructures", - "MultipleValuesError": ".datastructures", - }, - deprecated_aliases={ - "read_request": ".legacy.http", - "read_response": ".legacy.http", - }, - ) +warnings.warn( + "Headers and MultipleValuesError were moved " + "from websockets.http to websockets.datastructures" + "and read_request and read_response were moved " + "from websockets.http to websockets.legacy.http", + DeprecationWarning, +) diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 000000000..6e81199fc --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,16 @@ +from websockets.datastructures import Headers + +from .utils import DeprecationTestCase + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_headers_class(self): + with self.assertDeprecationWarning( + "Headers and MultipleValuesError were moved " + "from websockets.http to websockets.datastructures" + "and read_request and read_response were moved " + "from websockets.http to websockets.legacy.http", + ): + from websockets.http import Headers as OldHeaders + + self.assertIs(OldHeaders, Headers) From 5835da4967e130cd631a7601e75ea5228ab27537 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:25:19 +0200 Subject: [PATCH 540/762] Adjust timings to avoid spurious failures. --- tests/asyncio/test_server.py | 6 +++--- tests/utils.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 2e59f49b1..535083cbc 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -412,16 +412,16 @@ async def test_close_server_keeps_handlers_running(self): async with run_server(keep_running) as server: async with run_client(server) as client: # Delay termination of connection handler. - await client.send(str(2 * MS)) + await client.send(str(3 * MS)) server.close() # The server waits for the connection handler to terminate. with self.assertRaises(TimeoutError): - async with asyncio_timeout(MS): + async with asyncio_timeout(2 * MS): await server.wait_closed() - async with asyncio_timeout(2 * MS): + async with asyncio_timeout(3 * MS): await server.wait_closed() diff --git a/tests/utils.py b/tests/utils.py index 1793f3e8b..960439135 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -43,13 +43,14 @@ # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) -# PyPy has a performance penalty for this test suite. +# PyPy, asyncio's debug mode, and coverage penalize performance of this +# test suite. Increase timeouts to reduce the risk of spurious failures. if platform.python_implementation() == "PyPy": # pragma: no cover MS *= 2 - -# asyncio's debug mode has a performance penalty for this test suite. if os.environ.get("PYTHONASYNCIODEBUG"): # pragma: no cover MS *= 2 +if os.environ.get("COVERAGE_RUN"): # pragma: no branch + MS *= 2 # Ensure that timeouts are larger than the clock's resolution (for Windows). MS = max(MS, 2.5 * time.get_clock_info("monotonic").resolution) From 906592908bb5850a4f78a5d4877fbc2412d611b7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:25:42 +0200 Subject: [PATCH 541/762] Avoid spurious coverage failures due to timing effects. --- tests/asyncio/test_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 535083cbc..4a8a76a21 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -363,7 +363,7 @@ async def test_close_server_rejects_connecting_connections(self): async def process_request(ws, _request): while ws.server.is_serving(): - await asyncio.sleep(0) + await asyncio.sleep(0) # pragma: no cover async with run_server(process_request=process_request) as server: asyncio.get_running_loop().call_later(MS, server.close) From 84e8bd879b8dfc528b4e57517f2e1f8b7ad0a378 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 09:49:55 +0200 Subject: [PATCH 542/762] Fix spurious exception while running tests. Due to a race condition between serve_forever and shutdown, test run logs randomly contained this exception: Exception in thread Thread-NNN (serve_forever): Traceback (most recent call last): ... ValueError: Invalid file descriptor: -1 --- src/websockets/sync/server.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index ebbbd0312..10fbe4859 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -222,7 +222,13 @@ def serve_forever(self) -> None: """ poller = selectors.DefaultSelector() - poller.register(self.socket, selectors.EVENT_READ) + try: + poller.register(self.socket, selectors.EVENT_READ) + except ValueError: # pragma: no cover + # If shutdown() is called before poller.register(), + # the socket is closed and poller.register() raises + # ValueError: Invalid file descriptor: -1 + return if sys.platform != "win32": poller.register(self.shutdown_watcher, selectors.EVENT_READ) From a3ed1604b0f331fe91df52641cfab2ae5349eb46 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 9 Aug 2024 10:24:24 +0200 Subject: [PATCH 543/762] Make test_reconnect robust to slower runs. This avoids failures with higher WEBSOCKETS_TESTS_TIMEOUT_FACTOR, notably on PyPy. Refs #1483. --- tests/legacy/test_client_server.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 329f59286..0c5d66c92 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -5,6 +5,7 @@ import logging import platform import random +import re import socket import ssl import sys @@ -1608,16 +1609,19 @@ async def run_client(): ) # Iteration 3 self.assertEqual( - [record.getMessage() for record in logs.records][4:-1], + [ + re.sub(r"[0-9\.]+ seconds", "X seconds", record.getMessage()) + for record in logs.records + ][4:-1], [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed; reconnecting in 0.0 seconds", + "! connect failed; reconnecting in X seconds", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed again; retrying in 0 seconds", + "! connect failed again; retrying in X seconds", ] * ((len(logs.records) - 8) // 3) + [ From 58787cc6a58a1f1baf4be3f78d868594108afebd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 14:35:38 +0200 Subject: [PATCH 544/762] Confirm support for Python 3.13. --- .github/workflows/release.yml | 2 +- .github/workflows/tests.yml | 2 ++ docs/faq/misc.rst | 3 +++ docs/project/changelog.rst | 2 +- pyproject.toml | 1 + src/websockets/asyncio/connection.py | 3 +-- tox.ini | 1 + 7 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4a00bf8fc..ed52ddd80 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -56,7 +56,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.16.2 + uses: pypa/cibuildwheel@v2.20.0 env: BUILD_EXTENSION: yes - name: Save wheels diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 15a45bdfb..b9172b7fb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,6 +62,7 @@ jobs: - "3.10" - "3.11" - "3.12" + - "3.13" - "pypy-3.9" - "pypy-3.10" is_main: @@ -78,6 +79,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} + allow-prereleases: true - name: Install tox run: pip install tox - name: Run tests diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index ee5ad2372..0e74a784f 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -3,6 +3,9 @@ Miscellaneous .. currentmodule:: websockets +.. Remove this question when dropping Python < 3.13, which provides natively +.. a good error message in this case. + Why do I get the error: ``module 'websockets' has no attribute '...'``? ....................................................................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8143e3483..00b055dd1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -71,7 +71,7 @@ New features See :func:`websockets.asyncio.client.connect` and :func:`websockets.asyncio.server.serve` for details. -* Validated compatibility with Python 3.12. +* Validated compatibility with Python 3.12 and 3.13. * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. diff --git a/pyproject.toml b/pyproject.toml index de8acd6a3..c1d34c90b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dynamic = ["version", "readme"] diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 152c6789e..4f44d798c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -564,8 +564,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[None]: pong_waiter = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution - # is a bit low on Windows (~16ms). We cannot use time.perf_counter() - # because it doesn't count time elapsed while the process sleeps. + # is a bit low on Windows (~16ms). This is improved in Python 3.13. ping_timestamp = self.loop.time() self.pong_waiters[data] = (pong_waiter, ping_timestamp) self.protocol.send_ping(data) diff --git a/tox.ini b/tox.ini index 1edcfe261..16d9c9f16 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ env_list = py310 py311 py312 + py313 coverage black ruff From 9ec785d6f12cb1a3a3bc43f543f4a831a635472b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 15:07:40 +0200 Subject: [PATCH 545/762] Fix copy-paste error in tests. --- tests/asyncio/test_client.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index aab65cd2e..b74617ef0 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -156,27 +156,27 @@ async def test_connection(self): async def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" - with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - uri="wss://overridden/", - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with run_server(ssl=SERVER_CONTEXT) as server: + host, port = get_server_host_port(server) + async with run_client( + "wss://overridden/", + host=host, + port=port, + ssl=CLIENT_CONTEXT, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") async def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" - with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - server_hostname="overridden", - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with run_server(ssl=SERVER_CONTEXT) as server: + async with run_client( + server, + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" From 1853a9b2d0247573633e2749fe1169f764abe03c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Aug 2024 15:59:59 +0200 Subject: [PATCH 546/762] Ignore ResourceWarning in test. This is expected to prevent a spurious test failure under PyPy. Refs #1483. --- tests/legacy/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 28bc90df3..5bb56b26f 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -2,6 +2,7 @@ import contextlib import functools import logging +import sys import unittest @@ -76,8 +77,19 @@ def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): Check recorded deprecation warnings match a list of expected messages. """ + # Work around https://github.com/python/cpython/issues/90476. + if sys.version_info[:2] < (3, 11): # pragma: no cover + recorded_warnings = [ + recorded + for recorded in recorded_warnings + if not ( + type(recorded.message) is ResourceWarning + and str(recorded.message).startswith("unclosed transport") + ) + ] + for recorded in recorded_warnings: - self.assertEqual(type(recorded.message), DeprecationWarning) + self.assertIs(type(recorded.message), DeprecationWarning) self.assertEqual( {str(recorded.message) for recorded in recorded_warnings}, set(expected_warnings), From 00b63afe7d921d17fd48abee2e25389050a2410c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 12 Aug 2024 09:44:02 +0200 Subject: [PATCH 547/762] Add new asyncio implementation to feature matrices. --- docs/reference/features.rst | 257 ++++++++++++++++++------------------ 1 file changed, 128 insertions(+), 129 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 98b3c0dda..946770fe3 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -14,9 +14,10 @@ Feature support matrices summarize which implementations support which features. .support-matrix-table td:not(:first-child) { text-align: center; } -.. |aio| replace:: :mod:`asyncio` +.. |aio| replace:: :mod:`asyncio` (new) .. |sync| replace:: :mod:`threading` .. |sans| replace:: `Sans-I/O`_ +.. |leg| replace:: :mod:`asyncio` (legacy) .. _Sans-I/O: https://sans-io.readthedocs.io/ Both sides @@ -25,60 +26,58 @@ Both sides .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Perform the opening handshake | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Send a message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Receive a message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Iterate over received messages | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+ - | Send a fragmented message | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Receive a fragmented message after | ✅ | ✅ | ❌ | - | reassembly | | | | - +------------------------------------+--------+--------+--------+ - | Receive a fragmented message frame | ❌ | ✅ | ✅ | - | by frame (`#479`_) | | | | - +------------------------------------+--------+--------+--------+ - | Send a ping | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Respond to pings automatically | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Send a pong | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform the closing handshake | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Report close codes and reasons | ❌ | ✅ | ✅ | - | from both sides | | | | - +------------------------------------+--------+--------+--------+ - | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Tune memory usage for compression | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Negotiate extensions | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Implement custom extensions | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Negotiate a subprotocol | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Enforce security limits | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Log events | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Enforce opening timeout | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Enforce closing timeout | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Keepalive | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Heartbeat | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - -.. _#479: https://github.com/python-websockets/websockets/issues/479 + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Perform the opening handshake | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce opening timeout | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Receive a message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Iterate over received messages | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Receive a fragmented message frame | ✅ | ✅ | ✅ | ❌ | + | by frame | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Receive a fragmented message after | ✅ | ✅ | — | ✅ | + | reassembly | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Send a ping | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Send a pong | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Keepalive | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Heartbeat | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce closing timeout | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | + | from both sides | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Tune memory usage for compression | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Negotiate extensions | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Implement custom extensions | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Negotiate a subprotocol | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Enforce security limits | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Log events | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ Server ------ @@ -86,39 +85,39 @@ Server .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Listen on a TCP socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Listen on a Unix socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Listen using a preexisting socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close server on context exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close connection on handler exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Shut down server gracefully | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Check ``Origin`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Customize subprotocol selection | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Configure ``Server`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake request | ❌ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake response | ❌ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ❌ | ❌ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | - +------------------------------------+--------+--------+--------+ - | Force HTTP response | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Listen on a TCP socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Listen on a Unix socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Listen using a preexisting socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close server on context exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close connection on handler exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Shut down server gracefully | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Check ``Origin`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Customize subprotocol selection | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``Server`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Alter opening handshake request | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Alter opening handshake response | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ❌ | ❌ | ❌ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | + +------------------------------------+--------+--------+--------+--------+ Client ------ @@ -126,41 +125,43 @@ Client .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | - +====================================+========+========+========+ - | Connect to a TCP socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Connect to a Unix socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Connect using a preexisting socket | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Close connection on context exit | ✅ | ✅ | — | - +------------------------------------+--------+--------+--------+ - | Reconnect automatically | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Configure ``Origin`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Alter opening handshake request | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | - | (`#784`_) | | | | - +------------------------------------+--------+--------+--------+ - | Follow HTTP redirects | ✅ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | - +------------------------------------+--------+--------+--------+ - | Connect via a SOCKS5 proxy | ❌ | ❌ | — | - | (`#475`_) | | | | - +------------------------------------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+ + | | |aio| | |sync| | |sans| | |leg| | + +====================================+========+========+========+========+ + | Connect to a TCP socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Connect to a Unix socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Connect using a preexisting socket | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Close connection on context exit | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Reconnect automatically | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Modify opening handshake request | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Modify opening handshake response | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Follow HTTP redirects | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ + | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | + | (`#784`_) | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+ + | Connect via a SOCKS5 proxy | ❌ | ❌ | — | ❌ | + | (`#475`_) | | | | | + +------------------------------------+--------+--------+--------+--------+ .. _#364: https://github.com/python-websockets/websockets/issues/364 .. _#475: https://github.com/python-websockets/websockets/issues/475 @@ -174,14 +175,12 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 -The server doesn't check the Host header and respond with a HTTP 400 Bad Request -if it is missing or invalid (`#1246`). +The server doesn't check the Host header and doesn't respond with a HTTP 400 Bad +Request if it is missing or invalid (`#1246`). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -`mandated by RFC 6455`_. However, :func:`~client.connect()` isn't the right -layer for enforcing this constraint. It's the caller's responsibility. - -.. _mandated by RFC 6455: https://www.rfc-editor.org/rfc/rfc6455.html#section-4.1 +mandated by :rfc:`6455`, section 4.1. However, :func:`~client.connect()` isn't +the right layer for enforcing this constraint. It's the caller's responsibility. From 7345b31edc82abc200ebb58dc0fbe856e65d447b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 12 Aug 2024 09:45:39 +0200 Subject: [PATCH 548/762] Cool URI don't change... except when they do. --- src/websockets/asyncio/connection.py | 18 ++++++------- .../extensions/permessage_deflate.py | 4 +-- src/websockets/headers.py | 18 ++++++------- src/websockets/http11.py | 14 +++++----- src/websockets/legacy/http.py | 10 +++---- src/websockets/legacy/protocol.py | 26 +++++++++---------- src/websockets/legacy/server.py | 2 +- src/websockets/protocol.py | 4 +-- src/websockets/server.py | 2 +- src/websockets/sync/connection.py | 18 ++++++------- src/websockets/typing.py | 4 +-- src/websockets/uri.py | 2 +- 12 files changed, 61 insertions(+), 61 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 4f44d798c..0a3ddb9aa 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -200,8 +200,8 @@ async def recv(self, decode: bool | None = None) -> Data: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: @@ -253,8 +253,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 You may override this behavior with the ``decode`` argument: @@ -290,8 +290,8 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. @@ -299,7 +299,7 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you really want to send the keys of a dict-like object as fragments, @@ -524,7 +524,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[None]: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point @@ -574,7 +574,7 @@ async def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 579262f02..fea14131e 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -262,7 +262,7 @@ class ClientPerMessageDeflateFactory(ClientExtensionFactory): Parameters behave as described in `section 7.1 of RFC 7692`_. - .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 + .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 Set them to :obj:`True` to include them in the negotiation offer without a value or to an integer value to include them with this value. @@ -462,7 +462,7 @@ class ServerPerMessageDeflateFactory(ServerExtensionFactory): Parameters behave as described in `section 7.1 of RFC 7692`_. - .. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1 + .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 Set them to :obj:`True` to include them in the negotiation offer without a value or to an integer value to include them with this value. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index bc42e0b72..0ffd65233 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -40,7 +40,7 @@ def build_host(host: str, port: int, secure: bool) -> str: Build a ``Host`` header. """ - # https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2 + # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 # IPv6 addresses must be enclosed in brackets. try: address = ipaddress.ip_address(host) @@ -59,8 +59,8 @@ def build_host(host: str, port: int, secure: bool) -> str: # To avoid a dependency on a parsing library, we implement manually the ABNF -# described in https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 and -# https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and +# https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. def peek_ahead(header: str, pos: int) -> str | None: @@ -183,7 +183,7 @@ def parse_list( InvalidHeaderFormat: On invalid inputs. """ - # Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient + # Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient # MUST parse and ignore a reasonable number of empty list elements"; # hence while loops that remove extra delimiters. @@ -320,7 +320,7 @@ def parse_extension_item_param( if peek_ahead(header, pos) == '"': pos_before = pos # for proper error reporting below value, pos = parse_quoted_string(header, pos, header_name) - # https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 says: + # https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says: # the value after quoted-string unescaping MUST conform to # the 'token' ABNF. if _token_re.fullmatch(value) is None: @@ -489,7 +489,7 @@ def build_www_authenticate_basic(realm: str) -> str: realm: Identifier of the protection space. """ - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 realm = build_quoted_string(realm) charset = build_quoted_string("UTF-8") return f"Basic realm={realm}, charset={charset}" @@ -539,8 +539,8 @@ def parse_authorization_basic(header: str) -> tuple[str, str]: InvalidHeaderValue: On unsupported inputs. """ - # https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1 - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7235#section-2.1 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": raise exceptions.InvalidHeaderValue( @@ -580,7 +580,7 @@ def build_authorization_basic(username: str, password: str) -> str: This is the reverse of :func:`parse_authorization_basic`. """ - # https://www.rfc-editor.org/rfc/rfc7617.html#section-2 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 assert ":" not in username user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() diff --git a/src/websockets/http11.py b/src/websockets/http11.py index ed49fcbf9..b86c6ca4a 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -48,7 +48,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. # Regex for validating header names. @@ -122,7 +122,7 @@ def parse( ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -146,7 +146,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -227,7 +227,7 @@ def parse( ValueError: If the response isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 try: status_line = yield from parse_line(read_line) @@ -255,7 +255,7 @@ def parse( headers = yield from parse_headers(read_line) - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 if "Transfer-Encoding" in headers: raise NotImplementedError("transfer codings aren't supported") @@ -324,7 +324,7 @@ def parse_headers( ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 # We don't attempt to support obsolete line folding. @@ -378,7 +378,7 @@ def parse_line( line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: raise exceptions.SecurityError("line too long") - # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/http.py b/src/websockets/legacy/http.py index b5df7e4c4..a7c8a927e 100644 --- a/src/websockets/legacy/http.py +++ b/src/websockets/legacy/http.py @@ -22,7 +22,7 @@ def d(value: bytes) -> str: return value.decode(errors="backslashreplace") -# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. # Regex for validating header names. @@ -64,7 +64,7 @@ async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]: ValueError: If the request isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 # Parsing is simple because fixed values are expected for method and # version and because path isn't checked. Since WebSocket software tends @@ -111,7 +111,7 @@ async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers ValueError: If the response isn't well formatted. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 # As in read_request, parsing is simple because a fixed value is expected # for version, status_code is a 3-digit number, and reason can be ignored. @@ -150,7 +150,7 @@ async def read_headers(stream: asyncio.StreamReader) -> Headers: Non-ASCII characters are represented with surrogate escapes. """ - # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 # We don't attempt to support obsolete line folding. @@ -195,7 +195,7 @@ async def read_line(stream: asyncio.StreamReader) -> bytes: # Security: this guarantees header values are small (hard-coded = 8 KiB) if len(line) > MAX_LINE_LENGTH: raise SecurityError("line too long") - # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") return line[:-2] diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 120ff8e73..6f8916576 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -80,14 +80,14 @@ class WebSocketCommonProtocol(asyncio.Protocol): especially in the presence of proxies with short timeouts on inactive connections. Set ``ping_interval`` to :obj:`None` to disable this behavior. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 If the corresponding Pong_ frame isn't received within ``ping_timeout`` seconds, the connection is considered unusable and is closed with code 1011. This ensures that the remote endpoint remains responsive. Set ``ping_timeout`` to :obj:`None` to disable this behavior. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. @@ -447,7 +447,7 @@ def close_code(self) -> int | None: WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. .. _section 7.1.5 of RFC 6455: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -465,7 +465,7 @@ def close_reason(self) -> str | None: WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. .. _section 7.1.6 of RFC 6455: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. @@ -516,8 +516,8 @@ async def recv(self) -> Data: A string (:class:`str`) for a Text_ frame. A bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -583,8 +583,8 @@ async def send( bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. @@ -592,7 +592,7 @@ async def send( All items must be of the same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you want to send the keys of a dict-like object as fragments, call @@ -803,7 +803,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive, as a check that the remote endpoint received all messages up to this point, or to measure :attr:`latency`. @@ -862,7 +862,7 @@ async def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. @@ -1559,8 +1559,8 @@ def broadcast( object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :func:`broadcast` pushes the message synchronously to all connections even if their write buffers are overflowing. There's no backpressure. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index cd7980e00..d230f009e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -388,7 +388,7 @@ def process_origin( """ # "The user agent MUST NOT include more than one Origin header field" - # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") except MultipleValuesError as exc: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 7f2b45c74..917c19163 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -175,7 +175,7 @@ def close_code(self) -> int | None: `WebSocket close code`_. .. _WebSocket close code: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -193,7 +193,7 @@ def close_reason(self) -> str | None: `WebSocket close reason`_. .. _WebSocket close reason: - https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. diff --git a/src/websockets/server.py b/src/websockets/server.py index 7211d3cbf..1b4c3bf29 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -307,7 +307,7 @@ def process_origin(self, headers: Headers) -> Origin | None: """ # "The user agent MUST NOT include more than one Origin header field" - # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") except MultipleValuesError as exc: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 2bcb3aa0e..a4826c785 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -189,8 +189,8 @@ def recv(self, timeout: float | None = None) -> Data: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -222,8 +222,8 @@ def recv_streaming(self) -> Iterator[Data]: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 Raises: ConnectionClosed: When the connection is closed. @@ -250,8 +250,8 @@ def send(self, message: Data | Iterable[Data]) -> None: bytes-like object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent as a Binary_ frame. - .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 - .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a @@ -259,7 +259,7 @@ def send(self, message: Data | Iterable[Data]) -> None: same type, or else :meth:`send` will raise a :exc:`TypeError` and the connection will be closed. - .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 :meth:`send` rejects dict-like objects because this is often an error. (If you really want to send the keys of a dict-like object as fragments, @@ -425,7 +425,7 @@ def ping(self, data: Data | None = None) -> threading.Event: """ Send a Ping_. - .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 A ping may serve as a keepalive or as a check that the remote endpoint received all messages up to this point @@ -470,7 +470,7 @@ def pong(self, data: Data = b"") -> None: """ Send a Pong_. - .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 An unsolicited pong may serve as a unidirectional heartbeat. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 6360c7a0a..447fe79da 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -24,8 +24,8 @@ """Types supported in a WebSocket message: :class:`str` for a Text_ frame, :class:`bytes` for a Binary_. -.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 -.. _Binary : https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 +.. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 """ diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 5cb38a9cc..82b35f92a 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -23,7 +23,7 @@ class WebSocketURI: username: Available when the URI contains `User Information`_. password: Available when the URI contains `User Information`_. - .. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1 + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 """ From e2f0385119992317c0f49b32775ef50b2fa40218 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 13 Aug 2024 23:15:55 +0200 Subject: [PATCH 549/762] Add guide for upgrading to the new asyncio implementation. --- docs/howto/index.rst | 8 + docs/howto/upgrade.rst | 357 ++++++++++++++++++++++++++ docs/project/changelog.rst | 66 ++++- docs/reference/asyncio/client.rst | 4 +- docs/reference/asyncio/common.rst | 4 +- docs/reference/asyncio/server.rst | 10 +- docs/reference/new-asyncio/client.rst | 4 +- docs/reference/new-asyncio/common.rst | 4 +- docs/reference/new-asyncio/server.rst | 4 +- docs/spelling_wordlist.txt | 5 +- 10 files changed, 446 insertions(+), 20 deletions(-) create mode 100644 docs/howto/upgrade.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index ddbe67d3a..863c1c63c 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -8,6 +8,14 @@ In a hurry? Check out these examples. quickstart +Upgrading from the legacy :mod:`asyncio` implementation to the new one? +Read this. + +.. toctree:: + :titlesonly: + + upgrade + If you're stuck, perhaps you'll find the answer here. .. toctree:: diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst new file mode 100644 index 000000000..bb4c59bc4 --- /dev/null +++ b/docs/howto/upgrade.rst @@ -0,0 +1,357 @@ +Upgrade to the new :mod:`asyncio` implementation +================================================ + +.. currentmodule:: websockets + +The new :mod:`asyncio` implementation is a rewrite of the original +implementation of websockets. + +It provides a very similar API. However, there are a few differences. + +The recommended upgrade process is: + +1. Make sure that your application doesn't use any `deprecated APIs`_. If it + doesn't raise any warnings, you can skip this step. +2. Check if your application depends on `missing features`_. If it does, you + should stick to the original implementation until they're added. +3. `Update import paths`_. For straightforward usage of websockets, this could + be the only step you need to take. Upgrading could be transparent. +4. `Review API changes`_ and adapt your application to preserve its current + functionality or take advantage of improvements in the new implementation. + +In the interest of brevity, only :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` are discussed below but everything also applies +to :func:`~asyncio.client.unix_connect` and :func:`~asyncio.server.unix_serve` +respectively. + +.. admonition:: What will happen to the original implementation? + :class: hint + + The original implementation is now considered legacy. + + The next steps are: + + 1. Deprecating it once the new implementation reaches feature parity. + 2. Maintaining it for five years per the :ref:`backwards-compatibility + policy `. + 3. Removing it. This is expected to happen around 2030. + +.. _deprecated APIs: + +Deprecated APIs +--------------- + +Here's the list of deprecated behaviors that the original implementation still +supports and that the new implementation doesn't reproduce. + +If you're seeing a :class:`DeprecationWarning`, follow upgrade instructions from +the release notes of the version in which the feature was deprecated. + +* The ``path`` argument of connection handlers — unnecessary since :ref:`10.1` + and deprecated in :ref:`13.0`. +* The ``loop`` and ``legacy_recv`` arguments of :func:`~client.connect` and + :func:`~server.serve`, which were removed — deprecated in :ref:`10.0`. +* The ``timeout`` and ``klass`` arguments of :func:`~client.connect` and + :func:`~server.serve`, which were renamed to ``close_timeout`` and + ``create_protocol`` — deprecated in :ref:`7.0` and :ref:`3.4` respectively. +* An empty string in the ``origins`` argument of :func:`~server.serve` — + deprecated in :ref:`7.0`. +* The ``host``, ``port``, and ``secure`` attributes of connections — deprecated + in :ref:`8.0`. + +.. _missing features: + +Missing features +---------------- + +.. admonition:: All features listed below will be provided in a future release. + :class: tip + + If your application relies on one of them, you should stick to the original + implementation until the new implementation supports it in a future release. + +Broadcast +......... + +The new implementation doesn't support :doc:`broadcasting messages +<../topics/broadcast>` yet. + +Keepalive +......... + +The new implementation doesn't provide a :ref:`keepalive mechanism ` +yet. + +As a consequence, :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` don't accept the ``ping_interval`` and +``ping_timeout`` arguments and the +:attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property doesn't exist. + +HTTP Basic Authentication +......................... + +On the server side, :func:`~asyncio.server.serve` doesn't provide HTTP Basic +Authentication yet. + +For the avoidance of doubt, on the client side, :func:`~asyncio.client.connect` +performs HTTP Basic Authentication. + +Following redirects +................... + +The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP +redirects yet. + +Automatic reconnection +...................... + +The new implementation of :func:`~asyncio.client.connect` doesn't provide +automatic reconnection yet. + +In other words, the following pattern isn't supported:: + + from websockets.asyncio.client import connect + + async for websocket in connect(...): # this doesn't work yet + ... + +Configuring buffers +................... + +The new implementation doesn't provide a way to configure read and write buffers +yet. + +In practice, :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` +don't accept the ``max_queue``, ``read_limit``, and ``write_limit`` arguments. + +Here's the most likely outcome: + +* ``max_queue`` will be implemented but its semantics will change from "maximum + number of messages" to "maximum number of frames", which makes a difference + when messages are fragmented. +* ``read_limit`` won't be implemented because the buffer that it configured was + removed from the new implementation. The queue that ``max_queue`` configures + is the only read buffer now. +* ``write_limit`` will be implemented as in the original implementation. + Alternatively, the same functionality could be exposed with a different API. + +.. _Update import paths: + +Import paths +------------ + +For context, the ``websockets`` package is structured as follows: + +* The new implementation is found in the ``websockets.asyncio`` package. +* The original implementation was moved to the ``websockets.legacy`` package. +* The ``websockets`` package provides aliases for convenience. +* The ``websockets.client`` and ``websockets.server`` packages provide aliases + for backwards-compatibility with earlier versions of websockets. +* Currently, all aliases point to the original implementation. In the future, + they will point to the new implementation or they will be deprecated. + +To upgrade to the new :mod:`asyncio` implementation, change import paths as +shown in the tables below. + +.. |br| raw:: html + +
+ +Client APIs +........... + ++-------------------------------------------------------------------+-----------------------------------------------------+ +| Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | ++===================================================================+=====================================================+ +| ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | +| :func:`websockets.client.connect` |br| | | +| ``websockets.legacy.client.connect()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | +| :func:`websockets.client.unix_connect` |br| | | +| ``websockets.legacy.client.unix_connect()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | +| :class:`websockets.client.WebSocketClientProtocol` |br| | | +| ``websockets.legacy.client.WebSocketClientProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ + +Server APIs +........... + ++-------------------------------------------------------------------+-----------------------------------------------------+ +| Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | ++===================================================================+=====================================================+ +| ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | +| :func:`websockets.server.serve` |br| | | +| ``websockets.legacy.server.serve()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | +| :func:`websockets.server.unix_serve` |br| | | +| ``websockets.legacy.server.unix_serve()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | +| :class:`websockets.server.WebSocketServer` |br| | | +| ``websockets.legacy.server.WebSocketServer`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | +| :class:`websockets.server.WebSocketServerProtocol` |br| | | +| ``websockets.legacy.server.WebSocketServerProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| :func:`websockets.broadcast` |br| | *not available yet* | +| ``websockets.legacy.protocol.broadcast()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | +| :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | +| ``websockets.legacy.auth.BasicAuthWebSocketServerProtocol`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ +| ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | +| :func:`websockets.auth.basic_auth_protocol_factory` |br| | | +| ``websockets.legacy.auth.basic_auth_protocol_factory()`` | | ++-------------------------------------------------------------------+-----------------------------------------------------+ + +.. _Review API changes: + +API changes +----------- + +Controlling UTF-8 decoding +.......................... + +The new implementation of the :meth:`~asyncio.connection.Connection.recv` method +provides the ``decode`` argument to control UTF-8 decoding of messages. This +didn't exist in the original implementation. + +If you're calling :meth:`~str.encode` on a :class:`str` object returned by +:meth:`~asyncio.connection.Connection.recv`, using ``decode=False`` and removing +:meth:`~str.encode` saves a round-trip of UTF-8 decoding and encoding for text +messages. + +You can also force UTF-8 decoding of binary messages with ``decode=True``. This +is rarely useful and has no performance benefits over decoding a :class:`bytes` +object returned by :meth:`~asyncio.connection.Connection.recv`. + +Receiving fragmented messages +............................. + +The new implementation provides the +:meth:`~asyncio.connection.Connection.recv_streaming` method for receiving a +fragmented message frame by frame. There was no way to do this in the original +implementation. + +Depending on your use case, adopting this method may improve performance when +streaming large messages. Specifically, it could reduce memory usage. + +Customizing the opening handshake +................................. + +On the client side, if you're adding headers to the handshake request sent by +:func:`~client.connect` with the ``extra_headers`` argument, you must rename it +to ``additional_headers``. + +On the server side, if you're customizing how :func:`~server.serve` processes +the opening handshake with the ``process_request``, ``extra_headers``, or +``select_subprotocol``, you must update your code. ``process_response`` and +``select_subprotocol`` have new signatures; ``process_response`` replaces +``extra_headers`` and provides more flexibility. + +``process_request`` +~~~~~~~~~~~~~~~~~~~ + +The signature of ``process_request`` changed. This is easiest to illustrate with +an example:: + + import http + + # Original implementation + + def process_request(path, request_headers): + return http.HTTPStatus.OK, [], b"OK\n" + + serve(..., process_request=process_request, ...) + + # New implementation + + def process_request(connection, request): + return connection.protocol.reject(http.HTTPStatus.OK, "OK\n") + + serve(..., process_request=process_request, ...) + +``connection`` is always available in ``process_request``. In the original +implementation, you had to write a subclass of +:class:`~server.WebSocketServerProtocol` and pass it in the ``create_protocol`` +argument to make the connection object available in a ``process_request`` +method. This pattern isn't useful anymore; you can replace it with a +``process_request`` function or coroutine. + +``path`` and ``headers`` are available as attributes of the ``request`` object. + +``process_response`` +~~~~~~~~~~~~~~~~~~~~ + +``process_request`` replaces ``extra_headers`` and provides more flexibility. +In the most basic case, you would adapt your code as follows:: + + # Original implementation + + serve(..., extra_headers=HEADERS, ...) + + # New implementation + + def process_response(connection, request, response): + response.headers.update(HEADERS) + return response + + serve(..., process_response=process_response, ...) + +``connection`` is always available in ``process_response``, similar to +``process_request``. In the original implementation, there was no way to make +the connection object available. + +In addition, the ``request`` and ``response`` objects are available, which +enables a broader range of use cases (e.g., logging) and makes +``process_response`` more useful than ``extra_headers``. + +``select_subprotocol`` +~~~~~~~~~~~~~~~~~~~~~~ + +The signature of ``select_subprotocol`` changed. Here's an example:: + + # Original implementation + + def select_subprotocol(client_subprotocols, server_subprotocols): + if "chat" in client_subprotocols: + return "chat" + + # New implementation + + def select_subprotocol(connection, subprotocols): + if "chat" in subprotocols + return "chat" + + serve(..., select_subprotocol=select_subprotocol, ...) + +``connection`` is always available in ``select_subprotocol``. This brings the +same benefits as in ``process_request``. It may remove the need to subclass of +:class:`~server.WebSocketServerProtocol`. + +The ``subprotocols`` argument contains the list of subprotocols offered by the +client. The list of subprotocols supported by the server was removed because +``select_subprotocols`` already knows which subprotocols it may select and under +which conditions. + +Miscellaneous changes +..................... + +The first argument of :func:`~asyncio.server.serve` is called ``handler`` instead +of ``ws_handler``. It's usually passed as a positional argument, making this +change transparent. If you're passing it as a keyword argument, you must update +its name. + +The keyword argument of :func:`~asyncio.server.serve` for customizing the +creation of the connection object is called ``create_connection`` instead of +``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` +instead of a :class:`~server.WebSocketServerProtocol`. If you were customizing +connection objects, you should check the new implementation and possibly redo +your customization. Keep in mind that the changes to ``process_request`` and +``select_subprotocol`` remove most use cases for ``create_connection``. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 00b055dd1..f033f5632 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,8 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _13.0: + 13.0 ---- @@ -66,10 +68,10 @@ New features This new implementation is intended to be a drop-in replacement for the current implementation. It will become the default in a future release. - Please try it and report any issue that you encounter! - See :func:`websockets.asyncio.client.connect` and - :func:`websockets.asyncio.server.serve` for details. + Please try it and report any issue that you encounter! The :doc:`upgrade + guide <../howto/upgrade>` explains everything you need to know about the + upgrade process. * Validated compatibility with Python 3.12 and 3.13. @@ -79,6 +81,8 @@ New features If you were monkey-patching constants, be aware that they were renamed, which will break your configuration. You must switch to the environment variables. +.. _12.0: + 12.0 ---- @@ -135,6 +139,8 @@ Bug fixes * Restored the C extension in the source distribution. +.. _11.0: + 11.0 ---- @@ -211,6 +217,8 @@ Improvements * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. +.. _10.4: + 10.4 ---- @@ -237,6 +245,8 @@ Improvements * Improved FAQ. +.. _10.3: + 10.3 ---- @@ -259,6 +269,8 @@ Improvements * Reduced noise in logs when :mod:`ssl` or :mod:`zlib` raise exceptions. +.. _10.2: + 10.2 ---- @@ -279,6 +291,8 @@ Bug fixes * Avoided leaking open sockets when :func:`~client.connect` is canceled. +.. _10.1: + 10.1 ---- @@ -328,6 +342,8 @@ Bug fixes * Avoided half-closing TCP connections that are already closed. +.. _10.0: + 10.0 ---- @@ -434,6 +450,8 @@ Bug fixes * Avoided a crash when receiving a ping while the connection is closing. +.. _9.1: + 9.1 --- @@ -472,6 +490,8 @@ Bug fixes * Fixed issues with the packaging of the 9.0 release. +.. _9.0: + 9.0 --- @@ -549,6 +569,8 @@ Bug fixes * Ensured cancellation always propagates, even on Python versions where :exc:`~asyncio.CancelledError` inherits :exc:`Exception`. +.. _8.1: + 8.1 --- @@ -583,6 +605,8 @@ Bug fixes * Restored the ability to import ``WebSocketProtocolError`` from ``websockets``. +.. _8.0: + 8.0 --- @@ -692,6 +716,8 @@ Bug fixes * Avoided a crash when a ``extra_headers`` callable returns :obj:`None`. +.. _7.0: + 7.0 --- @@ -786,6 +812,8 @@ Bug fixes :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`: canceling it at the wrong time could result in messages being dropped. +.. _6.0: + 6.0 --- @@ -840,6 +868,8 @@ Bug fixes * Fixed a regression in 5.0 that broke some invocations of :func:`~server.serve` and :func:`~client.connect`. +.. _5.0: + 5.0 --- @@ -925,6 +955,8 @@ Bug fixes * Fixed issues with the packaging of the 4.0 release. +.. _4.0: + 4.0 --- @@ -984,6 +1016,8 @@ Bug fixes * Stopped leaking pending tasks when :meth:`~asyncio.Task.cancel` is called on a connection while it's being closed. +.. _3.4: + 3.4 --- @@ -1027,6 +1061,8 @@ Bug fixes * Providing a ``sock`` argument to :func:`~client.connect` no longer crashes. +.. _3.3: + 3.3 --- @@ -1047,6 +1083,8 @@ Bug fixes * Avoided crashing on concurrent writes on slow connections. +.. _3.2: + 3.2 --- @@ -1063,6 +1101,8 @@ Improvements * Made server shutdown more robust. +.. _3.1: + 3.1 --- @@ -1078,6 +1118,8 @@ Bug fixes * Avoided a warning when closing a connection before the opening handshake. +.. _3.0: + 3.0 --- @@ -1135,6 +1177,8 @@ Improvements * Improved documentation. +.. _2.7: + 2.7 --- @@ -1150,6 +1194,8 @@ Improvements * Refreshed documentation. +.. _2.6: + 2.6 --- @@ -1167,6 +1213,8 @@ Bug fixes * Avoided TCP fragmentation of small frames. +.. _2.5: + 2.5 --- @@ -1200,6 +1248,8 @@ Bug fixes * Canceling :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` no longer drops the next message. +.. _2.4: + 2.4 --- @@ -1213,6 +1263,8 @@ New features * Added ``loop`` argument to :func:`~client.connect` and :func:`~server.serve`. +.. _2.3: + 2.3 --- @@ -1223,6 +1275,8 @@ Improvements * Improved compliance of close codes. +.. _2.2: + 2.2 --- @@ -1233,6 +1287,8 @@ New features * Added support for limiting message size. +.. _2.1: + 2.1 --- @@ -1247,6 +1303,8 @@ New features .. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 +.. _2.0: + 2.0 --- @@ -1275,6 +1333,8 @@ New features * Added flow control for outgoing data. +.. _1.0: + 1.0 --- diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index 5086015b7..f9ce2f2d8 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,5 +1,5 @@ -Client (:mod:`asyncio`) -======================= +Client (legacy :mod:`asyncio`) +============================== .. automodule:: websockets.client diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index dc7a54ee1..aee774479 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (:mod:`asyncio`) -=========================== +Both sides (legacy :mod:`asyncio`) +================================== .. automodule:: websockets.legacy.protocol diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 106317916..4bd52b40b 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,15 +1,15 @@ -Server (:mod:`asyncio`) -======================= +Server (legacy :mod:`asyncio`) +============================== .. automodule:: websockets.server Starting a server ----------------- -.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: -.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) :async: Stopping a server @@ -34,7 +34,7 @@ Stopping a server Using a connection ------------------ -.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) .. automethod:: recv diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index 552d83b2f..196bda2b7 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -1,5 +1,5 @@ -Client (:mod:`asyncio` - new) -============================= +Client (new :mod:`asyncio`) +=========================== .. automodule:: websockets.asyncio.client diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index ba23552dc..4fa97dcf2 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (:mod:`asyncio` - new) -================================= +Both sides (new :mod:`asyncio`) +=============================== .. automodule:: websockets.asyncio.connection diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index f3446fb80..c43673d33 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -1,5 +1,5 @@ -Server (:mod:`asyncio` - new) -============================= +Server (new :mod:`asyncio`) +=========================== .. automodule:: websockets.asyncio.server diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index dfa7065e7..a1ba59a37 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -21,8 +21,8 @@ cryptocurrency css ctrl deserialize -django dev +django Dockerfile dyno formatter @@ -44,6 +44,7 @@ linkerd liveness lookups MiB +middleware mutex mypy nginx @@ -77,8 +78,8 @@ uple uvicorn uvloop virtualenv -WebSocket websocket +WebSocket websockets ws wsgi From 8385cf02fccd5e171e1ee5b8949df11773c0f954 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 14 Aug 2024 08:40:46 +0200 Subject: [PATCH 550/762] Add write_limit parameter to the new asyncio API. --- docs/howto/upgrade.rst | 84 ++++++++++++++++++---------- src/websockets/asyncio/client.py | 16 ++++++ src/websockets/asyncio/connection.py | 16 +++++- src/websockets/asyncio/messages.py | 21 ++++--- src/websockets/asyncio/server.py | 16 ++++++ tests/asyncio/test_connection.py | 38 ++++++++++++- tests/asyncio/test_messages.py | 35 +++++++----- 7 files changed, 166 insertions(+), 60 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index bb4c59bc4..10e8967d8 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -115,26 +115,6 @@ In other words, the following pattern isn't supported:: async for websocket in connect(...): # this doesn't work yet ... -Configuring buffers -................... - -The new implementation doesn't provide a way to configure read and write buffers -yet. - -In practice, :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` -don't accept the ``max_queue``, ``read_limit``, and ``write_limit`` arguments. - -Here's the most likely outcome: - -* ``max_queue`` will be implemented but its semantics will change from "maximum - number of messages" to "maximum number of frames", which makes a difference - when messages are fragmented. -* ``read_limit`` won't be implemented because the buffer that it configured was - removed from the new implementation. The queue that ``max_queue`` configures - is the only read buffer now. -* ``write_limit`` will be implemented as in the original implementation. - Alternatively, the same functionality could be exposed with a different API. - .. _Update import paths: Import paths @@ -340,18 +320,60 @@ client. The list of subprotocols supported by the server was removed because ``select_subprotocols`` already knows which subprotocols it may select and under which conditions. -Miscellaneous changes -..................... +Arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` +.............................................................................. + +``ws_handler`` → ``handler`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The first argument of :func:`~asyncio.server.serve` is called ``handler`` instead -of ``ws_handler``. It's usually passed as a positional argument, making this -change transparent. If you're passing it as a keyword argument, you must update -its name. +The first argument of :func:`~asyncio.server.serve` is now called ``handler`` +instead of ``ws_handler``. It's usually passed as a positional argument, making +this change transparent. If you're passing it as a keyword argument, you must +update its name. + +``create_protocol`` → ``create_connection`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The keyword argument of :func:`~asyncio.server.serve` for customizing the -creation of the connection object is called ``create_connection`` instead of +creation of the connection object is now called ``create_connection`` instead of ``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` -instead of a :class:`~server.WebSocketServerProtocol`. If you were customizing -connection objects, you should check the new implementation and possibly redo -your customization. Keep in mind that the changes to ``process_request`` and -``select_subprotocol`` remove most use cases for ``create_connection``. +instead of a :class:`~server.WebSocketServerProtocol`. + +If you were customizing connection objects, you should check the new +implementation and possibly redo your customization. Keep in mind that the +changes to ``process_request`` and ``select_subprotocol`` remove most use cases +for ``create_connection``. + +``max_queue`` +~~~~~~~~~~~~~ + +The ``max_queue`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` has a new meaning but achieves a similar effect. + +It is now the high-water mark of a buffer of incoming frames. It defaults to 16 +frames. It used to be the size of a buffer of incoming messages that refilled as +soon as a message was read. It used to default to 32 messages. + +This can make a difference when messages are fragmented in several frames. In +that case, you may want to increase ``max_queue``. If you're writing a high +performance server and you know that you're receiving fragmented messages, +probably you should adopt :meth:`~asyncio.connection.Connection.recv_streaming` +and optimize the performance of reads again. In all other cases, given how +uncommon fragmentation is, you shouldn't worry about this change. + +``read_limit`` +~~~~~~~~~~~~~~ + +The ``read_limit`` argument doesn't exist in the new implementation because it +doesn't buffer data received from the network in a +:class:`~asyncio.StreamReader`. With a better design, this buffer could be +removed. + +The buffer of incoming frames configured by ``max_queue`` is the only read +buffer now. + +``write_limit`` +~~~~~~~~~~~~~~~ + +The ``write_limit`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index ac8ded8ca..b2eaf9a65 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -49,11 +49,15 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol super().__init__( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) self.response_rcvd: asyncio.Future[None] = self.loop.create_future() @@ -146,6 +150,14 @@ class connect: :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. + write_limit: High-water mark of write buffer in bytes. It is passed to + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults + to 32 KiB. You may pass a ``(high, low)`` tuple to set the + high-water and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -199,6 +211,8 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -243,6 +257,8 @@ def factory() -> ClientConnection: connection = create_connection( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) return connection diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 0a3ddb9aa..1c4424f0d 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -48,9 +48,17 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int): + max_queue = (max_queue, None) + self.max_queue = max_queue + if isinstance(write_limit, int): + write_limit = (write_limit, None) + self.write_limit = write_limit # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -803,11 +811,13 @@ def close_transport(self) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) - self.transport = transport self.recv_messages = Assembler( - pause=self.transport.pause_reading, - resume=self.transport.resume_reading, + *self.max_queue, + pause=transport.pause_reading, + resume=transport.resume_reading, ) + transport.set_write_buffer_limits(*self.write_limit) + self.transport = transport def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index bc33df8d7..33ab6a5e9 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -89,6 +89,8 @@ class Assembler: # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, + high: int = 16, + low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: @@ -99,11 +101,16 @@ def __init__( # pragma: no cover # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - self.high = 16 - self.low = 4 - self.paused = False + if low is None: + low = high // 4 + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low self.pause = pause self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False @@ -254,14 +261,6 @@ def put(self, frame: Frame) -> None: self.frames.put(frame) self.maybe_pause() - def get_limits(self) -> tuple[int, int]: - """Return low and high water marks for flow control.""" - return self.low, self.high - - def set_limits(self, low: int = 4, high: int = 16) -> None: - """Configure low and high water marks for flow control.""" - self.low, self.high = low, high - def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" # Check for "> high" to support high = 0 diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 0c8b8780b..4feea13c4 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -62,11 +62,15 @@ def __init__( server: WebSocketServer, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol super().__init__( protocol, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() @@ -574,6 +578,14 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. + write_limit: High-water mark of write buffer in bytes. It is passed to + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults + to 32 KiB. You may pass a ``(high, low)`` tuple to set the + high-water and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -637,6 +649,8 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -709,6 +723,8 @@ def protocol_select_subprotocol( protocol, self.server, close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, ) return connection diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 2efd4e96d..02029b754 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -4,7 +4,7 @@ import socket import unittest import uuid -from unittest.mock import patch +from unittest.mock import Mock, patch from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * @@ -867,6 +867,42 @@ async def test_pong_explicit_binary(self): await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + # Test parameters. + + async def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS) + self.assertEqual(connection.close_timeout, 42 * MS) + + async def test_max_queue(self): + """max_queue parameter configures high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=4) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, 4) + + async def test_max_queue_tuple(self): + """max_queue parameter configures high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + + async def test_write_limit(self): + """write_limit parameter configures high-water mark of write buffer.""" + connection = Connection(Protocol(self.LOCAL), write_limit=4096) + transport = Mock() + connection.connection_made(transport) + transport.set_write_buffer_limits.assert_called_once_with(4096, None) + + async def test_write_limits(self): + """write_limit parameter configures high and low-water marks of write buffer.""" + connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + transport = Mock() + connection.connection_made(transport) + transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) + # Test attributes. async def test_id(self): diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index c8a2d7cd5..615b1f3a8 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -70,8 +70,7 @@ class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.pause = unittest.mock.Mock() self.resume = unittest.mock.Mock() - self.assembler = Assembler(pause=self.pause, resume=self.resume) - self.assembler.set_limits(low=1, high=2) + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get @@ -455,17 +454,25 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await alist(self.assembler.get_iter()) self.assembler.close() # let task terminate - # Test getting and setting limits + # Test setting limits - async def test_get_limits(self): - """get_limits returns low and high water marks.""" - low, high = self.assembler.get_limits() - self.assertEqual(low, 1) - self.assertEqual(high, 2) + async def test_set_high_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) - async def test_set_limits(self): - """set_limits changes low and high water marks.""" - self.assembler.set_limits(low=2, high=4) - low, high = self.assembler.get_limits() - self.assertEqual(low, 2) - self.assertEqual(high, 4) + async def test_set_high_and_low_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + async def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + async def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) From 5eafbe466b909f21dc7e74b1350583b4d5ae0606 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 15 Aug 2024 16:25:51 +0200 Subject: [PATCH 551/762] Rewrite documentation of buffers. Describe all implementations. Also update documentation of compression. --- .gitignore | 2 +- docs/topics/compression.rst | 173 ++++++++++-------- docs/topics/design.rst | 49 ----- docs/topics/memory.rst | 156 +++++++++++++--- experiments/compression/benchmark.py | 74 ++------ experiments/compression/client.py | 18 +- experiments/compression/corpus.py | 52 ++++++ experiments/compression/server.py | 10 +- .../extensions/permessage_deflate.py | 6 +- 9 files changed, 316 insertions(+), 224 deletions(-) create mode 100644 experiments/compression/corpus.py diff --git a/.gitignore b/.gitignore index 324e77069..d8e6697a8 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,7 @@ .tox build/ compliance/reports/ -experiments/compression/corpus.pkl +experiments/compression/corpus/ dist/ docs/_build/ htmlcov/ diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index eaf99070d..be263e56f 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -7,37 +7,36 @@ Most WebSocket servers exchange JSON messages because they're convenient to parse and serialize in a browser. These messages contain text data and tend to be repetitive. -This makes the stream of messages highly compressible. Enabling compression +This makes the stream of messages highly compressible. Compressing messages can reduce network traffic by more than 80%. -There's a standard for compressing messages. :rfc:`7692` defines WebSocket -Per-Message Deflate, a compression extension based on the Deflate_ algorithm. +websockets implements WebSocket Per-Message Deflate, a compression extension +based on the Deflate_ algorithm specified in :rfc:`7692`. .. _Deflate: https://en.wikipedia.org/wiki/Deflate -Configuring compression ------------------------ +:func:`~websockets.asyncio.client.connect` and +:func:`~websockets.asyncio.server.serve` enable compression by default because +the reduction in network bandwidth is usually worth the additional memory and +CPU cost. -:func:`~websockets.client.connect` and :func:`~websockets.server.serve` enable -compression by default because the reduction in network bandwidth is usually -worth the additional memory and CPU cost. -If you want to disable compression, set ``compression=None``:: +Configuring compression +----------------------- - import websockets +To disable compression, set ``compression=None``:: - websockets.connect(..., compression=None) + connect(..., compression=None, ...) - websockets.serve(..., compression=None) + serve(..., compression=None, ...) -If you want to customize compression settings, you can enable the Per-Message -Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or +To customize compression settings, enable the Per-Message Deflate extension +explicitly with :class:`ClientPerMessageDeflateFactory` or :class:`ServerPerMessageDeflateFactory`:: - import websockets from websockets.extensions import permessage_deflate - websockets.connect( + connect( ..., extensions=[ permessage_deflate.ClientPerMessageDeflateFactory( @@ -46,9 +45,10 @@ Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], + ..., ) - websockets.serve( + serve( ..., extensions=[ permessage_deflate.ServerPerMessageDeflateFactory( @@ -57,13 +57,14 @@ Deflate extension explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], + ..., ) The Window Bits and Memory Level values in these examples reduce memory usage at the expense of compression rate. -Compression settings --------------------- +Compression parameters +---------------------- When a client and a server enable the Per-Message Deflate extension, they negotiate two parameters to guarantee compatibility between compression and @@ -81,9 +82,9 @@ and memory usage for both sides. This requires retaining the compression context and state between messages, which increases the memory footprint of a connection. -* **Window Bits** controls the size of the compression context. It must be - an integer between 9 (lowest memory usage) and 15 (best compression). - Setting it to 8 is possible but rejected by some versions of zlib. +* **Window Bits** controls the size of the compression context. It must be an + integer between 9 (lowest memory usage) and 15 (best compression). Setting it + to 8 is possible but rejected by some versions of zlib and not very useful. On the server side, websockets defaults to 12. Specifically, the compression window size (server to client) is always 12 while the decompression window @@ -94,9 +95,8 @@ and memory usage for both sides. has the same effect as defaulting to 15. :mod:`zlib` offers additional parameters for tuning compression. They control -the trade-off between compression rate, memory usage, and CPU usage only for -compressing. They're transparent for decompressing. Unless mentioned -otherwise, websockets inherits defaults of :func:`~zlib.compressobj`. +the trade-off between compression rate, memory usage, and CPU usage for +compressing. They're transparent for decompressing. * **Memory Level** controls the size of the compression state. It must be an integer between 1 (lowest memory usage) and 9 (best compression). @@ -108,87 +108,82 @@ otherwise, websockets inherits defaults of :func:`~zlib.compressobj`. * **Compression Level** controls the effort to optimize compression. It must be an integer between 1 (lowest CPU usage) and 9 (best compression). + websockets relies on the default value chosen by :func:`~zlib.compressobj`, + ``Z_DEFAULT_COMPRESSION``. + * **Strategy** selects the compression strategy. The best choice depends on the type of data being compressed. + websockets relies on the default value chosen by :func:`~zlib.compressobj`, + ``Z_DEFAULT_STRATEGY``. -Tuning compression ------------------- +To customize these parameters, add keyword arguments for +:func:`~zlib.compressobj` in ``compress_settings``. -For servers -........... +Default settings for servers +---------------------------- By default, websockets enables compression with conservative settings that optimize memory usage at the cost of a slightly worse compression rate: -Window Bits = 12 and Memory Level = 5. This strikes a good balance for small +Window Bits = 12 and Memory Level = 5. This strikes a good balance for small messages that are typical of WebSocket servers. -Here's how various compression settings affect memory usage of a single -connection on a 64-bit system, as well a benchmark of compressed size and -compression time for a corpus of small JSON documents. +Here's an example of how compression settings affect memory usage per +connection, compressed size, and compression time for a corpus of JSON +documents. =========== ============ ============ ================ ================ Window Bits Memory Level Memory usage Size vs. default Time vs. default =========== ============ ============ ================ ================ -15 8 322 KiB -4.0% +15% -14 7 178 KiB -2.6% +10% -13 6 106 KiB -1.4% +5% -**12** **5** **70 KiB** **=** **=** -11 4 52 KiB +3.7% -5% -10 3 43 KiB +90% +50% -9 2 39 KiB +160% +100% -— — 19 KiB +452% — +15 8 316 KiB -10% +10% +14 7 172 KiB -7% +5% +13 6 100 KiB -3% +2% +**12** **5** **64 KiB** **=** **=** +11 4 46 KiB +10% +4% +10 3 37 KiB +70% +40% +9 2 33 KiB +130% +90% +— — 14 KiB +350% — =========== ============ ============ ================ ================ Window Bits and Memory Level don't have to move in lockstep. However, other combinations don't yield significantly better results than those shown above. -Compressed size and compression time depend heavily on the kind of messages -exchanged by the application so this example may not apply to your use case. - -You can adapt `compression/benchmark.py`_ by creating a list of typical -messages and passing it to the ``_run`` function. - -Window Bits = 11 and Memory Level = 4 looks like the sweet spot in this table. - -websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from -Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts -on what could happen at Window Bits = 11 and Memory Level = 4 on a different +websockets defaults to Window Bits = 12 and Memory Level = 5 to stay away from +Window Bits = 10 or Memory Level = 3 where performance craters, raising doubts +on what could happen at Window Bits = 11 and Memory Level = 4 on a different corpus. Defaults must be safe for all applications, hence a more conservative choice. -.. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py +Optimizing settings +------------------- -The benchmark focuses on compression because it's more expensive than -decompression. Indeed, leaving aside small allocations, theoretical memory -usage is: +Compressed size and compression time depend on the structure of messages +exchanged by your application. As a consequence, default settings may not be +optimal for your use case. -* ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; -* ``1 << windowBits`` for decompression. +To compare how various compression settings perform for your use case: -CPU usage is also higher for compression than decompression. +1. Create a corpus of typical messages in a directory, one message per file. +2. Run the `compression/benchmark.py`_ script, passing the directory in + argument. -While it's always possible for a server to use a smaller window size for -compressing outgoing messages, using a smaller window size for decompressing -incoming messages requires collaboration from clients. +The script measures compressed size and compression time for all combinations of +Window Bits and Memory Level. It outputs two tables with absolute values and two +tables with values relative to websockets' default settings. -When a client doesn't support configuring the size of its compression window, -websockets enables compression with the largest possible decompression window. -In most use cases, this is more efficient than disabling compression both ways. +Pick your favorite settings in these tables and configure them as shown above. -If you are very sensitive to memory usage, you can reverse this behavior by -setting the ``require_client_max_window_bits`` parameter of -:class:`ServerPerMessageDeflateFactory` to ``True``. +.. _compression/benchmark.py: https://github.com/python-websockets/websockets/blob/main/experiments/compression/benchmark.py -For clients -........... +Default settings for clients +---------------------------- -By default, websockets enables compression with Memory Level = 5 but leaves +By default, websockets enables compression with Memory Level = 5 but leaves the Window Bits setting up to the server. -There's two good reasons and one bad reason for not optimizing the client side -like the server side: +There's two good reasons and one bad reason for not optimizing Window Bits on +the client side as on the server side: 1. If the maintainers of a server configured some optimized settings, we don't want to override them with more restrictive settings. @@ -196,8 +191,9 @@ like the server side: 2. Optimizing memory usage doesn't matter very much for clients because it's uncommon to open thousands of client connections in a program. -3. On a more pragmatic note, some servers misbehave badly when a client - configures compression settings. `AWS API Gateway`_ is the worst offender. +3. On a more pragmatic and annoying note, some servers misbehave badly when a + client configures compression settings. `AWS API Gateway`_ is the worst + offender. .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 @@ -207,6 +203,29 @@ like the server side: Until the ecosystem levels up, interoperability with buggy servers seems more valuable than optimizing memory usage. +Decompression +------------- + +The discussion above focuses on compression because it's more expensive than +decompression. Indeed, leaving aside small allocations, theoretical memory +usage is: + +* ``(1 << (windowBits + 2)) + (1 << (memLevel + 9))`` for compression; +* ``1 << windowBits`` for decompression. + +CPU usage is also higher for compression than decompression. + +While it's always possible for a server to use a smaller window size for +compressing outgoing messages, using a smaller window size for decompressing +incoming messages requires collaboration from clients. + +When a client doesn't support configuring the size of its compression window, +websockets enables compression with the largest possible decompression window. +In most use cases, this is more efficient than disabling compression both ways. + +If you are very sensitive to memory usage, you can reverse this behavior by +setting the ``require_client_max_window_bits`` parameter of +:class:`ServerPerMessageDeflateFactory` to ``True``. Further reading --------------- @@ -216,7 +235,7 @@ settings affect memory usage and how to optimize them. .. _blog post by Ilya Grigorik: https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression/ -This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory -Level = 4 for optimizing memory usage. +This `experiment by Peter Thorson`_ recommends Window Bits = 11 and Memory +Level = 4 for optimizing memory usage. .. _experiment by Peter Thorson: https://mailarchive.ietf.org/arch/msg/hybi/F9t4uPufVEy8KBLuL36cZjCmM_Y/ diff --git a/docs/topics/design.rst b/docs/topics/design.rst index f164d2990..cc65e6a70 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -488,55 +488,6 @@ they're drained. That's why all APIs that write frames are asynchronous. Of course, it's still possible for an application to create its own unbounded buffers and break the backpressure. Be careful with queues. - -.. _buffers: - -Buffers -------- - -.. note:: - - This section discusses buffers from the perspective of a server but it - applies to clients as well. - -An asynchronous systems works best when its buffers are almost always empty. - -For example, if a client sends data too fast for a server, the queue of -incoming messages will be constantly full. The server will always be 32 -messages (by default) behind the client. This consumes memory and increases -latency for no good reason. The problem is called bufferbloat. - -If buffers are almost always full and that problem cannot be solved by adding -capacity — typically because the system is bottlenecked by the output and -constantly regulated by backpressure — reducing the size of buffers minimizes -negative consequences. - -By default websockets has rather high limits. You can decrease them according -to your application's characteristics. - -Bufferbloat can happen at every level in the stack where there is a buffer. -For each connection, the receiving side contains these buffers: - -- OS buffers: tuning them is an advanced optimization. -- :class:`~asyncio.StreamReader` bytes buffer: the default limit is 64 KiB. - You can set another limit by passing a ``read_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve`. -- Incoming messages :class:`~collections.deque`: its size depends both on - the size and the number of messages it contains. By default the maximum - UTF-8 encoded size is 1 MiB and the maximum number is 32. In the worst case, - after UTF-8 decoding, a single message could take up to 4 MiB of memory and - the overall memory consumption could reach 128 MiB. You should adjust these - limits by setting the ``max_size`` and ``max_queue`` keyword arguments of - :func:`~client.connect()` or :func:`~server.serve` according to your - application's requirements. - -For each connection, the sending side contains these buffers: - -- :class:`~asyncio.StreamWriter` bytes buffer: the default size is 64 KiB. - You can set another limit by passing a ``write_limit`` keyword argument to - :func:`~client.connect()` or :func:`~server.serve`. -- OS buffers: tuning them is an advanced optimization. - Concurrency ----------- diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index e44247a77..efbcbb83f 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -1,5 +1,5 @@ -Memory usage -============ +Memory and buffers +================== .. currentmodule:: websockets @@ -9,40 +9,148 @@ memory usage can become a bottleneck. Memory usage of a single connection is the sum of: -1. the baseline amount of memory websockets requires for each connection, -2. the amount of data held in buffers before the application processes it, -3. any additional memory allocated by the application itself. +1. the baseline amount of memory that websockets uses for each connection; +2. the amount of memory needed by your application code; +3. the amount of data held in buffers. -Baseline --------- +Connection +---------- -Compression settings are the main factor affecting the baseline amount of -memory used by each connection. +Compression settings are the primary factor affecting how much memory each +connection uses. -With websockets' defaults, on the server side, a single connections uses -70 KiB of memory. +The :mod:`asyncio` implementation with default settings uses 64 KiB of memory +for each connection. + +You can reduce memory usage to 14 KiB per connection if you disable compression +entirely. Refer to the :doc:`topic guide on compression <../topics/compression>` to learn more about tuning compression settings. +Application +----------- + +Your application will allocate memory for its data structures. Memory usage +depends on your use case and your implementation. + +Make sure that you don't keep references to data that you don't need anymore +because this prevents garbage collection. + Buffers ------- -Under normal circumstances, buffers are almost always empty. +Typical WebSocket applications exchange small messages at a rate that doesn't +saturate the CPU or the network. Buffers are almost always empty. This is the +optimal situation. Buffers absorb bursts of incoming or outgoing messages +without having to pause reading or writing. + +If the application receives messages faster than it can process them, receive +buffers will fill up when. If the application sends messages faster than the +network can transmit them, send buffers will fill up. + +When buffers are almost always full, not only does the additional memory usage +fail to bring any benefit, but latency degrades as well. This problem is called +bufferbloat_. If it cannot be resolved by adding capacity, typically because the +system is bottlenecked by its output and constantly regulated by +:ref:`backpressure `, then buffers should be kept small to ensure +that backpressure kicks in quickly. + +.. _bufferbloat: https://en.wikipedia.org/wiki/Bufferbloat + +To sum up, buffers should be sized to absorb bursts of messages. Making them +larger than necessary often causes more harm than good. + +There are three levels of buffering in an application built with websockets. + +TCP buffers +........... + +The operating system allocates buffers for each TCP connection. The receive +buffer stores data received from the network until the application reads it. +The send buffer stores data written by the application until it's sent to +the network and acknowledged by the recipient. + +Modern operating systems adjust the size of TCP buffers automatically to match +network conditions. Overall, you shouldn't worry about TCP buffers. Just be +aware that they exist. + +In very high throughput scenarios, TCP buffers may grow to several megabytes +to store the data in flight. Then, they can make up the bulk of the memory +usage of a connection. + +I/O library buffers +................... + +I/O libraries like :mod:`asyncio` may provide read and write buffers to reduce +the frequency of system calls or the need to pause reading or writing. + +You should keep these buffers small. Increasing them can help with spiky +workloads but it can also backfire because it delays backpressure. + +* In the new :mod:`asyncio` implementation, there is no library-level read + buffer. + + There is a write buffer. The ``write_limit`` argument of + :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` controls its + size. When the write buffer grows above the high-water mark, + :meth:`~asyncio.connection.Connection.send` waits until it drains under the + low-water mark to return. This creates backpressure on coroutines that send + messages. + +* In the legacy :mod:`asyncio` implementation, there is a library-level read + buffer. The ``read_limit`` argument of :func:`~client.connect` and + :func:`~server.serve` controls its size. When the read buffer grows above the + high-water mark, the connection stops reading from the network until it drains + under the low-water mark. This creates backpressure on the TCP connection. + + There is a write buffer. It as controlled by ``write_limit``. It behaves like + the new :mod:`asyncio` implementation described above. + +* In the :mod:`threading` implementation, there are no library-level buffers. + All I/O operations are performed directly on the :class:`~socket.socket`. + +websockets' buffers +................... + +Incoming messages are queued in a buffer after they have been received from the +network and parsed. A larger buffer may help a slow applications handle bursts +of messages while remaining responsive to control frames. + +The memory footprint of this buffer is bounded by the product of ``max_size``, +which controls the size of items in the queue, and ``max_queue``, which controls +the number of items. + +The ``max_size`` argument of :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` defaults to 1 MiB. Most applications never receive +such large messages. Configuring a smaller value puts a tighter boundary on +memory usage. This can make your application more resilient to denial of service +attacks. + +The behavior of the ``max_queue`` argument of :func:`~asyncio.client.connect` +and :func:`~asyncio.server.serve` varies across implementations. -Under high load, if a server receives more messages than it can process, -bufferbloat can result in excessive memory usage. +* In the new :mod:`asyncio` implementation, ``max_queue`` is the high-water mark + of a queue of incoming frames. It defaults to 16 frames. If the queue grows + larger, the connection stops reading from the network until the application + consumes messages and the queue goes below the low-water mark. This creates + backpressure on the TCP connection. -By default websockets has generous limits. It is strongly recommended to adapt -them to your application. When you call :func:`~server.serve`: + Each item in the queue is a frame. A frame can be a message or a message + fragment. Either way, it must be smaller than ``max_size``, the maximum size + of a message. The queue may use up to ``max_size * max_queue`` bytes of + memory. By default, this is 16 MiB. -- Set ``max_size`` (default: 1 MiB, UTF-8 encoded) to the maximum size of - messages your application generates. -- Set ``max_queue`` (default: 32) to the maximum number of messages your - application expects to receive faster than it can process them. The queue - provides burst tolerance without slowing down the TCP connection. +* In the legacy :mod:`asyncio` implementation, ``max_queue`` is the maximum + size of a queue of incoming messages. It defaults to 32 messages. If the queue + fills up, the connection stops reading from the library-level read buffer + described above. If that buffer fills up as well, it will create backpressure + on the TCP connection. -Furthermore, you can lower ``read_limit`` and ``write_limit`` (default: -64 KiB) to reduce the size of buffers for incoming and outgoing data. + Text messages are decoded before they're added to the queue. Since Python can + use up to 4 bytes of memory per character, the queue may use up to ``4 * + max_size * max_queue`` bytes of memory. By default, this is 128 MiB. -The design document provides :ref:`more details about buffers `. +* In the :mod:`threading` implementation, there is no queue of incoming + messages. The ``max_queue`` argument doesn't exist. The connection keeps at + most one message in memory at a time. diff --git a/experiments/compression/benchmark.py b/experiments/compression/benchmark.py index 4fbdf6220..86ebece31 100644 --- a/experiments/compression/benchmark.py +++ b/experiments/compression/benchmark.py @@ -1,72 +1,32 @@ #!/usr/bin/env python -import getpass -import json -import pickle -import subprocess +import collections +import pathlib import sys import time import zlib -CORPUS_FILE = "corpus.pkl" - REPEAT = 10 WB, ML = 12, 5 # defaults used as a reference -def _corpus(): - OAUTH_TOKEN = getpass.getpass("OAuth Token? ") - COMMIT_API = ( - f'curl -H "Authorization: token {OAUTH_TOKEN}" ' - f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" - ) - - commits = [] - - head = subprocess.check_output("git rev-parse HEAD", shell=True).decode().strip() - todo = [head] - seen = set() - - while todo: - sha = todo.pop(0) - commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) - commits.append(commit) - seen.add(sha) - for parent in json.loads(commit)["parents"]: - sha = parent["sha"] - if sha not in seen and sha not in todo: - todo.append(sha) - time.sleep(1) # rate throttling - - return commits - - -def corpus(): - data = _corpus() - with open(CORPUS_FILE, "wb") as handle: - pickle.dump(data, handle) - - -def _run(data): - size = {} - duration = {} +def benchmark(data): + size = collections.defaultdict(dict) + duration = collections.defaultdict(dict) for wbits in range(9, 16): - size[wbits] = {} - duration[wbits] = {} - for memLevel in range(1, 10): encoder = zlib.compressobj(wbits=-wbits, memLevel=memLevel) encoded = [] + print(f"Compressing {REPEAT} times with {wbits=} and {memLevel=}") + t0 = time.perf_counter() for _ in range(REPEAT): for item in data: - if isinstance(item, str): - item = item.encode() # Taken from PerMessageDeflate.encode item = encoder.compress(item) + encoder.flush(zlib.Z_SYNC_FLUSH) if item.endswith(b"\x00\x00\xff\xff"): @@ -75,7 +35,7 @@ def _run(data): t1 = time.perf_counter() - size[wbits][memLevel] = sum(len(item) for item in encoded) + size[wbits][memLevel] = sum(len(item) for item in encoded) / REPEAT duration[wbits][memLevel] = (t1 - t0) / REPEAT raw_size = sum(len(item) for item in data) @@ -149,15 +109,13 @@ def _run(data): print() -def run(): - with open(CORPUS_FILE, "rb") as handle: - data = pickle.load(handle) - _run(data) +def main(corpus): + data = [file.read_bytes() for file in corpus.iterdir()] + benchmark(data) -try: - run = globals()[sys.argv[1]] -except (KeyError, IndexError): - print(f"Usage: {sys.argv[0]} [corpus|run]") -else: - run() +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} [directory]") + sys.exit(2) + main(pathlib.Path(sys.argv[1])) diff --git a/experiments/compression/client.py b/experiments/compression/client.py index 3ee19ddc5..69bfd5e7c 100644 --- a/experiments/compression/client.py +++ b/experiments/compression/client.py @@ -4,8 +4,8 @@ import statistics import tracemalloc -import websockets -from websockets.extensions import permessage_deflate +from websockets.asyncio.client import connect +from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory CLIENTS = 20 @@ -16,16 +16,16 @@ MEM_SIZE = [] -async def client(client): +async def client(num): # Space out connections to make them sequential. - await asyncio.sleep(client * INTERVAL) + await asyncio.sleep(num * INTERVAL) tracemalloc.start() - async with websockets.connect( + async with connect( "ws://localhost:8765", extensions=[ - permessage_deflate.ClientPerMessageDeflateFactory( + ClientPerMessageDeflateFactory( server_max_window_bits=WB, client_max_window_bits=WB, compress_settings={"memLevel": ML}, @@ -42,11 +42,13 @@ async def client(client): tracemalloc.stop() # Hold connection open until the end of the test. - await asyncio.sleep(CLIENTS * INTERVAL) + await asyncio.sleep((CLIENTS + 1 - num) * INTERVAL) async def clients(): - await asyncio.gather(*[client(client) for client in range(CLIENTS + 1)]) + # Start one more client than necessary because we will ignore + # non-representative results from the first connection. + await asyncio.gather(*[client(num) for num in range(CLIENTS + 1)]) asyncio.run(clients()) diff --git a/experiments/compression/corpus.py b/experiments/compression/corpus.py new file mode 100644 index 000000000..da5661dfa --- /dev/null +++ b/experiments/compression/corpus.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +import getpass +import json +import pathlib +import subprocess +import sys +import time + + +def github_commits(): + OAUTH_TOKEN = getpass.getpass("OAuth Token? ") + COMMIT_API = ( + f'curl -H "Authorization: token {OAUTH_TOKEN}" ' + f"https://api.github.com/repos/python-websockets/websockets/git/commits/:sha" + ) + + commits = [] + + head = subprocess.check_output( + "git rev-parse origin/main", + shell=True, + text=True, + ).strip() + todo = [head] + seen = set() + + while todo: + sha = todo.pop(0) + commit = subprocess.check_output(COMMIT_API.replace(":sha", sha), shell=True) + commits.append(commit) + seen.add(sha) + for parent in json.loads(commit)["parents"]: + sha = parent["sha"] + if sha not in seen and sha not in todo: + todo.append(sha) + time.sleep(1) # rate throttling + + return commits + + +def main(corpus): + data = github_commits() + for num, content in enumerate(reversed(data)): + (corpus / f"{num:04d}.json").write_bytes(content) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} [directory]") + sys.exit(2) + main(pathlib.Path(sys.argv[1])) diff --git a/experiments/compression/server.py b/experiments/compression/server.py index 8d1ee3cd7..1c28f7355 100644 --- a/experiments/compression/server.py +++ b/experiments/compression/server.py @@ -6,8 +6,8 @@ import statistics import tracemalloc -import websockets -from websockets.extensions import permessage_deflate +from websockets.asyncio.server import serve +from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory CLIENTS = 20 @@ -44,12 +44,12 @@ async def server(): print() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( handler, "localhost", 8765, extensions=[ - permessage_deflate.ServerPerMessageDeflateFactory( + ServerPerMessageDeflateFactory( server_max_window_bits=WB, client_max_window_bits=WB, compress_settings={"memLevel": ML}, @@ -63,7 +63,7 @@ async def server(): asyncio.run(server()) -# First connection may incur non-representative setup costs. +# First connection incurs non-representative setup costs. del MEM_SIZE[0] print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB") diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index fea14131e..5b907b79f 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -62,7 +62,8 @@ def __init__( if not self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, **self.compress_settings + wbits=-self.local_max_window_bits, + **self.compress_settings, ) # To handle continuation frames properly, we must keep track of @@ -156,7 +157,8 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( - wbits=-self.local_max_window_bits, **self.compress_settings + wbits=-self.local_max_window_bits, + **self.compress_settings, ) # Compress data. From c3b162d05c3788b9367eb3cce8c5001c37a3e6fa Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 08:53:45 +0200 Subject: [PATCH 552/762] Add broadcast to the new asyncio implementation. --- docs/faq/server.rst | 2 +- docs/howto/upgrade.rst | 10 +- docs/intro/tutorial2.rst | 16 +-- docs/project/changelog.rst | 4 +- docs/reference/asyncio/server.rst | 2 +- docs/reference/new-asyncio/server.rst | 5 + docs/topics/broadcast.rst | 69 ++++++------ docs/topics/logging.rst | 2 +- docs/topics/performance.rst | 6 +- experiments/broadcast/server.py | 21 ++-- src/websockets/asyncio/connection.py | 100 ++++++++++++++++- src/websockets/legacy/protocol.py | 14 +-- src/websockets/sync/connection.py | 2 +- tests/asyncio/test_connection.py | 156 ++++++++++++++++++++++++++ tests/legacy/test_protocol.py | 12 +- 15 files changed, 341 insertions(+), 80 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index cba1cd35f..53e34632f 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -102,7 +102,7 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.broadcast`:: +Then, call :func:`~asyncio.connection.broadcast`:: import websockets diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 10e8967d8..6efaf0f56 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -70,12 +70,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -Broadcast -......... - -The new implementation doesn't support :doc:`broadcasting messages -<../topics/broadcast>` yet. - Keepalive ......... @@ -178,8 +172,8 @@ Server APIs | :class:`websockets.server.WebSocketServerProtocol` |br| | | | ``websockets.legacy.server.WebSocketServerProtocol`` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| :func:`websockets.broadcast` |br| | *not available yet* | -| ``websockets.legacy.protocol.broadcast()`` | | +| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | +| :func:`websockets.legacy.protocol.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | | :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index 5ac4ae9dd..b8e35f292 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -482,7 +482,7 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`broadcast` helper for this purpose: +the :func:`~legacy.protocol.broadcast` helper for this purpose: .. code-block:: python @@ -494,13 +494,14 @@ the :func:`broadcast` helper for this purpose: ... -Calling :func:`broadcast` once is more efficient than +Calling :func:`legacy.protocol.broadcast` once is more efficient than calling :meth:`~legacy.protocol.WebSocketCommonProtocol.send` in a loop. -However, there's a subtle difference in behavior. Did you notice that there's -no ``await`` in the second version? Indeed, :func:`broadcast` is a function, -not a coroutine like :meth:`~legacy.protocol.WebSocketCommonProtocol.send` -or :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. +However, there's a subtle difference in behavior. Did you notice that there's no +``await`` in the second version? Indeed, :func:`legacy.protocol.broadcast` is a +function, not a coroutine like +:meth:`~legacy.protocol.WebSocketCommonProtocol.send` or +:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. It's quite obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` is a coroutine. When you want to receive the next message, you have to wait @@ -521,7 +522,8 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`broadcast` doesn't wait until write buffers drain. +That's why :func:`legacy.protocol.broadcast` doesn't wait until write buffers +drain. For our Connect Four game, there's no difference in practice: the total amount of data sent on a connection for a game of Connect Four is less than 64 KB, diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f033f5632..eaabb2e9f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -212,7 +212,7 @@ Improvements * Added platform-independent wheels. -* Improved error handling in :func:`~websockets.broadcast`. +* Improved error handling in :func:`~legacy.protocol.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. @@ -402,7 +402,7 @@ New features * Added compatibility with Python 3.10. -* Added :func:`~websockets.broadcast` to send a message to many clients. +* Added :func:`~legacy.protocol.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using :func:`~client.connect` as an asynchronous iterator. diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 4bd52b40b..3636f0b33 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -110,4 +110,4 @@ websockets supports HTTP Basic Authentication according to Broadcast --------- -.. autofunction:: websockets.broadcast +.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index c43673d33..7f9de6148 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -70,3 +70,8 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + +Broadcast +--------- + +.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index b6ddda734..671319136 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -1,21 +1,22 @@ -Broadcasting messages -===================== +Broadcasting +============ .. currentmodule:: websockets - -.. admonition:: If you just want to send a message to all connected clients, - use :func:`broadcast`. +.. admonition:: If you want to send a message to all connected clients, + use :func:`~asyncio.connection.broadcast`. :class: tip - If you want to learn about its design in depth, continue reading this - document. + If you want to learn about its design, continue reading this document. + + For the legacy :mod:`asyncio` implementation, use + :func:`~legacy.protocol.broadcast`. WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. -Let's explore options for broadcasting a message, explain the design -of :func:`broadcast`, and discuss alternatives. +Let's explore options for broadcasting a message, explain the design of +:func:`~asyncio.connection.broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all @@ -24,7 +25,7 @@ connected clients. Integrating them is left as an exercise for the reader. You could start with:: import asyncio - import websockets + from websockets.asyncio.server import serve async def handler(websocket): ... @@ -39,7 +40,7 @@ Integrating them is left as an exercise for the reader. You could start with:: await broadcast(message) async def main(): - async with websockets.serve(handler, "localhost", 8765): + async with serve(handler, "localhost", 8765): await broadcast_messages() # runs forever if __name__ == "__main__": @@ -82,11 +83,13 @@ to:: Here's a coroutine that broadcasts a message to all clients:: + from websockets import ConnectionClosed + async def broadcast(message): for websocket in CLIENTS.copy(): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass There are two tricks in this version of ``broadcast()``. @@ -117,11 +120,11 @@ which is usually outside of the control of the server. If you know for sure that you will never write more than ``write_limit`` bytes within ``ping_interval + ping_timeout``, then websockets will terminate slow -connections before the write buffer has time to fill up. +connections before the write buffer can fill up. -Don't set extreme ``write_limit``, ``ping_interval``, and ``ping_timeout`` -values to ensure that this condition holds. Set reasonable values and use the -built-in :func:`broadcast` function instead. +Don't set extreme values of ``write_limit``, ``ping_interval``, or +``ping_timeout`` to ensure that this condition holds! Instead, set reasonable +values and use the built-in :func:`~asyncio.connection.broadcast` function. The concurrent way ------------------ @@ -134,7 +137,7 @@ Let's modify ``broadcast()`` to send messages concurrently:: async def send(websocket, message): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass def broadcast(message): @@ -179,20 +182,20 @@ doesn't work well when broadcasting a message to thousands of clients. When you're sending messages to a single client, you don't want to send them faster than the network can transfer them and the client accept them. This is -why :meth:`~server.WebSocketServerProtocol.send` checks if the write buffer -is full and, if it is, waits until it drain, giving the network and the -client time to catch up. This provides backpressure. +why :meth:`~asyncio.server.ServerConnection.send` checks if the write buffer is +above the high-water mark and, if it is, waits until it drains, giving the +network and the client time to catch up. This provides backpressure. Without backpressure, you could pile up data in the write buffer until the server process runs out of memory and the operating system kills it. -The :meth:`~server.WebSocketServerProtocol.send` API is designed to enforce +The :meth:`~asyncio.server.ServerConnection.send` API is designed to enforce backpressure by default. This helps users of websockets write robust programs even if they never heard about backpressure. For comparison, :class:`asyncio.StreamWriter` requires users to understand -backpressure and to await :meth:`~asyncio.StreamWriter.drain` explicitly -after each :meth:`~asyncio.StreamWriter.write`. +backpressure and to await :meth:`~asyncio.StreamWriter.drain` after each +:meth:`~asyncio.StreamWriter.write` — or at least sufficiently frequently. When broadcasting messages, backpressure consists in slowing down all clients in an attempt to let the slowest client catch up. With thousands of clients, @@ -203,14 +206,14 @@ How do we avoid running out of memory when slow clients can't keep up with the broadcast rate, then? The most straightforward option is to disconnect them. If a client gets too far behind, eventually it reaches the limit defined by -``ping_timeout`` and websockets terminates the connection. You can read the -discussion of :doc:`keepalive and timeouts <./timeouts>` for details. +``ping_timeout`` and websockets terminates the connection. You can refer to +the discussion of :doc:`keepalive and timeouts ` for details. -How :func:`broadcast` works ---------------------------- +How :func:`~asyncio.connection.broadcast` works +----------------------------------------------- -The built-in :func:`broadcast` function is similar to the naive way. The main -difference is that it doesn't apply backpressure. +The built-in :func:`~asyncio.connection.broadcast` function is similar to the +naive way. The main difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. @@ -321,9 +324,9 @@ the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- -The built-in :func:`broadcast` function sends all messages without yielding -control to the event loop. So does the naive way when the network and clients -are fast and reliable. +The built-in :func:`~asyncio.connection.broadcast` function sends all messages +without yielding control to the event loop. So does the naive way when the +network and clients are fast and reliable. For each client, a WebSocket frame is prepared and sent to the network. This is the minimum amount of work required to broadcast a message. @@ -343,7 +346,7 @@ However, this isn't possible in general for two reasons: All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower -than the built-in :func:`broadcast` function. +than the built-in :func:`~asyncio.connection.broadcast` function. There is no major difference between the performance of per-client queues and publish–subscribe. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 873c852c2..765278360 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -220,7 +220,7 @@ Here's what websockets logs at each level. ``WARNING`` ........... -* Failures in :func:`~websockets.broadcast` +* Failures in :func:`~asyncio.connection.broadcast` ``INFO`` ........ diff --git a/docs/topics/performance.rst b/docs/topics/performance.rst index 45e23b239..b226cec43 100644 --- a/docs/topics/performance.rst +++ b/docs/topics/performance.rst @@ -1,6 +1,8 @@ Performance =========== +.. currentmodule:: websockets + Here are tips to optimize performance. uvloop @@ -16,5 +18,5 @@ application.) broadcast --------- -:func:`~websockets.broadcast` is the most efficient way to send a message to -many clients. +:func:`~asyncio.connection.broadcast` is the most efficient way to send a +message to many clients. diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index b0407ba34..0a5c82b3c 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -6,7 +6,9 @@ import sys import time -import websockets +from websockets import ConnectionClosed +from websockets.asyncio.server import serve +from websockets.asyncio.connection import broadcast CLIENTS = set() @@ -15,7 +17,7 @@ async def send(websocket, message): try: await websocket.send(message) - except websockets.ConnectionClosed: + except ConnectionClosed: pass @@ -43,9 +45,6 @@ async def subscribe(self): __aiter__ = subscribe -PUBSUB = PubSub() - - async def handler(websocket, method=None): if method in ["default", "naive", "task", "wait"]: CLIENTS.add(websocket) @@ -63,14 +62,18 @@ async def handler(websocket, method=None): CLIENTS.remove(queue) relay_task.cancel() elif method == "pubsub": + global PUBSUB async for message in PUBSUB: await websocket.send(message) else: raise NotImplementedError(f"unsupported method: {method}") -async def broadcast(method, size, delay): +async def broadcast_messages(method, size, delay): """Broadcast messages at regular intervals.""" + if method == "pubsub": + global PUBSUB + PUBSUB = PubSub() load_average = 0 time_average = 0 pc1, pt1 = time.perf_counter_ns(), time.process_time_ns() @@ -90,7 +93,7 @@ async def broadcast(method, size, delay): message = str(time.time_ns()).encode() + b" " + os.urandom(size - 20) if method == "default": - websockets.broadcast(CLIENTS, message) + broadcast(CLIENTS, message) elif method == "naive": # Since the loop can yield control, make a copy of CLIENTS # to avoid: RuntimeError: Set changed size during iteration @@ -128,14 +131,14 @@ async def broadcast(method, size, delay): async def main(method, size, delay): - async with websockets.serve( + async with serve( functools.partial(handler, method=method), "localhost", 8765, compression=None, ping_timeout=None, ): - await broadcast(method, size, delay) + await broadcast_messages(method, size, delay) if __name__ == "__main__": diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 1c4424f0d..9d2f087da 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -6,6 +6,7 @@ import logging import random import struct +import sys import uuid from types import TracebackType from typing import ( @@ -27,7 +28,7 @@ from .messages import Assembler -__all__ = ["Connection"] +__all__ = ["Connection", "broadcast"] class Connection(asyncio.Protocol): @@ -338,7 +339,6 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If the connection busy sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -488,7 +488,7 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No self.fragmented_send_waiter = None else: - raise TypeError("data must be bytes, str, iterable, or async iterable") + raise TypeError("data must be str, bytes, iterable, or async iterable") async def close(self, code: int = 1000, reason: str = "") -> None: """ @@ -673,7 +673,7 @@ async def send_context( On entry, :meth:`send_context` checks that the connection is open; on exit, it writes outgoing data to the socket:: - async async with self.send_context(): + async with self.send_context(): self.protocol.send_text(message.encode()) When the connection isn't open on entry, when the connection is expected @@ -916,3 +916,95 @@ def eof_received(self) -> None: # As a consequence, they never need to write after receiving EOF, so # there's no reason to keep the transport open by returning True. # Besides, that doesn't work on TLS connections. + + +def broadcast( + connections: Iterable[Connection], + message: Data, + raise_exceptions: bool = False, +) -> None: + """ + Broadcast a message to several WebSocket connections. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent + as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + :func:`broadcast` pushes the message synchronously to all connections even + if their write buffers are overflowing. There's no backpressure. + + If you broadcast messages faster than a connection can handle them, messages + will pile up in its write buffer until the connection times out. Keep + ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage + from slow connections. + + Unlike :meth:`~Connection.send`, :func:`broadcast` doesn't support sending + fragmented messages. Indeed, fragmentation is useful for sending large + messages without buffering them in memory, while :func:`broadcast` buffers + one copy per connection as fast as possible. + + :func:`broadcast` skips connections that aren't open in order to avoid + errors on connections where the closing handshake is in progress. + + :func:`broadcast` ignores failures to write the message on some connections. + It continues writing to other connections. On Python 3.11 and above, you may + set ``raise_exceptions`` to :obj:`True` to record failures and raise all + exceptions in a :pep:`654` :exc:`ExceptionGroup`. + + Args: + websockets: WebSocket connections to which the message will be sent. + message: Message to send. + raise_exceptions: Whether to raise an exception in case of failures. + + Raises: + TypeError: If ``message`` doesn't have a supported type. + + """ + if isinstance(message, str): + send_method = "send_text" + message = message.encode() + elif isinstance(message, BytesLike): + send_method = "send_binary" + else: + raise TypeError("data must be str or bytes") + + if raise_exceptions: + if sys.version_info[:2] < (3, 11): # pragma: no cover + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions = [] + + for connection in connections: + if connection.protocol.state is not OPEN: + continue + + if connection.fragmented_send_waiter is not None: + if raise_exceptions: + exception = RuntimeError("sending a fragmented message") + exceptions.append(exception) + else: + connection.logger.warning( + "skipped broadcast: sending a fragmented message", + ) + continue + + try: + # Call connection.protocol.send_text or send_binary. + # Either way, message is already converted to bytes. + getattr(connection.protocol, send_method)(message) + connection.send_data() + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + connection.logger.warning( + "skipped broadcast: failed to write message", + exc_info=True, + ) + + if raise_exceptions and exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 6f8916576..b948257e0 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1570,18 +1570,17 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, - :func:`broadcast` doesn't support sending fragmented messages. Indeed, - fragmentation is useful for sending large messages without buffering them in - memory, while :func:`broadcast` buffers one copy per connection as fast as - possible. + Unlike :meth:`~WebSocketCommonProtocol.send`, :func:`broadcast` doesn't + support sending fragmented messages. Indeed, fragmentation is useful for + sending large messages without buffering them in memory, while + :func:`broadcast` buffers one copy per connection as fast as possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. :func:`broadcast` ignores failures to write the message on some connections. - It continues writing to other connections. On Python 3.11 and above, you - may set ``raise_exceptions`` to :obj:`True` to record failures and raise all + It continues writing to other connections. On Python 3.11 and above, you may + set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. Args: @@ -1615,6 +1614,7 @@ def broadcast( websocket.logger.warning( "skipped broadcast: sending a fragmented message", ) + continue try: websocket.write_frame_sync(True, opcode, data) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index a4826c785..88d6aee1f 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -387,7 +387,7 @@ def send(self, message: Data | Iterable[Data]) -> None: raise else: - raise TypeError("data must be bytes, str, or iterable") + raise TypeError("data must be str, bytes, or iterable") def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: """ diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 02029b754..1cf382a01 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -2,6 +2,7 @@ import contextlib import logging import socket +import sys import unittest import uuid from unittest.mock import Mock, patch @@ -1004,6 +1005,161 @@ async def test_unexpected_failure_in_send_context(self, send_text): self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) + # Test broadcast. + + async def test_broadcast_text(self): + """broadcast broadcasts a text message.""" + broadcast([self.connection], "😀") + await self.assertFrameSent(Frame(Opcode.TEXT, "😀".encode())) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_text_reports_no_errors(self): + """broadcast broadcasts a text message without raising exceptions.""" + broadcast([self.connection], "😀", raise_exceptions=True) + await self.assertFrameSent(Frame(Opcode.TEXT, "😀".encode())) + + async def test_broadcast_binary(self): + """broadcast broadcasts a binary message.""" + broadcast([self.connection], b"\x01\x02\xfe\xff") + await self.assertFrameSent(Frame(Opcode.BINARY, b"\x01\x02\xfe\xff")) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_binary_reports_no_errors(self): + """broadcast broadcasts a binary message without raising exceptions.""" + broadcast([self.connection], b"\x01\x02\xfe\xff", raise_exceptions=True) + await self.assertFrameSent(Frame(Opcode.BINARY, b"\x01\x02\xfe\xff")) + + async def test_broadcast_no_clients(self): + """broadcast does nothing when called with an empty list of clients.""" + broadcast([], "😀") + await self.assertNoFrameSent() + + async def test_broadcast_two_clients(self): + """broadcast broadcasts a message to several clients.""" + broadcast([self.connection, self.connection], "😀") + await self.assertFramesSent( + [ + Frame(Opcode.TEXT, "😀".encode()), + Frame(Opcode.TEXT, "😀".encode()), + ] + ) + + async def test_broadcast_skips_closed_connection(self): + """broadcast ignores closed connections.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + with self.assertNoLogs(): + broadcast([self.connection], "😀") + await self.assertNoFrameSent() + + async def test_broadcast_skips_closing_connection(self): + """broadcast ignores closing connections.""" + async with self.delay_frames_rcvd(MS): + close_task = asyncio.create_task(self.connection.close()) + await asyncio.sleep(0) + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + with self.assertNoLogs(): + broadcast([self.connection], "😀") + await self.assertNoFrameSent() + + await close_task + + async def test_broadcast_skips_connection_with_send_blocked(self): + """broadcast logs a warning when a connection is blocked in send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent(Frame(Opcode.TEXT, "⏳".encode(), fin=False)) + + with self.assertLogs("websockets", logging.WARNING) as logs: + broadcast([self.connection], "😀") + + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: sending a fragmented message"], + ) + + gate.set_result(None) + await send_task + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_reports_connection_with_send_blocked(self): + """broadcast raises exceptions for connections blocked in send.""" + gate = asyncio.get_running_loop().create_future() + + async def fragments(): + yield "⏳" + await gate + + send_task = asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(MS) + await self.assertFrameSent(Frame(Opcode.TEXT, "⏳".encode(), fin=False)) + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.connection], "😀", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + exc = raised.exception.exceptions[0] + self.assertEqual(str(exc), "sending a fragmented message") + self.assertIsInstance(exc, RuntimeError) + + gate.set_result(None) + await send_task + + async def test_broadcast_skips_connection_failing_to_send(self): + """broadcast logs a warning when a connection fails to send.""" + # Inject a fault by shutting down the transport for writing. + self.transport.write_eof() + + with self.assertLogs("websockets", logging.WARNING) as logs: + broadcast([self.connection], "😀") + + self.assertEqual( + [record.getMessage() for record in logs.records][:2], + ["skipped broadcast: failed to write message"], + ) + + @unittest.skipIf( + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", + ) + async def test_broadcast_reports_connection_failing_to_send(self): + """broadcast raises exceptions for connections failing to send.""" + # Inject a fault by shutting down the transport for writing. + self.transport.write_eof() + + with self.assertRaises(ExceptionGroup) as raised: + broadcast([self.connection], "😀", raise_exceptions=True) + + self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") + exc = raised.exception.exceptions[0] + self.assertEqual(str(exc), "failed to write message") + self.assertIsInstance(exc, RuntimeError) + cause = exc.__cause__ + self.assertEqual(str(cause), "Cannot call write() after write_eof()") + self.assertIsInstance(cause, RuntimeError) + + async def test_broadcast_type_error(self): + """broadcast raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + broadcast([self.connection], ["⏳", "⌛️"]) + class ServerConnectionTests(ClientConnectionTests): LOCAL = SERVER diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index d6303dcc7..ccea34719 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1473,7 +1473,8 @@ def test_broadcast_text(self): self.assertOneFrameSent(True, OP_TEXT, "café".encode()) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_text_reports_no_errors(self): broadcast([self.protocol], "café", raise_exceptions=True) @@ -1484,7 +1485,8 @@ def test_broadcast_binary(self): self.assertOneFrameSent(True, OP_BINARY, b"tea") @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_binary_reports_no_errors(self): broadcast([self.protocol], b"tea", raise_exceptions=True) @@ -1536,7 +1538,8 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): ) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_reports_connection_sending_fragmented_text(self): self.make_drain_slow() @@ -1565,7 +1568,8 @@ def test_broadcast_skips_connection_failing_to_send(self): ) @unittest.skipIf( - sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+" + sys.version_info[:2] < (3, 11), + "raise_exceptions requires Python 3.11+", ) def test_broadcast_reports_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. From 8d9f9a1cc791df01d7995693551cd9cf83e154c2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 11:15:51 +0200 Subject: [PATCH 553/762] Expose connection state in the new asyncio implementation. --- docs/reference/new-asyncio/client.rst | 2 ++ docs/reference/new-asyncio/common.rst | 2 ++ docs/reference/new-asyncio/server.rst | 2 ++ src/websockets/asyncio/connection.py | 12 ++++++++++++ src/websockets/protocol.py | 2 +- tests/asyncio/test_connection.py | 6 +++++- 6 files changed, 24 insertions(+), 2 deletions(-) diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index 196bda2b7..efd143f14 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -43,6 +43,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index 4fa97dcf2..60ea6bb37 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -33,6 +33,8 @@ Both sides (new :mod:`asyncio`) .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index 7f9de6148..b163e0fcd 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -62,6 +62,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 9d2f087da..a323376ca 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -137,6 +137,18 @@ def remote_address(self) -> Any: """ return self.transport.get_extra_info("peername") + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.state + @property def subprotocol(self) -> Subprotocol | None: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 917c19163..de065c544 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -156,7 +156,7 @@ def __init__( @property def state(self) -> State: """ - WebSocket connection state. + State of the WebSocket connection. Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 1cf382a01..239b5312e 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -11,7 +11,7 @@ from websockets.asyncio.connection import * from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from websockets.frames import CloseCode, Frame, Opcode -from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol from ..utils import MS @@ -930,6 +930,10 @@ async def test_remote_address(self, get_extra_info): self.assertEqual(self.connection.remote_address, ("peer", 1234)) get_extra_info.assert_called_with("peername") + async def test_state(self): + """Connection has a state attribute.""" + self.assertEqual(self.connection.state, State.OPEN) + async def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) From 7c8e0b9d6246cd7bdd304f630f719fc55620f89a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 10:20:37 +0200 Subject: [PATCH 554/762] Document removal of open and closed properties. They won't be added to the new asyncio implementation. --- docs/howto/upgrade.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 6efaf0f56..8ff18c594 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -371,3 +371,29 @@ buffer now. The ``write_limit`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. + +Attributes of connections +......................... + +``open`` and ``closed`` +~~~~~~~~~~~~~~~~~~~~~~~ + +The :attr:`~legacy.protocol.WebSocketCommonProtocol.open` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.closed` properties are removed. +Using them was discouraged. + +Instead, you should call :meth:`~asyncio.connection.Connection.recv` or +:meth:`~asyncio.connection.Connection.send` and handle +:exc:`~exceptions.ConnectionClosed` exceptions. + +If your code relies on them, you can replace:: + + connection.open + connection.closed + +with:: + + from websockets.protocol import State + + connection.state is State.OPEN + connection.state is State.CLOSED From 7c1d1d9b97fa698034d2b3651eb5a757e42b3dfb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 11:53:37 +0200 Subject: [PATCH 555/762] Add a respond method to server connections. It's an alias of the reject method of the underlying server protocol. It makes it easier to write process_request It's called respond because the semantics are "consider the request as an HTTP request and create an HTTP response". There isn't a similar alias for accept because process_request should just return and websockets will call accept. --- docs/howto/upgrade.rst | 2 +- docs/reference/new-asyncio/server.rst | 2 ++ docs/reference/sync/server.rst | 2 ++ src/websockets/asyncio/server.py | 23 ++++++++++++++++++++++- src/websockets/server.py | 27 ++++++++++++++------------- src/websockets/sync/server.py | 23 ++++++++++++++++++++++- tests/asyncio/test_server.py | 4 ++-- tests/sync/test_server.py | 2 +- 8 files changed, 66 insertions(+), 19 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 8ff18c594..fe95a6517 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -247,7 +247,7 @@ an example:: # New implementation def process_request(connection, request): - return connection.protocol.reject(http.HTTPStatus.OK, "OK\n") + return connection.respond(http.HTTPStatus.OK, "OK\n") serve(..., process_request=process_request, ...) diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index b163e0fcd..5ffcff843 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -52,6 +52,8 @@ Using a connection .. automethod:: pong + .. automethod:: respond + WebSocket connection objects also provide these attributes: .. autoattribute:: id diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 7ed744df2..26ab872c8 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -40,6 +40,8 @@ Using a connection .. automethod:: pong + .. automethod:: respond + WebSocket connection objects also provide these attributes: .. autoattribute:: id diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 4feea13c4..cc2f46216 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -23,7 +23,7 @@ from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol -from ..typing import LoggerLike, Origin, Subprotocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout from .connection import Connection @@ -75,6 +75,27 @@ def __init__( self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + async def handshake( self, process_request: ( diff --git a/src/websockets/server.py b/src/websockets/server.py index 1b4c3bf29..2ab9102f7 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -113,18 +113,22 @@ def accept(self, request: Request) -> Response: """ Create a handshake response to accept the connection. - If the connection cannot be established, the handshake response - actually rejects the handshake. + If the handshake request is valid and the handshake successful, + :meth:`accept` returns an HTTP response with status code 101. + + Else, it returns an HTTP response with another status code. This rejects + the connection, like :meth:`reject` would. You must send the handshake response with :meth:`send_response`. - You may modify it before sending it, for example to add HTTP headers. + You may modify the response before sending it, typically by adding HTTP + headers. Args: - request: WebSocket handshake request event received from the client. + request: WebSocket handshake request received from the client. Returns: - WebSocket handshake response event to send to the client. + WebSocket handshake response or HTTP response to send to the client. """ try: @@ -485,11 +489,7 @@ def select_subprotocol(protocol, subprotocols): + ", ".join(self.available_subprotocols) ) - def reject( - self, - status: StatusLike, - text: str, - ) -> Response: + def reject(self, status: StatusLike, text: str) -> Response: """ Create a handshake response to reject the connection. @@ -498,14 +498,15 @@ def reject( You must send the handshake response with :meth:`send_response`. - You can modify it before sending it, for example to alter HTTP headers. + You may modify the response before sending it, for example by changing + HTTP headers. Args: status: HTTP status code. - text: HTTP response body; will be encoded to UTF-8. + text: HTTP response body; it will be encoded to UTF-8. Returns: - WebSocket handshake response event to send to the client. + HTTP response to send to the client. """ # If a user passes an int instead of a HTTPStatus, fix it automatically. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 10fbe4859..b381908ca 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -19,7 +19,7 @@ from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol -from ..typing import LoggerLike, Origin, Subprotocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection from .utils import Deadline @@ -66,6 +66,27 @@ def __init__( close_timeout=close_timeout, ) + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + def handshake( self, process_request: ( diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 4a8a76a21..fa590210f 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -144,7 +144,7 @@ async def test_process_request_abort_handshake(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: @@ -159,7 +159,7 @@ async def test_async_process_request_abort_handshake(self): """Server aborts handshake if async process_request returns a response.""" async def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 9d509a5c4..4e04a39d5 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -131,7 +131,7 @@ def test_process_request_abort_handshake(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): - return ws.protocol.reject(http.HTTPStatus.FORBIDDEN, "Forbidden") + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: From 7b19e790ce766dadb0b90b040be68694074a7e0d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 17 Aug 2024 23:12:36 +0200 Subject: [PATCH 556/762] Add keepalive to the new asyncio implementation. --- docs/faq/common.rst | 4 +- docs/howto/upgrade.rst | 11 --- docs/reference/features.rst | 4 +- docs/reference/new-asyncio/client.rst | 2 + docs/reference/new-asyncio/common.rst | 2 + docs/reference/new-asyncio/server.rst | 2 + docs/topics/broadcast.rst | 4 +- docs/topics/index.rst | 2 +- docs/topics/{timeouts.rst => keepalive.rst} | 16 ++-- src/websockets/asyncio/client.py | 21 ++++- src/websockets/asyncio/connection.py | 89 ++++++++++++++++++-- src/websockets/asyncio/server.py | 21 ++++- src/websockets/legacy/protocol.py | 16 ++-- tests/asyncio/test_client.py | 15 ++++ tests/asyncio/test_connection.py | 91 ++++++++++++++++++++- tests/asyncio/test_server.py | 21 +++++ tests/legacy/test_protocol.py | 4 +- 17 files changed, 274 insertions(+), 51 deletions(-) rename docs/topics/{timeouts.rst => keepalive.rst} (90%) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 2c63c4f36..84256fdfe 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -97,7 +97,7 @@ There are two main reasons why latency may increase: * Poor network connectivity. * More traffic than the recipient can handle. -See the discussion of :doc:`timeouts <../topics/timeouts>` for details. +See the discussion of :doc:`keepalive <../topics/keepalive>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. @@ -146,7 +146,7 @@ It closes the connection if it doesn't get a pong within 20 seconds. You can adjust this behavior with ``ping_interval`` and ``ping_timeout``. -See :doc:`../topics/timeouts` for details. +See :doc:`../topics/keepalive` for details. How do I respond to pings? -------------------------- diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index fe95a6517..16b010aca 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -70,17 +70,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -Keepalive -......... - -The new implementation doesn't provide a :ref:`keepalive mechanism ` -yet. - -As a consequence, :func:`~asyncio.client.connect` and -:func:`~asyncio.server.serve` don't accept the ``ping_interval`` and -``ping_timeout`` arguments and the -:attr:`~legacy.protocol.WebSocketCommonProtocol.latency` property doesn't exist. - HTTP Basic Authentication ......................... diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 946770fe3..45fa79c48 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -53,9 +53,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Keepalive | ❌ | ❌ | — | ✅ | + | Keepalive | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Heartbeat | ❌ | ❌ | — | ✅ | + | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst index efd143f14..77a3c5d53 100644 --- a/docs/reference/new-asyncio/client.rst +++ b/docs/reference/new-asyncio/client.rst @@ -43,6 +43,8 @@ Using a connection .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst index 60ea6bb37..a58325fb9 100644 --- a/docs/reference/new-asyncio/common.rst +++ b/docs/reference/new-asyncio/common.rst @@ -33,6 +33,8 @@ Both sides (new :mod:`asyncio`) .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst index 5ffcff843..7bceca5a0 100644 --- a/docs/reference/new-asyncio/server.rst +++ b/docs/reference/new-asyncio/server.rst @@ -64,6 +64,8 @@ Using a connection .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index 671319136..ec358bbd2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -206,8 +206,8 @@ How do we avoid running out of memory when slow clients can't keep up with the broadcast rate, then? The most straightforward option is to disconnect them. If a client gets too far behind, eventually it reaches the limit defined by -``ping_timeout`` and websockets terminates the connection. You can refer to -the discussion of :doc:`keepalive and timeouts ` for details. +``ping_timeout`` and websockets terminates the connection. You can refer to the +discussion of :doc:`keepalive ` for details. How :func:`~asyncio.connection.broadcast` works ----------------------------------------------- diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 120a3dd32..a2b8ca879 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -11,7 +11,7 @@ Get a deeper understanding of how websockets is built and why. authentication broadcast compression - timeouts + keepalive design memory security diff --git a/docs/topics/timeouts.rst b/docs/topics/keepalive.rst similarity index 90% rename from docs/topics/timeouts.rst rename to docs/topics/keepalive.rst index 633fc1ab4..1c7a43264 100644 --- a/docs/topics/timeouts.rst +++ b/docs/topics/keepalive.rst @@ -1,5 +1,5 @@ -Timeouts -======== +Keepalive and latency +===================== .. currentmodule:: websockets @@ -49,9 +49,9 @@ This mechanism serves two purposes: application gets a :exc:`~exceptions.ConnectionClosed` exception. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` -arguments of :func:`~client.connect` and :func:`~server.serve`. Shorter values -will detect connection drops faster but they will increase network traffic and -they will be more sensitive to latency. +arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve`. +Shorter values will detect connection drops faster but they will increase +network traffic and they will be more sensitive to latency. Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and heartbeat mechanism. @@ -111,6 +111,6 @@ Latency between a client and a server may increase for two reasons: than the client can accept. The latency measured during the last exchange of Ping and Pong frames is -available in the :attr:`~legacy.protocol.WebSocketCommonProtocol.latency` -attribute. Alternatively, you can measure the latency at any time with the -:attr:`~legacy.protocol.WebSocketCommonProtocol.ping` method. +available in the :attr:`~asyncio.connection.Connection.latency` attribute. +Alternatively, you can measure the latency at any time with the +:attr:`~asyncio.connection.Connection.ping` method. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b2eaf9a65..632d3ac2b 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -37,10 +37,11 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, + and ``write_limit`` arguments the same meaning as in :func:`connect`. + Args: protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. - :obj:`None` disables the timeout. """ @@ -48,6 +49,8 @@ def __init__( self, protocol: ClientProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, @@ -55,6 +58,8 @@ def __init__( self.protocol: ClientProtocol super().__init__( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, @@ -84,7 +89,9 @@ async def handshake( if self.response is None: raise ConnectionError("connection closed during handshake") - if self.protocol.handshake_exc is not None: + if self.protocol.handshake_exc is None: + self.start_keepalive() + else: try: async with asyncio_timeout(self.close_timeout): await self.connection_lost_waiter @@ -146,6 +153,10 @@ class connect: :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -208,6 +219,8 @@ def __init__( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -256,6 +269,8 @@ def factory() -> ClientConnection: # This is a connection in websockets and a protocol in asyncio. connection = create_connection( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index a323376ca..b232b7956 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -24,7 +24,13 @@ from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol -from .compatibility import TimeoutError, aiter, anext, asyncio_timeout_at +from .compatibility import ( + TimeoutError, + aiter, + anext, + asyncio_timeout, + asyncio_timeout_at, +) from .messages import Assembler @@ -48,11 +54,15 @@ def __init__( self, protocol: Protocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout self.close_timeout = close_timeout if isinstance(max_queue, int): max_queue = (max_queue, None) @@ -95,6 +105,21 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + This value is updated after sending a ping frame and receiving a + matching pong frame. Before the first ping, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends ping frames automatically at regular intervals. You can also + send ping frames and measure latency with :meth:`ping`. + """ + + # Task that sends keepalive pings. None when ping_interval is None. + self.keepalive_task: asyncio.Task[None] | None = None + # Exception raised while reading from the connection, to be chained to # ConnectionClosed in order to show why the TCP connection dropped. self.recv_exc: BaseException | None = None @@ -144,7 +169,8 @@ def state(self) -> State: This attribute is provided for completeness. Typical applications shouldn't check its value. Instead, they should call :meth:`~recv` or - :meth:`send` and handle :exc:`~exceptions.ConnectionClosed` exceptions. + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. """ return self.protocol.state @@ -540,7 +566,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Data | None = None) -> Awaitable[None]: + async def ping(self, data: Data | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -643,8 +669,10 @@ def acknowledge_pings(self, data: bytes) -> None: ping_ids = [] for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) - pong_waiter.set_result(pong_timestamp - ping_timestamp) + latency = pong_timestamp - ping_timestamp + pong_waiter.set_result(latency) if ping_id == data: + self.latency = latency break else: raise AssertionError("solicited pong not found in pings") @@ -664,7 +692,8 @@ def abort_pings(self) -> None: exc = self.protocol.close_exc for pong_waiter, _ping_timestamp in self.pong_waiters.values(): - pong_waiter.set_exception(exc) + if not pong_waiter.done(): + pong_waiter.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does @@ -673,6 +702,50 @@ def abort_pings(self) -> None: self.pong_waiters.clear() + async def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + latency = 0.0 + try: + while True: + # If self.ping_timeout > latency > self.ping_interval, pings + # will be sent immediately after receiving pongs. The period + # will be longer than self.ping_interval. + await asyncio.sleep(self.ping_interval - latency) + + self.logger.debug("% sending keepalive ping") + pong_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + async with asyncio_timeout(self.ping_timeout): + latency = await pong_waiter + self.logger.debug("% received keepalive pong") + except asyncio.TimeoutError: + if self.debug: + self.logger.debug("! timed out waiting for keepalive pong") + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except ConnectionClosed: + pass + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a task, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + self.keepalive_task = self.loop.create_task(self.keepalive()) + @contextlib.asynccontextmanager async def send_context( self, @@ -835,11 +908,15 @@ def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent self.recv_messages.close() self.set_recv_exc(exc) + self.abort_pings() + # If keepalive() was waiting for a pong, abort_pings() terminated it. + # If it was sleeping until the next ping, we need to cancel it now + if self.keepalive_task is not None: + self.keepalive_task.cancel() # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. self.connection_lost_waiter.set_result(None) - self.abort_pings() # Adapted from asyncio.streams.FlowControlMixin if self.paused: # pragma: no cover diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index cc2f46216..1f55502bb 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -48,11 +48,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, + and ``write_limit`` arguments the same meaning as in :func:`serve`. + Args: protocol: Sans-I/O connection. server: Server that manages this connection. - close_timeout: Timeout for closing connections in seconds. - :obj:`None` disables the timeout. """ @@ -61,6 +62,8 @@ def __init__( protocol: ServerProtocol, server: WebSocketServer, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | tuple[int, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, @@ -68,6 +71,8 @@ def __init__( self.protocol: ServerProtocol super().__init__( protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, @@ -182,7 +187,9 @@ async def handshake( self.protocol.send_response(self.response) - if self.protocol.handshake_exc is not None: + if self.protocol.handshake_exc is None: + self.start_keepalive() + else: try: async with asyncio_timeout(self.close_timeout): await self.connection_lost_waiter @@ -595,6 +602,10 @@ def handler(websocket): :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing connections in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -667,6 +678,8 @@ def __init__( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -743,6 +756,8 @@ def protocol_select_subprotocol( connection = create_connection( protocol, self.server, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, write_limit=write_limit, diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index b948257e0..191350de3 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -89,7 +89,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. + See the discussion of :doc:`timeouts <../../topics/keepalive>` for details. The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy @@ -144,8 +144,8 @@ class WebSocketCommonProtocol(asyncio.Protocol): logger: Logger for this server. It defaults to ``logging.getLogger("websockets.protocol")``. See the :doc:`logging guide <../../topics/logging>` for details. - ping_interval: Delay between keepalive pings in seconds. - :obj:`None` disables keepalive pings. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. ping_timeout: Timeout for keepalive pings in seconds. :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. @@ -1242,18 +1242,16 @@ async def keepalive_ping(self) -> None: while True: await asyncio.sleep(self.ping_interval) - # ping() raises CancelledError if the connection is closed, - # when close_connection() cancels self.keepalive_ping_task. - - # ping() raises ConnectionClosed if the connection is lost, - # when connection_lost() calls abort_pings(). - self.logger.debug("% sending keepalive ping") pong_waiter = await self.ping() if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): + # Raises CancelledError if the connection is closed, + # when close_connection() cancels keepalive_ping(). + # Raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index b74617ef0..0bd2af4f1 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -63,6 +63,21 @@ async def test_disable_compression(self): async with run_client(server, compression=None) as client: self.assertEqual(client.protocol.extensions, []) + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with run_server() as server: + async with run_client(server, ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await asyncio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server() as server: + async with run_client(server, ping_interval=None) as client: + await asyncio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + async def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 239b5312e..9b84a6b81 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -868,6 +868,93 @@ async def test_pong_explicit_binary(self): await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + # Test keepalive. + + @patch("random.getrandbits") + async def test_keepalive(self, getrandbits): + """keepalive sends pings.""" + self.connection.ping_interval = 2 * MS + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + @patch("random.getrandbits") + async def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + async with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for MS. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + await asyncio.sleep(MS) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xf3keepalive ping timeout") + ) + + @patch("random.getrandbits") + async def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + async with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for MS. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + await asyncio.sleep(MS) + await self.assertNoFrameSent() + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertNoFrameSent() + + async def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 2 * MS + self.connection.start_keepalive() + await asyncio.sleep(MS) + await self.connection.close() + self.assertTrue(self.connection.keepalive_task.done()) + + async def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = 2 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for MS. + await self.connection.close() + self.assertTrue(self.connection.keepalive_task.done()) + + async def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + # Inject a fault by raising an exception in a pending pong waiter. + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for MS. + pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] + with self.assertLogs("websockets", logging.ERROR) as logs: + pong_waiter.set_exception(Exception("BOOM")) + await asyncio.sleep(0) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + # Test parameters. async def test_close_timeout(self): @@ -1092,7 +1179,7 @@ async def fragments(): broadcast([self.connection], "😀") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: sending a fragmented message"], ) @@ -1135,7 +1222,7 @@ async def test_broadcast_skips_connection_failing_to_send(self): broadcast([self.connection], "😀") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: failed to write message"], ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fa590210f..b3023434b 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -312,6 +312,27 @@ async def test_disable_compression(self): async with run_client(server) as client: await self.assertEval(client, "ws.protocol.extensions", "[]") + async def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + async with run_server(ping_interval=MS) as server: + async with run_client(server) as client: + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + await asyncio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertGreater(latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server(ping_interval=None) as server: + async with run_client(server) as client: + await asyncio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index ccea34719..8751b9ac6 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1533,7 +1533,7 @@ def test_broadcast_skips_connection_sending_fragmented_text(self): broadcast([self.protocol], "café") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: sending a fragmented message"], ) @@ -1563,7 +1563,7 @@ def test_broadcast_skips_connection_failing_to_send(self): broadcast([self.protocol], "café") self.assertEqual( - [record.getMessage() for record in logs.records][:2], + [record.getMessage() for record in logs.records], ["skipped broadcast: failed to write message"], ) From 60381d2566b55126f0f89f0c8380cf44ddc51aa1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 07:58:31 +0200 Subject: [PATCH 557/762] Fix exception chaining for ConnectionClosed. --- src/websockets/asyncio/connection.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index b232b7956..284fe2124 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -906,13 +906,17 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: def connection_lost(self, exc: Exception | None) -> None: self.protocol.receive_eof() # receive_eof is idempotent - self.recv_messages.close() + + # Abort recv() and pending pings with a ConnectionClosed exception. + # Set recv_exc first to get proper exception reporting. self.set_recv_exc(exc) + self.recv_messages.close() self.abort_pings() # If keepalive() was waiting for a pong, abort_pings() terminated it. # If it was sleeping until the next ping, we need to cancel it now if self.keepalive_task is not None: self.keepalive_task.cancel() + # If self.connection_lost_waiter isn't pending, that's a bug, because: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. From a78b5546074ed9e89a265eec1b54292a628d9b25 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 08:00:16 +0200 Subject: [PATCH 558/762] Document legacy implementation in websockets.legacy. Until now it was documented directly in the websockets package. Also update most examples to use the new asyncio implementation. Some drive-by documentation improvements too. --- compliance/test_client.py | 10 +- compliance/test_server.py | 6 +- docs/conf.py | 28 +- docs/faq/asyncio.rst | 44 ++-- docs/faq/client.rst | 62 +++-- docs/faq/common.rst | 53 ++-- docs/faq/misc.rst | 17 +- docs/faq/server.rst | 111 ++++---- docs/howto/cheatsheet.rst | 57 ++-- docs/howto/django.rst | 2 +- docs/howto/heroku.rst | 2 +- docs/howto/nginx.rst | 6 +- docs/howto/patterns.rst | 10 +- docs/howto/quickstart.rst | 14 +- docs/howto/upgrade.rst | 91 ++++--- docs/intro/tutorial1.rst | 34 +-- docs/intro/tutorial3.rst | 6 +- docs/project/changelog.rst | 117 ++++----- docs/reference/asyncio/client.rst | 37 ++- docs/reference/asyncio/common.rst | 33 +-- docs/reference/asyncio/server.rst | 70 ++--- docs/reference/features.rst | 5 +- docs/reference/index.rst | 19 +- docs/reference/legacy/client.rst | 64 +++++ docs/reference/legacy/common.rst | 54 ++++ docs/reference/legacy/server.rst | 113 ++++++++ docs/reference/new-asyncio/client.rst | 57 ---- docs/reference/new-asyncio/common.rst | 47 ---- docs/reference/new-asyncio/server.rst | 83 ------ docs/topics/authentication.rst | 22 +- docs/topics/deployment.rst | 13 +- docs/topics/design.rst | 286 ++++++++++----------- docs/topics/logging.rst | 11 +- docs/topics/memory.rst | 9 +- docs/topics/security.rst | 6 +- example/deployment/fly/app.py | 10 +- example/deployment/haproxy/app.py | 4 +- example/deployment/heroku/app.py | 4 +- example/deployment/kubernetes/app.py | 18 +- example/deployment/kubernetes/benchmark.py | 5 +- example/deployment/nginx/app.py | 4 +- example/deployment/render/app.py | 10 +- example/deployment/supervisor/app.py | 4 +- example/django/authentication.py | 4 +- example/django/notifications.py | 7 +- example/echo.py | 2 +- example/faq/health_check_server.py | 15 +- example/faq/shutdown_client.py | 8 +- example/faq/shutdown_server.py | 5 +- example/legacy/basic_auth_client.py | 5 +- example/legacy/basic_auth_server.py | 8 +- example/legacy/unix_client.py | 5 +- example/legacy/unix_server.py | 5 +- example/logging/json_log_formatter.py | 2 +- example/quickstart/client.py | 5 +- example/quickstart/client_secure.py | 5 +- example/quickstart/counter.py | 13 +- example/quickstart/server.py | 5 +- example/quickstart/server_secure.py | 5 +- example/quickstart/show_time.py | 5 +- example/quickstart/show_time_2.py | 8 +- example/tutorial/step1/app.py | 4 +- example/tutorial/step2/app.py | 4 +- example/tutorial/step3/app.py | 4 +- experiments/authentication/app.py | 19 +- experiments/broadcast/clients.py | 4 +- src/websockets/exceptions.py | 4 +- src/websockets/legacy/auth.py | 6 +- src/websockets/legacy/client.py | 4 +- src/websockets/legacy/protocol.py | 16 +- tests/asyncio/test_connection.py | 2 +- tests/asyncio/test_server.py | 9 +- tests/sync/test_connection.py | 2 +- tests/sync/test_server.py | 5 +- 74 files changed, 951 insertions(+), 902 deletions(-) create mode 100644 docs/reference/legacy/client.rst create mode 100644 docs/reference/legacy/common.rst create mode 100644 docs/reference/legacy/server.rst delete mode 100644 docs/reference/new-asyncio/client.rst delete mode 100644 docs/reference/new-asyncio/common.rst delete mode 100644 docs/reference/new-asyncio/server.rst diff --git a/compliance/test_client.py b/compliance/test_client.py index 1ed4d711e..8e22569fd 100644 --- a/compliance/test_client.py +++ b/compliance/test_client.py @@ -1,9 +1,9 @@ +import asyncio import json import logging import urllib.parse -import asyncio -import websockets +from websockets.asyncio.client import connect logging.basicConfig(level=logging.WARNING) @@ -18,21 +18,21 @@ async def get_case_count(server): uri = f"{server}/getCaseCount" - async with websockets.connect(uri) as ws: + async with connect(uri) as ws: msg = ws.recv() return json.loads(msg) async def run_case(server, case, agent): uri = f"{server}/runCase?case={case}&agent={agent}" - async with websockets.connect(uri, max_size=2 ** 25, max_queue=1) as ws: + async with connect(uri, max_size=2 ** 25, max_queue=1) as ws: async for msg in ws: await ws.send(msg) async def update_reports(server, agent): uri = f"{server}/updateReports?agent={agent}" - async with websockets.connect(uri): + async with connect(uri): pass diff --git a/compliance/test_server.py b/compliance/test_server.py index 5701e4485..39176e902 100644 --- a/compliance/test_server.py +++ b/compliance/test_server.py @@ -1,7 +1,7 @@ +import asyncio import logging -import asyncio -import websockets +from websockets.asyncio.server import serve logging.basicConfig(level=logging.WARNING) @@ -19,7 +19,7 @@ async def echo(ws): async def main(): - with websockets.serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): + with serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): try: await asyncio.get_running_loop().create_future() # run forever except KeyboardInterrupt: diff --git a/docs/conf.py b/docs/conf.py index 9d61dc717..2c621bf41 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,20 +36,20 @@ # topics/design.rst discusses undocumented APIs ("py:meth", "client.WebSocketClientProtocol.handshake"), ("py:meth", "server.WebSocketServerProtocol.handshake"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.is_client"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.messages"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.close_connection"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.close_connection_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.transfer_data"), - ("py:attr", "legacy.protocol.WebSocketCommonProtocol.transfer_data_task"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_open"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.ensure_open"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.fail_connection"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.connection_lost"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.read_message"), - ("py:meth", "legacy.protocol.WebSocketCommonProtocol.write_frame"), + ("py:attr", "protocol.WebSocketCommonProtocol.is_client"), + ("py:attr", "protocol.WebSocketCommonProtocol.messages"), + ("py:meth", "protocol.WebSocketCommonProtocol.close_connection"), + ("py:attr", "protocol.WebSocketCommonProtocol.close_connection_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.keepalive_ping"), + ("py:attr", "protocol.WebSocketCommonProtocol.keepalive_ping_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.transfer_data"), + ("py:attr", "protocol.WebSocketCommonProtocol.transfer_data_task"), + ("py:meth", "protocol.WebSocketCommonProtocol.connection_open"), + ("py:meth", "protocol.WebSocketCommonProtocol.ensure_open"), + ("py:meth", "protocol.WebSocketCommonProtocol.fail_connection"), + ("py:meth", "protocol.WebSocketCommonProtocol.connection_lost"), + ("py:meth", "protocol.WebSocketCommonProtocol.read_message"), + ("py:meth", "protocol.WebSocketCommonProtocol.write_frame"), ] # Add any Sphinx extension module names here, as strings. They can be diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index e77f50add..3bc381cfd 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -1,7 +1,12 @@ Using asyncio ============= -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.connection + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. How do I run two coroutines in parallel? ---------------------------------------- @@ -9,8 +14,8 @@ How do I run two coroutines in parallel? You must start two tasks, which the event loop will run concurrently. You can achieve this with :func:`asyncio.gather` or :func:`asyncio.create_task`. -Keep track of the tasks and make sure they terminate or you cancel them when -the connection terminates. +Keep track of the tasks and make sure that they terminate or that you cancel +them when the connection terminates. Why does my program never receive any messages? ----------------------------------------------- @@ -22,13 +27,12 @@ Putting an ``await`` statement in a ``for`` or a ``while`` loop isn't enough to yield control. Awaiting a coroutine may yield control, but there's no guarantee that it will. -For example, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` only yields -control when send buffers are full, which never happens in most practical -cases. +For example, :meth:`~Connection.send` only yields control when send buffers are +full, which never happens in most practical cases. -If you run a loop that contains only synchronous operations and -a :meth:`~legacy.protocol.WebSocketCommonProtocol.send` call, you must yield -control explicitly with :func:`asyncio.sleep`:: +If you run a loop that contains only synchronous operations and a +:meth:`~Connection.send` call, you must yield control explicitly with +:func:`asyncio.sleep`:: async def producer(websocket): message = generate_next_message() @@ -46,16 +50,19 @@ See `issue 867`_. Why am I having problems with threads? -------------------------------------- -If you choose websockets' default implementation based on :mod:`asyncio`, then -you shouldn't use threads. Indeed, choosing :mod:`asyncio` to handle concurrency -is mutually exclusive with :mod:`threading`. +If you choose websockets' :mod:`asyncio` implementation, then you shouldn't use +threads. Indeed, choosing :mod:`asyncio` to handle concurrency is mutually +exclusive with :mod:`threading`. If you believe that you need to run websockets in a thread and some logic in another thread, you should run that logic in a :class:`~asyncio.Task` instead. -If it blocks the event loop, :meth:`~asyncio.loop.run_in_executor` will help. -This question is really about :mod:`asyncio`. Please review the advice about -:ref:`asyncio-multithreading` in the Python documentation. +If it has to run in another thread because it would block the event loop, +:func:`~asyncio.to_thread` or :meth:`~asyncio.loop.run_in_executor` is the way +to go. + +Please review the advice about :ref:`asyncio-multithreading` in the Python +documentation. Why does my simple program misbehave mysteriously? -------------------------------------------------- @@ -63,7 +70,6 @@ Why does my simple program misbehave mysteriously? You are using :func:`time.sleep` instead of :func:`asyncio.sleep`, which blocks the event loop and prevents asyncio from operating normally. -This may lead to messages getting send but not received, to connection -timeouts, and to unexpected results of shotgun debugging e.g. adding an -unnecessary call to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` -makes the program functional. +This may lead to messages getting send but not received, to connection timeouts, +and to unexpected results of shotgun debugging e.g. adding an unnecessary call +to a coroutine makes the program functional. diff --git a/docs/faq/client.rst b/docs/faq/client.rst index c590ac107..0dfc84253 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -1,7 +1,16 @@ Client ====== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.client + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. + + They translate to the :mod:`threading` implementation by removing ``await`` + and ``async`` keywords and by using a :class:`~threading.Thread` instead of + a :class:`~asyncio.Task` for concurrent execution. Why does the client close the connection prematurely? ----------------------------------------------------- @@ -22,46 +31,47 @@ change it to:: How do I access HTTP headers? ----------------------------- -Once the connection is established, HTTP headers are available in -:attr:`~client.WebSocketClientProtocol.request_headers` and -:attr:`~client.WebSocketClientProtocol.response_headers`. +Once the connection is established, HTTP headers are available in the +:attr:`~ClientConnection.request` and :attr:`~ClientConnection.response` +objects:: + + async with connect(...) as websocket: + websocket.request.headers + websocket.response.headers How do I set HTTP headers? -------------------------- To set the ``Origin``, ``Sec-WebSocket-Extensions``, or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake request, use the -``origin``, ``extensions``, or ``subprotocols`` arguments of -:func:`~client.connect`. +``origin``, ``extensions``, or ``subprotocols`` arguments of :func:`~connect`. To override the ``User-Agent`` header, use the ``user_agent_header`` argument. Set it to :obj:`None` to remove the header. To set other HTTP headers, for example the ``Authorization`` header, use the -``extra_headers`` argument:: +``additional_headers`` argument:: - async with connect(..., extra_headers={"Authorization": ...}) as websocket: + async with connect(..., additional_headers={"Authorization": ...}) as websocket: ... -In the :mod:`threading` API, this argument is named ``additional_headers``:: - - with connect(..., additional_headers={"Authorization": ...}) as websocket: - ... +In the legacy :mod:`asyncio` API, this argument is named ``extra_headers``. How do I force the IP address that the client connects to? ---------------------------------------------------------- -Use the ``host`` argument of :meth:`~asyncio.loop.create_connection`:: +Use the ``host`` argument :func:`~connect`:: - await websockets.connect("ws://example.com", host="192.168.0.1") + async with connect(..., host="192.168.0.1") as websocket: + ... -:func:`~client.connect` accepts the same arguments as -:meth:`~asyncio.loop.create_connection`. +:func:`~connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection` and passes them through. How do I close a connection? ---------------------------- -The easiest is to use :func:`~client.connect` as a context manager:: +The easiest is to use :func:`~connect` as a context manager:: async with connect(...) as websocket: ... @@ -71,9 +81,17 @@ The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- -Use :func:`~client.connect` as an asynchronous iterator:: +.. admonition:: This feature is only supported by the legacy :mod:`asyncio` + implementation. + :class: warning + + It will be added to the new :mod:`asyncio` implementation soon. + +Use :func:`~websockets.legacy.client.connect` as an asynchronous iterator:: + + from websockets.legacy.client import connect - async for websocket in websockets.connect(...): + async for websocket in connect(...): try: ... except websockets.ConnectionClosed: @@ -90,12 +108,12 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_client.py - :emphasize-lines: 10-13 + :emphasize-lines: 11-13 How do I disable TLS/SSL certificate verification? -------------------------------------------------- Look at the ``ssl`` argument of :meth:`~asyncio.loop.create_connection`. -:func:`~client.connect` accepts the same arguments as -:meth:`~asyncio.loop.create_connection`. +:func:`~connect` accepts the same arguments as +:meth:`~asyncio.loop.create_connection` and passes them through. diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 84256fdfe..0dc4a3aeb 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -1,7 +1,7 @@ Both sides ========== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.connection What does ``ConnectionClosedError: no close frame received or sent`` mean? -------------------------------------------------------------------------- @@ -11,12 +11,6 @@ If you're seeing this traceback in the logs of a server: .. code-block:: pytb connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.IncompleteReadError: 0 bytes read on a total of 2 expected bytes - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: no close frame received or sent @@ -25,12 +19,6 @@ or if a client crashes with this traceback: .. code-block:: pytb - Traceback (most recent call last): - ... - ConnectionResetError: [Errno 54] Connection reset by peer - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: no close frame received or sent @@ -39,8 +27,8 @@ it means that the TCP connection was lost. As a consequence, the WebSocket connection was closed without receiving and sending a close frame, which is abnormal. -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. +You can catch and handle :exc:`~websockets.exceptions.ConnectionClosed` to +prevent it from being logged. There are several reasons why long-lived connections may be lost: @@ -62,12 +50,6 @@ If you're seeing this traceback in the logs of a server: .. code-block:: pytb connection handler failed - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received @@ -76,12 +58,6 @@ or if a client crashes with this traceback: .. code-block:: pytb - Traceback (most recent call last): - ... - asyncio.exceptions.CancelledError - - The above exception was the direct cause of the following exception: - Traceback (most recent call last): ... websockets.exceptions.ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received @@ -89,8 +65,8 @@ or if a client crashes with this traceback: it means that the WebSocket connection suffered from excessive latency and was closed after reaching the timeout of websockets' keepalive mechanism. -You can catch and handle :exc:`~exceptions.ConnectionClosed` to prevent it -from being logged. +You can catch and handle :exc:`~websockets.exceptions.ConnectionClosed` to +prevent it from being logged. There are two main reasons why latency may increase: @@ -102,8 +78,8 @@ See the discussion of :doc:`keepalive <../topics/keepalive>` for details. If websockets' default timeout of 20 seconds is too short for your use case, you can adjust it with the ``ping_timeout`` argument. -How do I set a timeout on :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`? --------------------------------------------------------------------------------- +How do I set a timeout on :meth:`~Connection.recv`? +--------------------------------------------------- On Python ≥ 3.11, use :func:`asyncio.timeout`:: @@ -117,23 +93,24 @@ On older versions of Python, use :func:`asyncio.wait_for`:: This technique works for most APIs. When it doesn't, for example with asynchronous context managers, websockets provides an ``open_timeout`` argument. -How can I pass arguments to a custom protocol subclass? -------------------------------------------------------- +How can I pass arguments to a custom connection subclass? +--------------------------------------------------------- -You can bind additional arguments to the protocol factory with +You can bind additional arguments to the connection factory with :func:`functools.partial`:: import asyncio import functools - import websockets + from websockets.asyncio.server import ServerConnection, serve - class MyServerProtocol(websockets.WebSocketServerProtocol): + class MyServerConnection(ServerConnection): def __init__(self, *args, extra_argument=None, **kwargs): super().__init__(*args, **kwargs) # do something with extra_argument - create_protocol = functools.partial(MyServerProtocol, extra_argument=42) - start_server = websockets.serve(..., create_protocol=create_protocol) + create_connection = functools.partial(ServerConnection, extra_argument=42) + async with serve(..., create_connection=create_connection): + ... This example was for a server. The same pattern applies on a client. diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 0e74a784f..4936aa6f3 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -13,27 +13,12 @@ Often, this is because you created a script called ``websockets.py`` in your current working directory. Then ``import websockets`` imports this module instead of the websockets library. -.. _real-import-paths: - -Why is the default implementation located in ``websockets.legacy``? -................................................................... - -This is an artifact of websockets' history. For its first eight years, only the -:mod:`asyncio` implementation existed. Then, the Sans-I/O implementation was -added. Moving the code in a ``legacy`` submodule eased this refactoring and -optimized maintainability. - -All public APIs were kept at their original locations. ``websockets.legacy`` -isn't a public API. It's only visible in the source code and in stack traces. -There is no intent to deprecate this implementation — at least until a superior -alternative exists. - Why is websockets slower than another library in my benchmark? .............................................................. Not all libraries are as feature-complete as websockets. For a fair benchmark, you should disable features that the other library doesn't provide. Typically, -you may need to disable: +you must disable: * Compression: set ``compression=None`` * Keepalive: set ``ping_interval=None`` diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 53e34632f..e6b068316 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -1,7 +1,16 @@ Server ====== -.. currentmodule:: websockets +.. currentmodule:: websockets.asyncio.server + +.. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. + :class: hint + + Answers are also valid for the legacy :mod:`asyncio` implementation. + + They translate to the :mod:`threading` implementation by removing ``await`` + and ``async`` keywords and by using a :class:`~threading.Thread` instead of + a :class:`~asyncio.Task` for concurrent execution. Why does the server close the connection prematurely? ----------------------------------------------------- @@ -36,8 +45,13 @@ change it like this:: async for message in websocket: print(message) -*Don't feel bad if this happens to you — it's the most common question in -websockets' issue tracker :-)* +If you have prior experience with an API that relies on callbacks, you may +assume that ``handler()`` is executed every time a message is received. The API +of websockets relies on coroutines instead. + +The handler coroutine is started when a new connection is established. Then, it +is responsible for receiving or sending messages throughout the lifetime of that +connection. Why can only one client connect at a time? ------------------------------------------ @@ -69,9 +83,9 @@ continuously:: while True: await websocket.send("firehose!") -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` completes synchronously as -long as there's space in send buffers. The event loop never runs. (This pattern -is uncommon in real-world applications. It occurs mostly in toy programs.) +:meth:`~ServerConnection.send` completes synchronously as long as there's space +in send buffers. The event loop never runs. (This pattern is uncommon in +real-world applications. It occurs mostly in toy programs.) You can avoid the issue by yielding control to the event loop explicitly:: @@ -102,12 +116,12 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~asyncio.connection.broadcast`:: +Then, call :func:`~websockets.asyncio.connection.broadcast`:: - import websockets + from websockets.asyncio.connection import broadcast def message_all(message): - websockets.broadcast(CONNECTIONS, message) + broadcast(CONNECTIONS, message) If you're running multiple server processes, make sure you call ``message_all`` in each process. @@ -129,7 +143,7 @@ Record connections in a global variable, keyed by user identifier:: finally: del CONNECTIONS[user_id] -Then, call :meth:`~legacy.protocol.WebSocketCommonProtocol.send`:: +Then, call :meth:`~ServerConnection.send`:: async def message_user(user_id, message): websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected @@ -178,15 +192,12 @@ How do I pass arguments to the connection handler? You can bind additional arguments to the connection handler with :func:`functools.partial`:: - import asyncio import functools - import websockets async def handler(websocket, extra_argument): ... bound_handler = functools.partial(handler, extra_argument=42) - start_server = websockets.serve(bound_handler, ...) Another way to achieve this result is to define the ``handler`` coroutine in a scope where the ``extra_argument`` variable exists instead of injecting it @@ -195,14 +206,14 @@ through an argument. How do I access the request path? --------------------------------- -It is available in the :attr:`~server.WebSocketServerProtocol.path` attribute. +It is available in the :attr:`~ServerConnection.request` object. You may route a connection to different handlers depending on the request path:: async def handler(websocket): - if websocket.path == "/blue": + if websocket.request.path == "/blue": await blue_handler(websocket) - elif websocket.path == "/green": + elif websocket.request.path == "/green": await green_handler(websocket) else: # No handler for this path; close the connection. @@ -219,35 +230,46 @@ it may ignore the request path entirely. How do I access HTTP headers? ----------------------------- -To access HTTP headers during the WebSocket handshake, you can override -:attr:`~server.WebSocketServerProtocol.process_request`:: +You can access HTTP headers during the WebSocket handshake by providing a +``process_request`` callable or coroutine:: - async def process_request(self, path, request_headers): - authorization = request_headers["Authorization"] + def process_request(connection, request): + authorization = request.headers["Authorization"] + ... + + async with serve(handler, process_request=process_request): + ... -Once the connection is established, HTTP headers are available in -:attr:`~server.WebSocketServerProtocol.request_headers` and -:attr:`~server.WebSocketServerProtocol.response_headers`:: +Once the connection is established, HTTP headers are available in the +:attr:`~ServerConnection.request` and :attr:`~ServerConnection.response` +objects:: async def handler(websocket): - authorization = websocket.request_headers["Authorization"] + authorization = websocket.request.headers["Authorization"] How do I set HTTP headers? -------------------------- To set the ``Sec-WebSocket-Extensions`` or ``Sec-WebSocket-Protocol`` headers in the WebSocket handshake response, use the ``extensions`` or ``subprotocols`` -arguments of :func:`~server.serve`. +arguments of :func:`~serve`. To override the ``Server`` header, use the ``server_header`` argument. Set it to :obj:`None` to remove the header. -To set other HTTP headers, use the ``extra_headers`` argument. +To set other HTTP headers, provide a ``process_response`` callable or +coroutine:: + + def process_response(connection, request, response): + response.headers["X-Blessing"] = "May the network be with you" + + async with serve(handler, process_response=process_response): + ... How do I get the IP address of the client? ------------------------------------------ -It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address`:: +It's available in :attr:`~ServerConnection.remote_address`:: async def handler(websocket): remote_ip = websocket.remote_address[0] @@ -255,18 +277,19 @@ It's available in :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address How do I set the IP addresses that my server listens on? -------------------------------------------------------- -Use the ``host`` argument of :meth:`~asyncio.loop.create_server`:: +Use the ``host`` argument of :meth:`~serve`:: - await websockets.serve(handler, host="192.168.0.1", port=8080) + async with serve(handler, host="192.168.0.1", port=8080): + ... -:func:`~server.serve` accepts the same arguments as -:meth:`~asyncio.loop.create_server`. +:func:`~serve` accepts the same arguments as +:meth:`~asyncio.loop.create_server` and passes them through. What does ``OSError: [Errno 99] error while attempting to bind on address ('::1', 80, 0, 0): address not available`` mean? -------------------------------------------------------------------------------------------------------------------------- -You are calling :func:`~server.serve` without a ``host`` argument in a context -where IPv6 isn't available. +You are calling :func:`~serve` without a ``host`` argument in a context where +IPv6 isn't available. To listen only on IPv4, specify ``host="0.0.0.0"`` or ``family=socket.AF_INET``. @@ -280,17 +303,17 @@ websockets takes care of closing the connection when the handler exits. How do I stop a server? ----------------------- -Exit the :func:`~server.serve` context manager. +Exit the :func:`~serve` context manager. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 12-15,18 + :emphasize-lines: 13-16,19 How do I stop a server while keeping existing connections open? --------------------------------------------------------------- -Call the server's :meth:`~server.WebSocketServer.close` method with +Call the server's :meth:`~WebSocketServer.close` method with ``close_connections=False``. Here's how to adapt the example just above:: @@ -298,7 +321,7 @@ Here's how to adapt the example just above:: async def server(): ... - server = await websockets.serve(echo, "localhost", 8765) + server = await serve(echo, "localhost", 8765) await stop server.close(close_connections=False) await server.wait_closed() @@ -306,14 +329,14 @@ Here's how to adapt the example just above:: How do I implement a health check? ---------------------------------- -Intercept WebSocket handshake requests with the -:meth:`~server.WebSocketServerProtocol.process_request` hook. - -When a request is sent to the health check endpoint, treat is as an HTTP request -and return a ``(status, headers, body)`` tuple, as in this example: +Intercept requests with the ``process_request`` hook. When a request is sent to +the health check endpoint, treat is as an HTTP request and return a response: .. literalinclude:: ../../example/faq/health_check_server.py - :emphasize-lines: 7-9,18 + :emphasize-lines: 7-9,16 + +:meth:`~ServerConnection.respond` makes it easy to send a plain text response. +You can also construct a :class:`~websockets.http11.Response` object directly. How do I run HTTP and WebSocket servers on the same port? --------------------------------------------------------- @@ -327,7 +350,7 @@ Providing an HTTP server is out of scope for websockets. It only aims at providing a WebSocket server. There's limited support for returning HTTP responses with the -:attr:`~server.WebSocketServerProtocol.process_request` hook. +``process_request`` hook. If you need more, pick an HTTP server and run it separately. diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst index 95b551f67..8df2f234b 100644 --- a/docs/howto/cheatsheet.rst +++ b/docs/howto/cheatsheet.rst @@ -9,24 +9,24 @@ Server * Write a coroutine that handles a single connection. It receives a WebSocket protocol instance and the URI path in argument. - * Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to receive and send - messages at any time. + * Call :meth:`~asyncio.connection.Connection.recv` and + :meth:`~asyncio.connection.Connection.send` to receive and send messages at + any time. - * When :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` raises - :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started - other :class:`asyncio.Task`, terminate them before exiting. + * When :meth:`~asyncio.connection.Connection.recv` or + :meth:`~asyncio.connection.Connection.send` raises + :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started other + :class:`asyncio.Task`, terminate them before exiting. - * If you aren't awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, - consider awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed` - to detect quickly when the connection is closed. + * If you aren't awaiting :meth:`~asyncio.connection.Connection.recv`, consider + awaiting :meth:`~asyncio.connection.Connection.wait_closed` to detect + quickly when the connection is closed. - * You may :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't - needed in general. + * You may :meth:`~asyncio.connection.Connection.ping` or + :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed + in general. -* Create a server with :func:`~server.serve` which is similar to asyncio's +* Create a server with :func:`~asyncio.server.serve` which is similar to asyncio's :meth:`~asyncio.loop.create_server`. You can also use it as an asynchronous context manager. @@ -35,30 +35,30 @@ Server handler exits normally or with an exception. * For advanced customization, you may subclass - :class:`~server.WebSocketServerProtocol` and pass either this subclass or - a factory function as the ``create_protocol`` argument. + :class:`~asyncio.server.ServerConnection` and pass either this subclass or a + factory function as the ``create_connection`` argument. Client ------ -* Create a client with :func:`~client.connect` which is similar to asyncio's - :meth:`~asyncio.loop.create_connection`. You can also use it as an +* Create a client with :func:`~asyncio.client.connect` which is similar to + asyncio's :meth:`~asyncio.loop.create_connection`. You can also use it as an asynchronous context manager. * For advanced customization, you may subclass - :class:`~client.WebSocketClientProtocol` and pass either this subclass or - a factory function as the ``create_protocol`` argument. + :class:`~asyncio.client.ClientConnection` and pass either this subclass or + a factory function as the ``create_connection`` argument. -* Call :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` and - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to receive and send messages - at any time. +* Call :meth:`~asyncio.connection.Connection.recv` and + :meth:`~asyncio.connection.Connection.send` to receive and send messages at + any time. -* You may :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` or - :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` if you wish but it isn't - needed in general. +* You may :meth:`~asyncio.connection.Connection.ping` or + :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed in + general. -* If you aren't using :func:`~client.connect` as a context manager, call - :meth:`~legacy.protocol.WebSocketCommonProtocol.close` to terminate the connection. +* If you aren't using :func:`~asyncio.client.connect` as a context manager, call + :meth:`~asyncio.connection.Connection.close` to terminate the connection. .. _debugging: @@ -84,4 +84,3 @@ particular. Fortunately Python's official documentation provides advice to `develop with asyncio`_. Check it out: it's invaluable! .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html - diff --git a/docs/howto/django.rst b/docs/howto/django.rst index e3da0a878..dada9c5e4 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -124,7 +124,7 @@ support asynchronous I/O. It would block the event loop if it didn't run in a separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. -Finally, we start a server with :func:`~websockets.server.serve`. +Finally, we start a server with :func:`~websockets.asyncio.server.serve`. We're ready to test! diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index a97d2e7ce..b335e14c5 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -42,7 +42,7 @@ Here's the implementation of the app, an echo server. Save it in a file called Heroku expects the server to `listen on a specific port`_, which is provided in the ``$PORT`` environment variable. The app reads it and passes it to -:func:`~websockets.server.serve`. +:func:`~websockets.asyncio.server.serve`. .. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index ff42c3c2b..872353cad 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -21,9 +21,9 @@ We'd like nginx to connect to websockets servers via Unix sockets in order to avoid the overhead of TCP for communicating between processes running in the same OS. -We start the app with :func:`~websockets.server.unix_serve`. Each server -process listens on a different socket thanks to an environment variable set -by Supervisor to a different value. +We start the app with :func:`~websockets.asyncio.server.unix_serve`. Each server +process listens on a different socket thanks to an environment variable set by +Supervisor to a different value. Save this configuration to ``supervisord.conf``: diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst index c6f325d21..60bc8ab42 100644 --- a/docs/howto/patterns.rst +++ b/docs/howto/patterns.rst @@ -8,7 +8,7 @@ client. You will certainly implement some of them in your application. This page gives examples of connection handlers for a server. However, they're also applicable to a client, simply by assuming that ``websocket`` is a -connection created with :func:`~client.connect`. +connection created with :func:`~asyncio.client.connect`. WebSocket connections are long-lived. You will usually write a loop to process several messages during the lifetime of a connection. @@ -42,10 +42,10 @@ In this example, ``producer()`` is a coroutine implementing your business logic for generating the next message to send on the WebSocket connection. Each message must be :class:`str` or :class:`bytes`. -Iteration terminates when the client disconnects -because :meth:`~server.WebSocketServerProtocol.send` raises a -:exc:`~exceptions.ConnectionClosed` exception, -which breaks out of the ``while True`` loop. +Iteration terminates when the client disconnects because +:meth:`~asyncio.server.ServerConnection.send` raises a +:exc:`~exceptions.ConnectionClosed` exception, which breaks out of the ``while +True`` loop. Consumer and producer --------------------- diff --git a/docs/howto/quickstart.rst b/docs/howto/quickstart.rst index ab870952c..e6bd362a4 100644 --- a/docs/howto/quickstart.rst +++ b/docs/howto/quickstart.rst @@ -17,9 +17,9 @@ It receives a name from the client, sends a greeting, and closes the connection. :language: python :linenos: -:func:`~server.serve` executes the connection handler coroutine ``hello()`` -once for each WebSocket connection. It closes the WebSocket connection when -the handler returns. +:func:`~asyncio.server.serve` executes the connection handler coroutine +``hello()`` once for each WebSocket connection. It closes the WebSocket +connection when the handler returns. Here's a corresponding WebSocket client. @@ -30,8 +30,8 @@ It sends a name to the server, receives a greeting, and closes the connection. :language: python :linenos: -Using :func:`~client.connect` as an asynchronous context manager ensures the -WebSocket connection is closed. +Using :func:`~asyncio.client.connect` as an asynchronous context manager ensures +the WebSocket connection is closed. .. _secure-server-example: @@ -73,8 +73,8 @@ In this example, the client needs a TLS context because the server uses a self-signed certificate. When connecting to a secure WebSocket server with a valid certificate — any -certificate signed by a CA that your Python installation trusts — you can -simply pass ``ssl=True`` to :func:`~client.connect`. +certificate signed by a CA that your Python installation trusts — you can simply +pass ``ssl=True`` to :func:`~asyncio.client.connect`. .. admonition:: Configure the TLS context securely :class: attention diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 16b010aca..40c8c5ec9 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -49,12 +49,13 @@ the release notes of the version in which the feature was deprecated. * The ``path`` argument of connection handlers — unnecessary since :ref:`10.1` and deprecated in :ref:`13.0`. -* The ``loop`` and ``legacy_recv`` arguments of :func:`~client.connect` and - :func:`~server.serve`, which were removed — deprecated in :ref:`10.0`. -* The ``timeout`` and ``klass`` arguments of :func:`~client.connect` and - :func:`~server.serve`, which were renamed to ``close_timeout`` and +* The ``loop`` and ``legacy_recv`` arguments of :func:`~legacy.client.connect` + and :func:`~legacy.server.serve`, which were removed — deprecated in + :ref:`10.0`. +* The ``timeout`` and ``klass`` arguments of :func:`~legacy.client.connect` and + :func:`~legacy.server.serve`, which were renamed to ``close_timeout`` and ``create_protocol`` — deprecated in :ref:`7.0` and :ref:`3.4` respectively. -* An empty string in the ``origins`` argument of :func:`~server.serve` — +* An empty string in the ``origins`` argument of :func:`~legacy.server.serve` — deprecated in :ref:`7.0`. * The ``host``, ``port``, and ``secure`` attributes of connections — deprecated in :ref:`8.0`. @@ -127,16 +128,16 @@ Client APIs | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ | ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | -| :func:`websockets.client.connect` |br| | | -| ``websockets.legacy.client.connect()`` | | +| ``websockets.client.connect()`` |br| | | +| :func:`websockets.legacy.client.connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | -| :func:`websockets.client.unix_connect` |br| | | -| ``websockets.legacy.client.unix_connect()`` | | +| ``websockets.client.unix_connect()`` |br| | | +| :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | -| :class:`websockets.client.WebSocketClientProtocol` |br| | | -| ``websockets.legacy.client.WebSocketClientProtocol`` | | +| ``websockets.client.WebSocketClientProtocol`` |br| | | +| :class:`websockets.legacy.client.WebSocketClientProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ Server APIs @@ -146,31 +147,31 @@ Server APIs | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ | ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | -| :func:`websockets.server.serve` |br| | | -| ``websockets.legacy.server.serve()`` | | +| ``websockets.server.serve()`` |br| | | +| :func:`websockets.legacy.server.serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | -| :func:`websockets.server.unix_serve` |br| | | -| ``websockets.legacy.server.unix_serve()`` | | +| ``websockets.server.unix_serve()`` |br| | | +| :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | -| :class:`websockets.server.WebSocketServer` |br| | | -| ``websockets.legacy.server.WebSocketServer`` | | +| ``websockets.server.WebSocketServer`` |br| | | +| :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | -| :class:`websockets.server.WebSocketServerProtocol` |br| | | -| ``websockets.legacy.server.WebSocketServerProtocol`` | | +| ``websockets.server.WebSocketServerProtocol`` |br| | | +| :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | | :func:`websockets.legacy.protocol.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | -| :class:`websockets.auth.BasicAuthWebSocketServerProtocol` |br| | | -| ``websockets.legacy.auth.BasicAuthWebSocketServerProtocol`` | | +| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | +| :class:`websockets.legacy.auth.BasicAuthWebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | -| :func:`websockets.auth.basic_auth_protocol_factory` |br| | | -| ``websockets.legacy.auth.basic_auth_protocol_factory()`` | | +| ``websockets.auth.basic_auth_protocol_factory()`` |br| | | +| :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ .. _Review API changes: @@ -209,12 +210,12 @@ Customizing the opening handshake ................................. On the client side, if you're adding headers to the handshake request sent by -:func:`~client.connect` with the ``extra_headers`` argument, you must rename it -to ``additional_headers``. +:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must +rename it to ``additional_headers``. -On the server side, if you're customizing how :func:`~server.serve` processes -the opening handshake with the ``process_request``, ``extra_headers``, or -``select_subprotocol``, you must update your code. ``process_response`` and +On the server side, if you're customizing how :func:`~legacy.server.serve` +processes the opening handshake with the ``process_request``, ``extra_headers``, +or ``select_subprotocol``, you must update your code. ``process_response`` and ``select_subprotocol`` have new signatures; ``process_response`` replaces ``extra_headers`` and provides more flexibility. @@ -242,10 +243,10 @@ an example:: ``connection`` is always available in ``process_request``. In the original implementation, you had to write a subclass of -:class:`~server.WebSocketServerProtocol` and pass it in the ``create_protocol`` -argument to make the connection object available in a ``process_request`` -method. This pattern isn't useful anymore; you can replace it with a -``process_request`` function or coroutine. +:class:`~legacy.server.WebSocketServerProtocol` and pass it in the +``create_protocol`` argument to make the connection object available in a +``process_request`` method. This pattern isn't useful anymore; you can replace +it with a ``process_request`` function or coroutine. ``path`` and ``headers`` are available as attributes of the ``request`` object. @@ -296,7 +297,7 @@ The signature of ``select_subprotocol`` changed. Here's an example:: ``connection`` is always available in ``select_subprotocol``. This brings the same benefits as in ``process_request``. It may remove the need to subclass of -:class:`~server.WebSocketServerProtocol`. +:class:`~legacy.server.WebSocketServerProtocol`. The ``subprotocols`` argument contains the list of subprotocols offered by the client. The list of subprotocols supported by the server was removed because @@ -320,7 +321,7 @@ update its name. The keyword argument of :func:`~asyncio.server.serve` for customizing the creation of the connection object is now called ``create_connection`` instead of ``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` -instead of a :class:`~server.WebSocketServerProtocol`. +instead of a :class:`~legacy.server.WebSocketServerProtocol`. If you were customizing connection objects, you should check the new implementation and possibly redo your customization. Keep in mind that the @@ -364,6 +365,28 @@ The ``write_limit`` argument of :func:`~asyncio.client.connect` and Attributes of connections ......................... +``path``, ``request_headers`` and ``response_headers`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :attr:`~legacy.protocol.WebSocketCommonProtocol.path`, +:attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` properties are +replaced by :attr:`~asyncio.connection.Connection.request` and +:attr:`~asyncio.connection.Connection.response`, which provide a ``headers`` +attribute. + +If your code relies on them, you can replace:: + + connection.path + connection.request_headers + connection.response_headers + +with:: + + connection.request.path + connection.request.headers + connection.response.headers + ``open`` and ``closed`` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 6b32d47f6..74f5f79a3 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -184,7 +184,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: import asyncio - import websockets + from websockets.asyncio.server import serve async def handler(websocket): @@ -194,7 +194,7 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever @@ -204,8 +204,9 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: The entry point of this program is ``asyncio.run(main())``. It creates an asyncio event loop, runs the ``main()`` coroutine, and shuts down the loop. -The ``main()`` coroutine calls :func:`~server.serve` to start a websockets -server. :func:`~server.serve` takes three positional arguments: +The ``main()`` coroutine calls :func:`~asyncio.server.serve` to start a +websockets server. :func:`~asyncio.server.serve` takes three positional +arguments: * ``handler`` is a coroutine that manages a connection. When a client connects, websockets calls ``handler`` with the connection in argument. @@ -215,7 +216,7 @@ server. :func:`~server.serve` takes three positional arguments: on the same local network can connect. * The third argument is the port on which the server listens. -Invoking :func:`~server.serve` as an asynchronous context manager, in an +Invoking :func:`~asyncio.server.serve` as an asynchronous context manager, in an ``async with`` block, ensures that the server shuts down properly when terminating the program. @@ -258,11 +259,11 @@ stack trace of an exception: ... websockets.exceptions.ConnectionClosedOK: received 1000 (OK); then sent 1000 (OK) -Indeed, the server was waiting for the next message -with :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` when the client -disconnected. When this happens, websockets raises -a :exc:`~exceptions.ConnectionClosedOK` exception to let you know that you -won't receive another message on this connection. +Indeed, the server was waiting for the next message with +:meth:`~asyncio.server.ServerConnection.recv` when the client disconnected. +When this happens, websockets raises a :exc:`~exceptions.ConnectionClosedOK` +exception to let you know that you won't receive another message on this +connection. This exception creates noise in the server logs, making it more difficult to spot real errors when you add functionality to the server. Catch it in the @@ -551,13 +552,12 @@ Summary In this first part of the tutorial, you learned how to: -* build and run a WebSocket server in Python with :func:`~server.serve`; -* receive a message in a connection handler - with :meth:`~server.WebSocketServerProtocol.recv`; -* send a message in a connection handler - with :meth:`~server.WebSocketServerProtocol.send`; -* iterate over incoming messages with ``async for - message in websocket: ...``; +* build and run a WebSocket server in Python with :func:`~asyncio.server.serve`; +* receive a message in a connection handler with + :meth:`~asyncio.server.ServerConnection.recv`; +* send a message in a connection handler with + :meth:`~asyncio.server.ServerConnection.send`; +* iterate over incoming messages with ``async for message in websocket: ...``; * open a WebSocket connection in JavaScript with the ``WebSocket`` API; * send messages in a browser with ``WebSocket.send()``; * receive messages in a browser by listening to ``message`` events; diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 6fdec113b..21d51371b 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -93,9 +93,9 @@ called ``stop`` and registers a signal handler that sets the result of this future. The value of the future doesn't matter; it's only for waiting for ``SIGTERM``. -Then, by using :func:`~server.serve` as a context manager and exiting the -context when ``stop`` has a result, ``main()`` ensures that the server closes -connections cleanly and exits on ``SIGTERM``. +Then, by using :func:`~asyncio.server.serve` as a context manager and exiting +the context when ``stop`` has a result, ``main()`` ensures that the server +closes connections cleanly and exits on ``SIGTERM``. The app is now fully compatible with Heroku. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index eaabb2e9f..df5af54f4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -178,8 +178,8 @@ Backwards-incompatible changes As a consequence, calling ``WebSocket.close()`` without arguments in a browser isn't reported as an error anymore. -.. admonition:: :func:`~server.serve` times out on the opening handshake after - 10 seconds by default. +.. admonition:: :func:`~legacy.server.serve` times out on the opening handshake + after 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -200,7 +200,7 @@ New features See :func:`websockets.sync.client.connect` and :func:`websockets.sync.server.serve` for details. -* Added ``open_timeout`` to :func:`~server.serve`. +* Added ``open_timeout`` to :func:`~legacy.server.serve`. * Made it possible to close a server without closing existing connections. @@ -289,7 +289,7 @@ Bug fixes * Fixed backwards-incompatibility in 10.1 for connection handlers created with :func:`functools.partial`. -* Avoided leaking open sockets when :func:`~client.connect` is canceled. +* Avoided leaking open sockets when :func:`~legacy.client.connect` is canceled. .. _10.1: @@ -330,8 +330,8 @@ Improvements .. _AWS API Gateway: https://github.com/python-websockets/websockets/issues/1065 -* Mirrored the entire :class:`~asyncio.Server` API - in :class:`~server.WebSocketServer`. +* Mirrored the entire :class:`~asyncio.Server` API in + :class:`~legacy.server.WebSocketServer`. * Improved performance for large messages on ARM processors. @@ -364,9 +364,9 @@ Backwards-incompatible changes Python 3.10 for details. The ``loop`` parameter is also removed - from :class:`~server.WebSocketServer`. This should be transparent. + from :class:`~legacy.server.WebSocketServer`. This should be transparent. -.. admonition:: :func:`~client.connect` times out after 10 seconds by default. +.. admonition:: :func:`~legacy.client.connect` times out after 10 seconds by default. :class: note You can adjust the timeout with the ``open_timeout`` parameter. Set it to @@ -405,9 +405,9 @@ New features * Added :func:`~legacy.protocol.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using - :func:`~client.connect` as an asynchronous iterator. + :func:`~legacy.client.connect` as an asynchronous iterator. -* Added ``open_timeout`` to :func:`~client.connect`. +* Added ``open_timeout`` to :func:`~legacy.client.connect`. * Documented how to integrate with `Django `_. @@ -427,12 +427,12 @@ Improvements * Optimized processing of client-to-server messages when the C extension isn't available. -* Supported relative redirects in :func:`~client.connect`. +* Supported relative redirects in :func:`~legacy.client.connect`. * Handled TCP connection drops during the opening handshake. * Made it easier to customize authentication with - :meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`. + :meth:`~legacy.auth.BasicAuthWebSocketServerProtocol.check_credentials`. * Provided additional information in :exc:`~exceptions.ConnectionClosed` exceptions. @@ -590,7 +590,7 @@ Bug fixes ......... * Restored the ability to pass a socket with the ``sock`` parameter of - :func:`~server.serve`. + :func:`~legacy.server.serve`. * Removed an incorrect assertion when a connection drops. @@ -623,11 +623,12 @@ Backwards-incompatible changes .. admonition:: ``process_request`` is now expected to be a coroutine. :class: note - If you're passing a ``process_request`` argument to :func:`~server.serve` - or :class:`~server.WebSocketServerProtocol`, or if you're overriding - :meth:`~server.WebSocketServerProtocol.process_request` in a subclass, - define it with ``async def`` instead of ``def``. Previously, both were - supported. + If you're passing a ``process_request`` argument to + :func:`~legacy.server.serve` or + :class:`~legacy.server.WebSocketServerProtocol`, or if you're overriding + :meth:`~legacy.server.WebSocketServerProtocol.process_request` in a + subclass, define it with ``async def`` instead of ``def``. Previously, both + were supported. For backwards compatibility, functions are still accepted, but mixing functions and coroutines won't work in some inheritance scenarios. @@ -661,15 +662,15 @@ Backwards-incompatible changes New features ............ -* Added :func:`~auth.basic_auth_protocol_factory` to enforce HTTP - Basic Auth on the server side. +* Added :func:`~legacy.auth.basic_auth_protocol_factory` to enforce HTTP Basic + Auth on the server side. -* :func:`~client.connect` handles redirects from the server during the +* :func:`~legacy.client.connect` handles redirects from the server during the handshake. -* :func:`~client.connect` supports overriding ``host`` and ``port``. +* :func:`~legacy.client.connect` supports overriding ``host`` and ``port``. -* Added :func:`~client.unix_connect` for connecting to Unix sockets. +* Added :func:`~legacy.client.unix_connect` for connecting to Unix sockets. * Added support for asynchronous generators in :meth:`~legacy.protocol.WebSocketCommonProtocol.send` @@ -699,9 +700,8 @@ Improvements :exc:`~exceptions.ConnectionClosed` to tell apart normal connection termination from errors. -* Changed :meth:`WebSocketServer.close() - ` to perform a proper closing handshake - instead of failing the connection. +* Changed :meth:`WebSocketServer.close() ` + to perform a proper closing handshake instead of failing the connection. * Improved error messages when HTTP parsing fails. @@ -734,7 +734,7 @@ Backwards-incompatible changes See :class:`~legacy.protocol.WebSocketCommonProtocol` for details. .. admonition:: Termination of connections by :meth:`WebSocketServer.close() - ` changes. + ` changes. :class: caution Previously, connections handlers were canceled. Now, connections are @@ -758,15 +758,16 @@ Backwards-incompatible changes Concurrent calls lead to non-deterministic behavior because there are no guarantees about which coroutine will receive which message. -.. admonition:: The ``timeout`` argument of :func:`~server.serve` - and :func:`~client.connect` is renamed to ``close_timeout`` . +.. admonition:: The ``timeout`` argument of :func:`~legacy.server.serve` + and :func:`~legacy.client.connect` is renamed to ``close_timeout`` . :class: note This prevents confusion with ``ping_timeout``. For backwards compatibility, ``timeout`` is still supported. -.. admonition:: The ``origins`` argument of :func:`~server.serve` changes. +.. admonition:: The ``origins`` argument of :func:`~legacy.server.serve` + changes. :class: note Include :obj:`None` in the list rather than ``''`` to allow requests that @@ -786,10 +787,10 @@ New features ............ * Added ``process_request`` and ``select_subprotocol`` arguments to - :func:`~server.serve` and - :class:`~server.WebSocketServerProtocol` to facilitate customization of - :meth:`~server.WebSocketServerProtocol.process_request` and - :meth:`~server.WebSocketServerProtocol.select_subprotocol`. + :func:`~legacy.server.serve` and + :class:`~legacy.server.WebSocketServerProtocol` to facilitate customization of + :meth:`~legacy.server.WebSocketServerProtocol.process_request` and + :meth:`~legacy.server.WebSocketServerProtocol.select_subprotocol`. * Added support for sending fragmented messages. @@ -826,10 +827,10 @@ Backwards-incompatible changes several APIs are updated to use it. :class: caution - * The ``request_headers`` argument - of :meth:`~server.WebSocketServerProtocol.process_request` is now - a :class:`~datastructures.Headers` instead of - an ``http.client.HTTPMessage``. + * The ``request_headers`` argument of + :meth:`~legacy.server.WebSocketServerProtocol.process_request` is now a + :class:`~datastructures.Headers` instead of an + ``http.client.HTTPMessage``. * The ``request_headers`` and ``response_headers`` attributes of :class:`~legacy.protocol.WebSocketCommonProtocol` are now @@ -866,7 +867,7 @@ Bug fixes ......... * Fixed a regression in 5.0 that broke some invocations of - :func:`~server.serve` and :func:`~client.connect`. + :func:`~legacy.server.serve` and :func:`~legacy.client.connect`. .. _5.0: @@ -900,10 +901,10 @@ Backwards-incompatible changes New features ............ -* :func:`~client.connect` performs HTTP Basic Auth when the URI contains +* :func:`~legacy.client.connect` performs HTTP Basic Auth when the URI contains credentials. -* :func:`~server.unix_serve` can be used as an asynchronous context +* :func:`~legacy.server.unix_serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added the :attr:`~legacy.protocol.WebSocketCommonProtocol.closed` property @@ -979,7 +980,7 @@ Backwards-incompatible changes Compression should improve performance but it increases RAM and CPU use. If you want to disable compression, add ``compression=None`` when calling - :func:`~server.serve` or :func:`~client.connect`. + :func:`~legacy.server.serve` or :func:`~legacy.client.connect`. .. admonition:: The ``state_name`` attribute of protocols is deprecated. :class: note @@ -992,10 +993,10 @@ New features * :class:`~legacy.protocol.WebSocketCommonProtocol` instances can be used as asynchronous iterators on Python ≥ 3.6. They yield incoming messages. -* Added :func:`~server.unix_serve` for listening on Unix sockets. +* Added :func:`~legacy.server.unix_serve` for listening on Unix sockets. -* Added the :attr:`~server.WebSocketServer.sockets` attribute to the - return value of :func:`~server.serve`. +* Added the :attr:`~legacy.server.WebSocketServer.sockets` attribute to the + return value of :func:`~legacy.server.serve`. * Allowed ``extra_headers`` to override ``Server`` and ``User-Agent`` headers. @@ -1030,24 +1031,24 @@ Backwards-incompatible changes by :class:`~exceptions.InvalidStatusCode`. :class: note - This exception is raised when :func:`~client.connect` receives an invalid + This exception is raised when :func:`~legacy.client.connect` receives an invalid response status code from the server. New features ............ -* :func:`~server.serve` can be used as an asynchronous context manager +* :func:`~legacy.server.serve` can be used as an asynchronous context manager on Python ≥ 3.5.1. * Added support for customizing handling of incoming connections with - :meth:`~server.WebSocketServerProtocol.process_request`. + :meth:`~legacy.server.WebSocketServerProtocol.process_request`. * Made read and write buffer sizes configurable. Improvements ............ -* Renamed :func:`~server.serve` and :func:`~client.connect`'s +* Renamed :func:`~legacy.server.serve` and :func:`~legacy.client.connect`'s ``klass`` argument to ``create_protocol`` to reflect that it can also be a callable. For backwards compatibility, ``klass`` is still supported. @@ -1058,7 +1059,7 @@ Improvements Bug fixes ......... -* Providing a ``sock`` argument to :func:`~client.connect` no longer +* Providing a ``sock`` argument to :func:`~legacy.client.connect` no longer crashes. .. _3.3: @@ -1094,7 +1095,7 @@ New features ............ * Added ``timeout``, ``max_size``, and ``max_queue`` arguments to - :func:`~client.connect` and :func:`~server.serve`. + :func:`~legacy.client.connect` and :func:`~legacy.server.serve`. Improvements ............ @@ -1151,15 +1152,15 @@ Backwards-incompatible changes In order to avoid stranding projects built upon an earlier version, the previous behavior can be restored by passing ``legacy_recv=True`` to - :func:`~server.serve`, :func:`~client.connect`, - :class:`~server.WebSocketServerProtocol`, or - :class:`~client.WebSocketClientProtocol`. + :func:`~legacy.server.serve`, :func:`~legacy.client.connect`, + :class:`~legacy.server.WebSocketServerProtocol`, or + :class:`~legacy.client.WebSocketClientProtocol`. New features ............ -* :func:`~client.connect` can be used as an asynchronous context - manager on Python ≥ 3.5.1. +* :func:`~legacy.client.connect` can be used as an asynchronous context manager + on Python ≥ 3.5.1. * :meth:`~legacy.protocol.WebSocketCommonProtocol.ping` and :meth:`~legacy.protocol.WebSocketCommonProtocol.pong` support data passed as @@ -1260,8 +1261,8 @@ New features * Added support for subprotocols. -* Added ``loop`` argument to :func:`~client.connect` and - :func:`~server.serve`. +* Added ``loop`` argument to :func:`~legacy.client.connect` and + :func:`~legacy.server.serve`. .. _2.3: diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index f9ce2f2d8..77a3c5d53 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,24 +1,28 @@ -Client (legacy :mod:`asyncio`) -============================== +Client (new :mod:`asyncio`) +=========================== -.. automodule:: websockets.client +.. automodule:: websockets.asyncio.client Opening a connection -------------------- -.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: connect :async: -.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_connect :async: Using a connection ------------------ -.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -39,26 +43,15 @@ Using a connection .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: request - .. autoproperty:: close_code + .. autoattribute:: response - .. autoproperty:: close_reason + .. autoproperty:: subprotocol diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index aee774479..a58325fb9 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,14 +1,18 @@ :orphan: -Both sides (legacy :mod:`asyncio`) -================================== +Both sides (new :mod:`asyncio`) +=============================== -.. automodule:: websockets.legacy.protocol +.. automodule:: websockets.asyncio.connection -.. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: Connection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -29,26 +33,15 @@ Both sides (legacy :mod:`asyncio`) .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: + .. autoattribute:: request - .. autoproperty:: close_code + .. autoattribute:: response - .. autoproperty:: close_reason + .. autoproperty:: subprotocol diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 3636f0b33..7bceca5a0 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,19 +1,19 @@ -Server (legacy :mod:`asyncio`) -============================== +Server (new :mod:`asyncio`) +=========================== -.. automodule:: websockets.server +.. automodule:: websockets.asyncio.server -Starting a server +Creating a server ----------------- -.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: serve :async: -.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) +.. autofunction:: unix_serve :async: -Stopping a server ------------------ +Running a server +---------------- .. autoclass:: WebSocketServer @@ -34,10 +34,14 @@ Stopping a server Using a connection ------------------ -.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ .. automethod:: recv + .. automethod:: recv_streaming + .. automethod:: send .. automethod:: close @@ -48,11 +52,7 @@ Using a connection .. automethod:: pong - You can customize the opening handshake in a subclass by overriding these methods: - - .. automethod:: process_request - - .. automethod:: select_subprotocol + .. automethod:: respond WebSocket connection objects also provide these attributes: @@ -64,50 +64,20 @@ Using a connection .. autoproperty:: remote_address - .. autoproperty:: open - - .. autoproperty:: closed - .. autoattribute:: latency + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: - .. autoattribute:: path - - .. autoattribute:: request_headers - - .. autoattribute:: response_headers - - .. autoattribute:: subprotocol - - The following attributes are available after the closing handshake, - once the WebSocket connection is closed: - - .. autoproperty:: close_code - - .. autoproperty:: close_reason - - -Basic authentication --------------------- - -.. automodule:: websockets.auth - -websockets supports HTTP Basic Authentication according to -:rfc:`7235` and :rfc:`7617`. - -.. autofunction:: basic_auth_protocol_factory - -.. autoclass:: BasicAuthWebSocketServerProtocol - - .. autoattribute:: realm + .. autoattribute:: request - .. autoattribute:: username + .. autoattribute:: response - .. automethod:: check_credentials + .. autoproperty:: subprotocol Broadcast --------- -.. autofunction:: websockets.legacy.protocol.broadcast +.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 45fa79c48..6840fe15b 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -182,5 +182,6 @@ Request if it is missing or invalid (`#1246`). The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is -mandated by :rfc:`6455`, section 4.1. However, :func:`~client.connect()` isn't -the right layer for enforcing this constraint. It's the caller's responsibility. +mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` +isn't the right layer for enforcing this constraint. It's the caller's +responsibility. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index d3a0e935c..77b538b78 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -13,12 +13,12 @@ Check which implementations support which features and known limitations. features +:mod:`asyncio` (new) +-------------------- -:mod:`asyncio` --------------- +It's ideal for servers that handle many clients concurrently. -This is the default implementation. It's ideal for servers that handle many -clients concurrently. +It's a rewrite of the legacy :mod:`asyncio` implementation. .. toctree:: :titlesonly: @@ -26,17 +26,16 @@ clients concurrently. asyncio/server asyncio/client -:mod:`asyncio` (new) --------------------- +:mod:`asyncio` (legacy) +----------------------- -This is a rewrite of the :mod:`asyncio` implementation. It will become the -default in the future. +This is the historical implementation. .. toctree:: :titlesonly: - new-asyncio/server - new-asyncio/client + legacy/server + legacy/client :mod:`threading` ---------------- diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst new file mode 100644 index 000000000..fca45d218 --- /dev/null +++ b/docs/reference/legacy/client.rst @@ -0,0 +1,64 @@ +Client (legacy :mod:`asyncio`) +============================== + +.. automodule:: websockets.legacy.client + +Opening a connection +-------------------- + +.. autofunction:: connect(uri, *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +.. autofunction:: unix_connect(path, uri="ws://localhost/", *, create_protocol=None, logger=None, compression="deflate", origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +Using a connection +------------------ + +.. autoclass:: WebSocketClientProtocol(*, logger=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, user_agent_header="Python/x.y.z websockets/X.Y", ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst new file mode 100644 index 000000000..aee774479 --- /dev/null +++ b/docs/reference/legacy/common.rst @@ -0,0 +1,54 @@ +:orphan: + +Both sides (legacy :mod:`asyncio`) +================================== + +.. automodule:: websockets.legacy.protocol + +.. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst new file mode 100644 index 000000000..c2758f5a2 --- /dev/null +++ b/docs/reference/legacy/server.rst @@ -0,0 +1,113 @@ +Server (legacy :mod:`asyncio`) +============================== + +.. automodule:: websockets.legacy.server + +Starting a server +----------------- + +.. autofunction:: serve(ws_handler, host=None, port=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +.. autofunction:: unix_serve(ws_handler, path=None, *, create_protocol=None, logger=None, compression="deflate", origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, **kwds) + :async: + +Stopping a server +----------------- + +.. autoclass:: WebSocketServer + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: get_loop + + .. automethod:: is_serving + + .. automethod:: start_serving + + .. automethod:: serve_forever + + .. autoattribute:: sockets + +Using a connection +------------------ + +.. autoclass:: WebSocketServerProtocol(ws_handler, ws_server, *, logger=None, origins=None, extensions=None, subprotocols=None, extra_headers=None, server_header="Python/x.y.z websockets/X.Y", process_request=None, select_subprotocol=None, open_timeout=10, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) + + .. automethod:: recv + + .. automethod:: send + + .. automethod:: close + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + You can customize the opening handshake in a subclass by overriding these methods: + + .. automethod:: process_request + + .. automethod:: select_subprotocol + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoproperty:: open + + .. autoproperty:: closed + + .. autoattribute:: latency + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: path + + .. autoattribute:: request_headers + + .. autoattribute:: response_headers + + .. autoattribute:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + + +Basic authentication +-------------------- + +.. automodule:: websockets.legacy.auth + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: basic_auth_protocol_factory + +.. autoclass:: BasicAuthWebSocketServerProtocol + + .. autoattribute:: realm + + .. autoattribute:: username + + .. automethod:: check_credentials + +Broadcast +--------- + +.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/reference/new-asyncio/client.rst b/docs/reference/new-asyncio/client.rst deleted file mode 100644 index 77a3c5d53..000000000 --- a/docs/reference/new-asyncio/client.rst +++ /dev/null @@ -1,57 +0,0 @@ -Client (new :mod:`asyncio`) -=========================== - -.. automodule:: websockets.asyncio.client - -Opening a connection --------------------- - -.. autofunction:: connect - :async: - -.. autofunction:: unix_connect - :async: - -Using a connection ------------------- - -.. autoclass:: ClientConnection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/common.rst b/docs/reference/new-asyncio/common.rst deleted file mode 100644 index a58325fb9..000000000 --- a/docs/reference/new-asyncio/common.rst +++ /dev/null @@ -1,47 +0,0 @@ -:orphan: - -Both sides (new :mod:`asyncio`) -=============================== - -.. automodule:: websockets.asyncio.connection - -.. autoclass:: Connection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol diff --git a/docs/reference/new-asyncio/server.rst b/docs/reference/new-asyncio/server.rst deleted file mode 100644 index 7bceca5a0..000000000 --- a/docs/reference/new-asyncio/server.rst +++ /dev/null @@ -1,83 +0,0 @@ -Server (new :mod:`asyncio`) -=========================== - -.. automodule:: websockets.asyncio.server - -Creating a server ------------------ - -.. autofunction:: serve - :async: - -.. autofunction:: unix_serve - :async: - -Running a server ----------------- - -.. autoclass:: WebSocketServer - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: get_loop - - .. automethod:: is_serving - - .. automethod:: start_serving - - .. automethod:: serve_forever - - .. autoattribute:: sockets - -Using a connection ------------------- - -.. autoclass:: ServerConnection - - .. automethod:: __aiter__ - - .. automethod:: recv - - .. automethod:: recv_streaming - - .. automethod:: send - - .. automethod:: close - - .. automethod:: wait_closed - - .. automethod:: ping - - .. automethod:: pong - - .. automethod:: respond - - WebSocket connection objects also provide these attributes: - - .. autoattribute:: id - - .. autoattribute:: logger - - .. autoproperty:: local_address - - .. autoproperty:: remote_address - - .. autoattribute:: latency - - .. autoproperty:: state - - The following attributes are available after the opening handshake, - once the WebSocket connection is open: - - .. autoattribute:: request - - .. autoattribute:: response - - .. autoproperty:: subprotocol - -Broadcast ---------- - -.. autofunction:: websockets.asyncio.connection.broadcast diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 31bc8e6da..86d2e2587 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -212,7 +212,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class QueryParamProtocol(websockets.WebSocketServerProtocol): + from websockets.legacy.server import WebSocketServerProtocol + + class QueryParamProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): token = get_query_parameter(path, "token") if token is None: @@ -258,7 +260,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class CookieProtocol(websockets.WebSocketServerProtocol): + from websockets.legacy.server import WebSocketServerProtocol + + class CookieProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): # Serve iframe on non-WebSocket requests ... @@ -299,7 +303,9 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): + from websockets.legacy.auth import BasicAuthWebSocketServerProtocol + + class UserInfoProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): if username != "token": return False @@ -328,8 +334,10 @@ To authenticate a websockets client with HTTP Basic Authentication .. code-block:: python - async with websockets.connect( - f"wss://{username}:{password}@example.com", + from websockets.legacy.client import connect + + async with connect( + f"wss://{username}:{password}@example.com" ) as websocket: ... @@ -341,7 +349,9 @@ To authenticate a websockets client with HTTP Bearer Authentication .. code-block:: python - async with websockets.connect( + from websockets.legacy.client import connect + + async with connect( "wss://example.com", extra_headers={"Authorization": f"Bearer {token}"} ) as websocket: diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index 2a1fe9a78..48ef72b56 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -78,7 +78,7 @@ Option 2 almost always combines with option 3. How do I start a process? ......................... -Run a Python program that invokes :func:`~server.serve`. That's it. +Run a Python program that invokes :func:`~asyncio.server.serve`. That's it. Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're alternatives to websockets, not complements. @@ -98,18 +98,19 @@ signal and exit the server to ensure a graceful shutdown. Here's an example: .. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 12-15,18 + :emphasize-lines: 13-16,19 -When exiting the context manager, :func:`~server.serve` closes all connections +When exiting the context manager, :func:`~asyncio.server.serve` closes all +connections with code 1001 (going away). As a consequence: * If the connection handler is awaiting - :meth:`~server.WebSocketServerProtocol.recv`, it receives a + :meth:`~asyncio.server.ServerConnection.recv`, it receives a :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception and clean up before exiting. * Otherwise, it should be waiting on - :meth:`~server.WebSocketServerProtocol.wait_closed`, so it can receive the + :meth:`~asyncio.server.ServerConnection.wait_closed`, so it can receive the :exc:`~exceptions.ConnectionClosedOK` exception and exit. This example is easily adapted to handle other signals. @@ -173,7 +174,7 @@ Load balancers need a way to check whether server processes are up and running to avoid routing connections to a non-functional backend. websockets provide minimal support for responding to HTTP requests with the -:meth:`~server.WebSocketServerProtocol.process_request` hook. +``process_request`` hook. Here's an example: diff --git a/docs/topics/design.rst b/docs/topics/design.rst index cc65e6a70..d2fd18d0c 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,10 +1,11 @@ -Design -====== +Design (legacy :mod:`asyncio`) +============================== -.. currentmodule:: websockets +.. currentmodule:: websockets.legacy -This document describes the design of websockets. It assumes familiarity with -the specification of the WebSocket protocol in :rfc:`6455`. +This document describes the design of the legacy implementation of websockets. +It assumes familiarity with the specification of the WebSocket protocol in +:rfc:`6455`. It's primarily intended at maintainers. It may also be useful for users who wish to understand what happens under the hood. @@ -32,21 +33,19 @@ WebSocket connections go through a trivial state machine: Transitions happen in the following places: - ``CONNECTING -> OPEN``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` which runs when - the :ref:`opening handshake ` completes and the WebSocket + :meth:`~protocol.WebSocketCommonProtocol.connection_open` which runs when the + :ref:`opening handshake ` completes and the WebSocket connection is established — not to be confused with - :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP connection - is established; -- ``OPEN -> CLOSING``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.write_frame` immediately before - sending a close frame; since receiving a close frame triggers sending a - close frame, this does the right thing regardless of which side started the - :ref:`closing handshake `; also in - :meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` which duplicates - a few lines of code from ``write_close_frame()`` and ``write_frame()``; -- ``* -> CLOSED``: in - :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_lost` which is always - called exactly once when the TCP connection is closed. + :meth:`~asyncio.BaseProtocol.connection_made` which runs when the TCP + connection is established; +- ``OPEN -> CLOSING``: in :meth:`~protocol.WebSocketCommonProtocol.write_frame` + immediately before sending a close frame; since receiving a close frame + triggers sending a close frame, this does the right thing regardless of which + side started the :ref:`closing handshake `; also in + :meth:`~protocol.WebSocketCommonProtocol.fail_connection` which duplicates a + few lines of code from ``write_close_frame()`` and ``write_frame()``; +- ``* -> CLOSED``: in :meth:`~protocol.WebSocketCommonProtocol.connection_lost` + which is always called exactly once when the TCP connection is closed. Coroutines .......... @@ -57,38 +56,38 @@ connection lifecycle on the client side. .. image:: lifecycle.svg :target: _images/lifecycle.svg -The lifecycle is identical on the server side, except inversion of control -makes the equivalent of :meth:`~client.connect` implicit. +The lifecycle is identical on the server side, except inversion of control makes +the equivalent of :meth:`~client.connect` implicit. Coroutines shown in green are called by the application. Multiple coroutines may interact with the WebSocket connection concurrently. Coroutines shown in gray manage the connection. When the opening handshake -succeeds, :meth:`~legacy.protocol.WebSocketCommonProtocol.connection_open` starts -two tasks: - -- :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.transfer_data` which handles - incoming data and lets :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` - consume it. It may be canceled to terminate the connection. It never exits - with an exception other than :exc:`~asyncio.CancelledError`. See :ref:`data - transfer ` below. - -- :attr:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping +succeeds, :meth:`~protocol.WebSocketCommonProtocol.connection_open` starts two +tasks: + +- :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` runs + :meth:`~protocol.WebSocketCommonProtocol.transfer_data` which handles incoming + data and lets :meth:`~protocol.WebSocketCommonProtocol.recv` consume it. It + may be canceled to terminate the connection. It never exits with an exception + other than :exc:`~asyncio.CancelledError`. See :ref:`data transfer + ` below. + +- :attr:`~protocol.WebSocketCommonProtocol.keepalive_ping_task` runs + :meth:`~protocol.WebSocketCommonProtocol.keepalive_ping` which sends Ping frames at regular intervals and ensures that corresponding Pong frames are - received. It is canceled when the connection terminates. It never exits - with an exception other than :exc:`~asyncio.CancelledError`. + received. It is canceled when the connection terminates. It never exits with + an exception other than :exc:`~asyncio.CancelledError`. -- :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` runs - :meth:`~legacy.protocol.WebSocketCommonProtocol.close_connection` which waits for - the data transfer to terminate, then takes care of closing the TCP - connection. It must not be canceled. It never exits with an exception. See - :ref:`connection termination ` below. +- :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` runs + :meth:`~protocol.WebSocketCommonProtocol.close_connection` which waits for the + data transfer to terminate, then takes care of closing the TCP connection. It + must not be canceled. It never exits with an exception. See :ref:`connection + termination ` below. -Besides, :meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` starts -the same :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` when -the opening handshake fails, in order to close the TCP connection. +Besides, :meth:`~protocol.WebSocketCommonProtocol.fail_connection` starts the +same :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` when the +opening handshake fails, in order to close the TCP connection. Splitting the responsibilities between two tasks makes it easier to guarantee that websockets can terminate connections: @@ -99,11 +98,11 @@ that websockets can terminate connections: regardless of whether the connection terminates normally or abnormally. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` completes when no +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` completes when no more data will be received on the connection. Under normal circumstances, it exits after exchanging close frames. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` completes when +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` completes when the TCP connection is closed. @@ -113,10 +112,9 @@ Opening handshake ----------------- websockets performs the opening handshake when establishing a WebSocket -connection. On the client side, :meth:`~client.connect` executes it -before returning the protocol to the caller. On the server side, it's executed -before passing the protocol to the ``ws_handler`` coroutine handling the -connection. +connection. On the client side, :meth:`~client.connect` executes it before +returning the protocol to the caller. On the server side, it's executed before +passing the protocol to the ``ws_handler`` coroutine handling the connection. While the opening handshake is asymmetrical — the client sends an HTTP Upgrade request and the server replies with an HTTP Switching Protocols response — @@ -136,9 +134,9 @@ On the client side, :meth:`~client.WebSocketClientProtocol.handshake`: On the server side, :meth:`~server.WebSocketServerProtocol.handshake`: - reads an HTTP request from the network; -- calls :meth:`~server.WebSocketServerProtocol.process_request` which may - abort the WebSocket handshake and return an HTTP response instead; this - hook only makes sense on the server side; +- calls :meth:`~server.WebSocketServerProtocol.process_request` which may abort + the WebSocket handshake and return an HTTP response instead; this hook only + makes sense on the server side; - checks the HTTP request, negotiates ``extensions`` and ``subprotocol``, and configures the protocol accordingly; - builds an HTTP response based on the above and parameters passed to @@ -178,13 +176,13 @@ differences between a server and a client: These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented -in the same class, :class:`~legacy.protocol.WebSocketCommonProtocol`. +in the same class, :class:`~protocol.WebSocketCommonProtocol`. .. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 .. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 .. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 -The :attr:`~legacy.protocol.WebSocketCommonProtocol.is_client` attribute tells which +The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the :attr:`~server.WebSocketServerProtocol` and :attr:`~client.WebSocketClientProtocol` classes. @@ -211,11 +209,11 @@ The left side of the diagram shows how websockets receives data. Incoming data is written to a :class:`~asyncio.StreamReader` in order to implement flow control and provide backpressure on the TCP connection. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`, which is started +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, which is started when the WebSocket connection is established, processes this data. When it receives data frames, it reassembles fragments and puts the resulting -messages in the :attr:`~legacy.protocol.WebSocketCommonProtocol.messages` queue. +messages in the :attr:`~protocol.WebSocketCommonProtocol.messages` queue. When it encounters a control frame: @@ -227,11 +225,11 @@ When it encounters a control frame: Running this process in a task guarantees that control frames are processed promptly. Without such a task, websockets would depend on the application to drive the connection by having exactly one coroutine awaiting -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv` at any time. While this -happens naturally in many use cases, it cannot be relied upon. +:meth:`~protocol.WebSocketCommonProtocol.recv` at any time. While this happens +naturally in many use cases, it cannot be relied upon. -Then :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` fetches the next message -from the :attr:`~legacy.protocol.WebSocketCommonProtocol.messages` queue, with some +Then :meth:`~protocol.WebSocketCommonProtocol.recv` fetches the next message +from the :attr:`~protocol.WebSocketCommonProtocol.messages` queue, with some complexity added for handling backpressure and termination correctly. Sending data @@ -239,19 +237,19 @@ Sending data The right side of the diagram shows how websockets sends data. -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` writes one or several data -frames containing the message. While sending a fragmented message, concurrent -calls to :meth:`~legacy.protocol.WebSocketCommonProtocol.send` are put on hold until -all fragments are sent. This makes concurrent calls safe. +:meth:`~protocol.WebSocketCommonProtocol.send` writes one or several data frames +containing the message. While sending a fragmented message, concurrent calls to +:meth:`~protocol.WebSocketCommonProtocol.send` are put on hold until all +fragments are sent. This makes concurrent calls safe. -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping` writes a ping frame and -yields a :class:`~asyncio.Future` which will be completed when a matching pong -frame is received. +:meth:`~protocol.WebSocketCommonProtocol.ping` writes a ping frame and yields a +:class:`~asyncio.Future` which will be completed when a matching pong frame is +received. -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong` writes a pong frame. +:meth:`~protocol.WebSocketCommonProtocol.pong` writes a pong frame. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` writes a close frame and -waits for the TCP connection to terminate. +:meth:`~protocol.WebSocketCommonProtocol.close` writes a close frame and waits +for the TCP connection to terminate. Outgoing data is written to a :class:`~asyncio.StreamWriter` in order to implement flow control and provide backpressure from the TCP connection. @@ -262,17 +260,17 @@ Closing handshake ................. When the other side of the connection initiates the closing handshake, -:meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives a close -frame while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a -close frame, and returns :obj:`None`, causing -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. +:meth:`~protocol.WebSocketCommonProtocol.read_message` receives a close frame +while in the ``OPEN`` state. It moves to the ``CLOSING`` state, sends a close +frame, and returns :obj:`None`, causing +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. When this side of the connection initiates the closing handshake with -:meth:`~legacy.protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` +:meth:`~protocol.WebSocketCommonProtocol.close`, it moves to the ``CLOSING`` state and sends a close frame. When the other side sends a close frame, -:meth:`~legacy.protocol.WebSocketCommonProtocol.read_message` receives it in the +:meth:`~protocol.WebSocketCommonProtocol.read_message` receives it in the ``CLOSING`` state and returns :obj:`None`, also causing -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate. If the other side doesn't send a close frame within the connection's close timeout, websockets :ref:`fails the connection `. @@ -289,33 +287,33 @@ Then websockets terminates the TCP connection. Connection termination ---------------------- -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which is +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which is started when the WebSocket connection is established, is responsible for eventually closing the TCP connection. -First :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` waits -for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, -which may happen as a result of: +First :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` waits for +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to terminate, which +may happen as a result of: - a successful closing handshake: as explained above, this exits the infinite - loop in :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`; + loop in :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a timeout while waiting for the closing handshake to complete: this cancels - :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`; + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`; - a protocol error, including connection errors: depending on the exception, - :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the + :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` :ref:`fails the connection ` with a suitable code and exits. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` is separate -from :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` to make it -easier to implement the timeout on the closing handshake. Canceling -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk -of canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` -and failing to close the TCP connection, thus leaking resources. +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` is separate from +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to make it easier +to implement the timeout on the closing handshake. Canceling +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` creates no risk of +canceling :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` and +failing to close the TCP connection, thus leaking resources. -Then :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` cancels -:meth:`~legacy.protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no -protocol compliance responsibilities. Terminating it to avoid leaking it is -the only concern. +Then :attr:`~protocol.WebSocketCommonProtocol.close_connection_task` cancels +:meth:`~protocol.WebSocketCommonProtocol.keepalive_ping`. This task has no +protocol compliance responsibilities. Terminating it to avoid leaking it is the +only concern. Terminating the TCP connection can take up to ``2 * close_timeout`` on the server side and ``3 * close_timeout`` on the client side. Clients start by @@ -335,11 +333,11 @@ If the opening handshake doesn't complete successfully, websockets fails the connection by closing the TCP connection. Once the opening handshake has completed, websockets fails the connection by -canceling :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -and sending a close frame if appropriate. +canceling :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +sending a close frame if appropriate. -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`, which closes +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` exits, unblocking +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`, which closes the TCP connection. @@ -348,7 +346,7 @@ the TCP connection. Server shutdown --------------- -:class:`~websockets.server.WebSocketServer` closes asynchronously like +:class:`~server.WebSocketServer` closes asynchronously like :class:`asyncio.Server`. The shutdown happen in two steps: 1. Stop listening and accepting new connections; @@ -356,10 +354,10 @@ Server shutdown the opening handshake is still in progress, with HTTP status code 503 (Service Unavailable). -The first call to :class:`~websockets.server.WebSocketServer.close` starts a -task that performs this sequence. Further calls are ignored. This is the -easiest way to make :class:`~websockets.server.WebSocketServer.close` and -:class:`~websockets.server.WebSocketServer.wait_closed` idempotent. +The first call to :class:`~server.WebSocketServer.close` starts a task that +performs this sequence. Further calls are ignored. This is the easiest way to +make :class:`~server.WebSocketServer.close` and +:class:`~server.WebSocketServer.wait_closed` idempotent. .. _cancellation: @@ -415,45 +413,45 @@ happen on the client side. On the server side, the opening handshake is managed by websockets and nothing results in a cancellation. Once the WebSocket connection is established, internal tasks -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` and -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` mustn't get +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` and +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` mustn't get accidentally canceled if a coroutine that awaits them is canceled. In other words, they must be shielded from cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv` waits for the next message in -the queue or for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -to terminate, whichever comes first. It relies on :func:`~asyncio.wait` for -waiting on two futures in parallel. As a consequence, even though it's waiting -on a :class:`~asyncio.Future` signaling the next message and on -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't +:meth:`~protocol.WebSocketCommonProtocol.recv` waits for the next message in the +queue or for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` to +terminate, whichever comes first. It relies on :func:`~asyncio.wait` for waiting +on two futures in parallel. As a consequence, even though it's waiting on a +:class:`~asyncio.Future` signaling the next message and on +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`, it doesn't propagate cancellation to them. -:meth:`~legacy.protocol.WebSocketCommonProtocol.ensure_open` is called by -:meth:`~legacy.protocol.WebSocketCommonProtocol.send`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, and -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong`. When the connection state is +:meth:`~protocol.WebSocketCommonProtocol.ensure_open` is called by +:meth:`~protocol.WebSocketCommonProtocol.send`, +:meth:`~protocol.WebSocketCommonProtocol.ping`, and +:meth:`~protocol.WebSocketCommonProtocol.pong`. When the connection state is ``CLOSING``, it waits for -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` but shields it to prevent cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` waits for the data transfer -task to terminate with :func:`~asyncio.timeout`. If it's canceled or if the -timeout elapses, :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` -is canceled, which is correct at this point. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` then waits for -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` but shields it +:meth:`~protocol.WebSocketCommonProtocol.close` waits for the data transfer task +to terminate with :func:`~asyncio.timeout`. If it's canceled or if the timeout +elapses, :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` is +canceled, which is correct at this point. +:meth:`~protocol.WebSocketCommonProtocol.close` then waits for +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` but shields it to prevent cancellation. -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` and -:meth:`~legacy.protocol.WebSocketCommonProtocol.fail_connection` are the only -places where :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` may -be canceled. +:meth:`~protocol.WebSocketCommonProtocol.close` and +:meth:`~protocol.WebSocketCommonProtocol.fail_connection` are the only places +where :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` may be +canceled. -:attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task` starts by -waiting for :attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task`. It +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task` starts by +waiting for :attr:`~protocol.WebSocketCommonProtocol.transfer_data_task`. It catches :exc:`~asyncio.CancelledError` to prevent a cancellation of -:attr:`~legacy.protocol.WebSocketCommonProtocol.transfer_data_task` from propagating -to :attr:`~legacy.protocol.WebSocketCommonProtocol.close_connection_task`. +:attr:`~protocol.WebSocketCommonProtocol.transfer_data_task` from propagating to +:attr:`~protocol.WebSocketCommonProtocol.close_connection_task`. .. _backpressure: @@ -491,28 +489,28 @@ buffers and break the backpressure. Be careful with queues. Concurrency ----------- -Awaiting any combination of :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.send`, -:meth:`~legacy.protocol.WebSocketCommonProtocol.close` -:meth:`~legacy.protocol.WebSocketCommonProtocol.ping`, or -:meth:`~legacy.protocol.WebSocketCommonProtocol.pong` concurrently is safe, including +Awaiting any combination of :meth:`~protocol.WebSocketCommonProtocol.recv`, +:meth:`~protocol.WebSocketCommonProtocol.send`, +:meth:`~protocol.WebSocketCommonProtocol.close` +:meth:`~protocol.WebSocketCommonProtocol.ping`, or +:meth:`~protocol.WebSocketCommonProtocol.pong` concurrently is safe, including multiple calls to the same method, with one exception and one limitation. -* **Only one coroutine can receive messages at a time.** This constraint - avoids non-deterministic behavior (and simplifies the implementation). If a - coroutine is awaiting :meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, - awaiting it again in another coroutine raises :exc:`RuntimeError`. +* **Only one coroutine can receive messages at a time.** This constraint avoids + non-deterministic behavior (and simplifies the implementation). If a coroutine + is awaiting :meth:`~protocol.WebSocketCommonProtocol.recv`, awaiting it again + in another coroutine raises :exc:`RuntimeError`. * **Sending a fragmented message forces serialization.** Indeed, the WebSocket protocol doesn't support multiplexing messages. If a coroutine is awaiting - :meth:`~legacy.protocol.WebSocketCommonProtocol.send` to send a fragmented message, + :meth:`~protocol.WebSocketCommonProtocol.send` to send a fragmented message, awaiting it again in another coroutine waits until the first call completes. - This will be transparent in many cases. It may be a concern if the - fragmented message is generated slowly by an asynchronous iterator. + This will be transparent in many cases. It may be a concern if the fragmented + message is generated slowly by an asynchronous iterator. Receiving frames is independent from sending frames. This isolates -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`, which receives frames, from -the other methods, which send frames. +:meth:`~protocol.WebSocketCommonProtocol.recv`, which receives frames, from the +other methods, which send frames. While the connection is open, each frame is sent with a single write. Combined with the concurrency model of :mod:`asyncio`, this enforces serialization. The diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 765278360..cad49ba55 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -101,9 +101,10 @@ However, this technique runs into two problems: * Even with :meth:`str.format` style, you're restricted to attribute and index lookups, which isn't enough to implement some fairly simple requirements. -There's a better way. :func:`~client.connect` and :func:`~server.serve` accept -a ``logger`` argument to override the default :class:`~logging.Logger`. You -can set ``logger`` to a :class:`~logging.LoggerAdapter` that enriches logs. +There's a better way. :func:`~asyncio.client.connect` and +:func:`~asyncio.server.serve` accept a ``logger`` argument to override the +default :class:`~logging.Logger`. You can set ``logger`` to a +:class:`~logging.LoggerAdapter` that enriches logs. For example, if the server is behind a reverse proxy, :attr:`~legacy.protocol.WebSocketCommonProtocol.remote_address` gives @@ -128,7 +129,7 @@ Here's how to include them in logs, assuming they're in the xff = websocket.request_headers.get("X-Forwarded-For") return f"{websocket.id} {xff} {msg}", kwargs - async with websockets.serve( + async with serve( ..., # Python < 3.10 requires passing None as the second argument. logger=LoggerAdapter(logging.getLogger("websockets.server"), None), @@ -170,7 +171,7 @@ a :class:`~logging.LoggerAdapter`:: } return msg, kwargs - async with websockets.serve( + async with serve( ..., # Python < 3.10 requires passing None as the second argument. logger=LoggerAdapter(logging.getLogger("websockets.server"), None), diff --git a/docs/topics/memory.rst b/docs/topics/memory.rst index efbcbb83f..61b1113e2 100644 --- a/docs/topics/memory.rst +++ b/docs/topics/memory.rst @@ -99,10 +99,11 @@ workloads but it can also backfire because it delays backpressure. messages. * In the legacy :mod:`asyncio` implementation, there is a library-level read - buffer. The ``read_limit`` argument of :func:`~client.connect` and - :func:`~server.serve` controls its size. When the read buffer grows above the - high-water mark, the connection stops reading from the network until it drains - under the low-water mark. This creates backpressure on the TCP connection. + buffer. The ``read_limit`` argument of :func:`~legacy.client.connect` and + :func:`~legacy.server.serve` controls its size. When the read buffer grows + above the high-water mark, the connection stops reading from the network until + it drains under the low-water mark. This creates backpressure on the TCP + connection. There is a write buffer. It as controlled by ``write_limit``. It behaves like the new :mod:`asyncio` implementation described above. diff --git a/docs/topics/security.rst b/docs/topics/security.rst index 83d79e35b..a22b752c7 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -49,9 +49,9 @@ Identification By default, websockets identifies itself with a ``Server`` or ``User-Agent`` header in the format ``"Python/x.y.z websockets/X.Y"``. -You can set the ``server_header`` argument of :func:`~server.serve` or the -``user_agent_header`` argument of :func:`~client.connect` to configure another -value. Setting them to :obj:`None` removes the header. +You can set the ``server_header`` argument of :func:`~asyncio.server.serve` or +the ``user_agent_header`` argument of :func:`~asyncio.client.connect` to +configure another value. Setting them to :obj:`None` removes the header. Alternatively, you can set the :envvar:`WEBSOCKETS_SERVER` and :envvar:`WEBSOCKETS_USER_AGENT` environment variables respectively. Setting them diff --git a/example/deployment/fly/app.py b/example/deployment/fly/app.py index 4ca34d23b..c8e6af4f9 100644 --- a/example/deployment/fly/app.py +++ b/example/deployment/fly/app.py @@ -4,7 +4,7 @@ import http import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -12,9 +12,9 @@ async def echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): @@ -23,7 +23,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py index 360479b8e..ef6d9c42d 100644 --- a/example/deployment/haproxy/app.py +++ b/example/deployment/haproxy/app.py @@ -4,7 +4,7 @@ import os import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="localhost", port=8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]), diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index d4ba3edb5..17ad09d26 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -4,7 +4,7 @@ import signal import os -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=int(os.environ["PORT"]), diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py index a8bcef688..387f0ade1 100755 --- a/example/deployment/kubernetes/app.py +++ b/example/deployment/kubernetes/app.py @@ -6,7 +6,7 @@ import sys import time -import websockets +from websockets.asyncio.server import serve async def slow_echo(websocket): @@ -17,17 +17,17 @@ async def slow_echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" - if path == "/inemuri": +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + if request.path == "/inemuri": loop = asyncio.get_running_loop() loop.call_later(1, time.sleep, 10) - return http.HTTPStatus.OK, [], b"Sleeping for 10s\n" - if path == "/seppuku": + return connection.respond(http.HTTPStatus.OK, "Sleeping for 10s\n") + if request.path == "/seppuku": loop = asyncio.get_running_loop() loop.call_later(1, sys.exit, 69) - return http.HTTPStatus.OK, [], b"Terminating\n" + return connection.respond(http.HTTPStatus.OK, "Terminating\n") async def main(): @@ -36,7 +36,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( slow_echo, host="", port=80, diff --git a/example/deployment/kubernetes/benchmark.py b/example/deployment/kubernetes/benchmark.py index 22ee4c5bd..11a452d55 100755 --- a/example/deployment/kubernetes/benchmark.py +++ b/example/deployment/kubernetes/benchmark.py @@ -2,14 +2,15 @@ import asyncio import sys -import websockets + +from websockets.asyncio.client import connect URI = "ws://localhost:32080" async def run(client_id, messages): - async with websockets.connect(URI) as websocket: + async with connect(URI) as websocket: for message_id in range(messages): await websocket.send(f"{client_id}:{message_id}") await websocket.recv() diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py index 24e608975..134070f61 100644 --- a/example/deployment/nginx/app.py +++ b/example/deployment/nginx/app.py @@ -4,7 +4,7 @@ import os import signal -import websockets +from websockets.asyncio.server import unix_serve async def echo(websocket): @@ -18,7 +18,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.unix_serve( + async with unix_serve( echo, path=f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock", ): diff --git a/example/deployment/render/app.py b/example/deployment/render/app.py index 4ca34d23b..c8e6af4f9 100644 --- a/example/deployment/render/app.py +++ b/example/deployment/render/app.py @@ -4,7 +4,7 @@ import http import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -12,9 +12,9 @@ async def echo(websocket): await websocket.send(message) -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") async def main(): @@ -23,7 +23,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index bf61983ef..5e69f16a6 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -3,7 +3,7 @@ import asyncio import signal -import websockets +from websockets.asyncio.server import serve async def echo(websocket): @@ -17,7 +17,7 @@ async def main(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( echo, host="", port=8080, diff --git a/example/django/authentication.py b/example/django/authentication.py index 83e128f07..c4f12a3f8 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -3,11 +3,11 @@ import asyncio import django -import websockets django.setup() from sesame.utils import get_user +from websockets.asyncio.server import serve from websockets.frames import CloseCode @@ -22,7 +22,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "localhost", 8888): + async with serve(handler, "localhost", 8888): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/django/notifications.py b/example/django/notifications.py index 3a9ed10cf..445438d2d 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -5,12 +5,13 @@ import aioredis import django -import websockets django.setup() from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve from websockets.frames import CloseCode @@ -61,11 +62,11 @@ async def process_events(): for websocket, connection in CONNECTIONS.items() if event["content_type_id"] in connection["content_type_ids"] ) - websockets.broadcast(recipients, payload) + broadcast(recipients, payload) async def main(): - async with websockets.serve(handler, "localhost", 8888): + async with serve(handler, "localhost", 8888): await process_events() # runs forever diff --git a/example/echo.py b/example/echo.py index d11b33527..b952a5cfb 100755 --- a/example/echo.py +++ b/example/echo.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import asyncio -from websockets.server import serve +from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index 6c7681e8a..c0fa4327f 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -1,22 +1,19 @@ #!/usr/bin/env python import asyncio -import http -import websockets +from http import HTTPStatus +from websockets.asyncio.server import serve -async def health_check(path, request_headers): - if path == "/healthz": - return http.HTTPStatus.OK, [], b"OK\n" +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(HTTPStatus.OK, b"OK\n") async def echo(websocket): async for message in websocket: await websocket.send(message) async def main(): - async with websockets.serve( - echo, "localhost", 8765, - process_request=health_check, - ): + async with serve(echo, "localhost", 8765, process_request=health_check): await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/faq/shutdown_client.py b/example/faq/shutdown_client.py index 539dd0304..5c8bd8cbe 100755 --- a/example/faq/shutdown_client.py +++ b/example/faq/shutdown_client.py @@ -2,15 +2,15 @@ import asyncio import signal -import websockets + +from websockets.asyncio.client import connect async def client(): uri = "ws://localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() - loop.add_signal_handler( - signal.SIGTERM, loop.create_task, websocket.close()) + loop.add_signal_handler(signal.SIGTERM, loop.create_task, websocket.close()) # Process messages received on the connection. async for message in websocket: diff --git a/example/faq/shutdown_server.py b/example/faq/shutdown_server.py index 1bcc9c90b..3f7bc5732 100755 --- a/example/faq/shutdown_server.py +++ b/example/faq/shutdown_server.py @@ -2,7 +2,8 @@ import asyncio import signal -import websockets + +from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: @@ -14,7 +15,7 @@ async def server(): stop = loop.create_future() loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve(echo, "localhost", 8765): + async with serve(echo, "localhost", 8765): await stop asyncio.run(server()) diff --git a/example/legacy/basic_auth_client.py b/example/legacy/basic_auth_client.py index 164732152..0252894b7 100755 --- a/example/legacy/basic_auth_client.py +++ b/example/legacy/basic_auth_client.py @@ -3,11 +3,12 @@ # WS client example with HTTP Basic Authentication import asyncio -import websockets + +from websockets.legacy.client import connect async def hello(): uri = "ws://mary:p@ssw0rd@localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: greeting = await websocket.recv() print(greeting) diff --git a/example/legacy/basic_auth_server.py b/example/legacy/basic_auth_server.py index 6f6020253..fc45a0270 100755 --- a/example/legacy/basic_auth_server.py +++ b/example/legacy/basic_auth_server.py @@ -3,16 +3,18 @@ # Server example with HTTP Basic Authentication over TLS import asyncio -import websockets + +from websockets.legacy.auth import basic_auth_protocol_factory +from websockets.legacy.server import serve async def hello(websocket): greeting = f"Hello {websocket.username}!" await websocket.send(greeting) async def main(): - async with websockets.serve( + async with serve( hello, "localhost", 8765, - create_protocol=websockets.basic_auth_protocol_factory( + create_protocol=basic_auth_protocol_factory( realm="example", credentials=("mary", "p@ssw0rd") ), ): diff --git a/example/legacy/unix_client.py b/example/legacy/unix_client.py index 926156730..87201c9e4 100755 --- a/example/legacy/unix_client.py +++ b/example/legacy/unix_client.py @@ -4,11 +4,12 @@ import asyncio import os.path -import websockets + +from websockets.legacy.client import unix_connect async def hello(): socket_path = os.path.join(os.path.dirname(__file__), "socket") - async with websockets.unix_connect(socket_path) as websocket: + async with unix_connect(socket_path) as websocket: name = input("What's your name? ") await websocket.send(name) print(f">>> {name}") diff --git a/example/legacy/unix_server.py b/example/legacy/unix_server.py index 5bfb66072..8a4981f5f 100755 --- a/example/legacy/unix_server.py +++ b/example/legacy/unix_server.py @@ -4,7 +4,8 @@ import asyncio import os.path -import websockets + +from websockets.legacy.server import unix_serve async def hello(websocket): name = await websocket.recv() @@ -17,7 +18,7 @@ async def hello(websocket): async def main(): socket_path = os.path.join(os.path.dirname(__file__), "socket") - async with websockets.unix_serve(hello, socket_path): + async with unix_serve(hello, socket_path): await asyncio.get_running_loop().create_future() # run forever asyncio.run(main()) diff --git a/example/logging/json_log_formatter.py b/example/logging/json_log_formatter.py index b8fc8d6dc..ff7fce8b5 100644 --- a/example/logging/json_log_formatter.py +++ b/example/logging/json_log_formatter.py @@ -1,6 +1,6 @@ +import datetime import json import logging -import datetime class JSONFormatter(logging.Formatter): """ diff --git a/example/quickstart/client.py b/example/quickstart/client.py index 8d588c2b0..934af69e3 100755 --- a/example/quickstart/client.py +++ b/example/quickstart/client.py @@ -1,11 +1,12 @@ #!/usr/bin/env python import asyncio -import websockets + +from websockets.asyncio.client import connect async def hello(): uri = "ws://localhost:8765" - async with websockets.connect(uri) as websocket: + async with connect(uri) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/quickstart/client_secure.py b/example/quickstart/client_secure.py index f4b39f2b8..a1449587a 100755 --- a/example/quickstart/client_secure.py +++ b/example/quickstart/client_secure.py @@ -3,7 +3,8 @@ import asyncio import pathlib import ssl -import websockets + +from websockets.asyncio.client import connect ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") @@ -11,7 +12,7 @@ async def hello(): uri = "wss://localhost:8765" - async with websockets.connect(uri, ssl=ssl_context) as websocket: + async with connect(uri, ssl=ssl_context) as websocket: name = input("What's your name? ") await websocket.send(name) diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index 414919e04..d42069e64 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -3,7 +3,8 @@ import asyncio import json import logging -import websockets +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve logging.basicConfig() @@ -22,7 +23,7 @@ async def counter(websocket): try: # Register user USERS.add(websocket) - websockets.broadcast(USERS, users_event()) + broadcast(USERS, users_event()) # Send current state to user await websocket.send(value_event()) # Manage state changes @@ -30,19 +31,19 @@ async def counter(websocket): event = json.loads(message) if event["action"] == "minus": VALUE -= 1 - websockets.broadcast(USERS, value_event()) + broadcast(USERS, value_event()) elif event["action"] == "plus": VALUE += 1 - websockets.broadcast(USERS, value_event()) + broadcast(USERS, value_event()) else: logging.error("unsupported event: %s", event) finally: # Unregister user USERS.remove(websocket) - websockets.broadcast(USERS, users_event()) + broadcast(USERS, users_event()) async def main(): - async with websockets.serve(counter, "localhost", 6789): + async with serve(counter, "localhost", 6789): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/server.py b/example/quickstart/server.py index 64d7adeb6..bde5e6126 100755 --- a/example/quickstart/server.py +++ b/example/quickstart/server.py @@ -1,7 +1,8 @@ #!/usr/bin/env python import asyncio -import websockets + +from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() @@ -13,7 +14,7 @@ async def hello(websocket): print(f">>> {greeting}") async def main(): - async with websockets.serve(hello, "localhost", 8765): + async with serve(hello, "localhost", 8765): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/server_secure.py b/example/quickstart/server_secure.py index 11db5fb3a..8b456ed6e 100755 --- a/example/quickstart/server_secure.py +++ b/example/quickstart/server_secure.py @@ -3,7 +3,8 @@ import asyncio import pathlib import ssl -import websockets + +from websockets.asyncio.server import serve async def hello(websocket): name = await websocket.recv() @@ -19,7 +20,7 @@ async def hello(websocket): ssl_context.load_cert_chain(localhost_pem) async def main(): - async with websockets.serve(hello, "localhost", 8765, ssl=ssl_context): + async with serve(hello, "localhost", 8765, ssl=ssl_context): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index add226869..8aeb811db 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -3,7 +3,8 @@ import asyncio import datetime import random -import websockets + +from websockets.asyncio.server import serve async def show_time(websocket): while True: @@ -12,7 +13,7 @@ async def show_time(websocket): await asyncio.sleep(random.random() * 2 + 1) async def main(): - async with websockets.serve(show_time, "localhost", 5678): + async with serve(show_time, "localhost", 5678): await asyncio.get_running_loop().create_future() # run forever if __name__ == "__main__": diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py index 08e87f593..4fa244a23 100755 --- a/example/quickstart/show_time_2.py +++ b/example/quickstart/show_time_2.py @@ -3,7 +3,9 @@ import asyncio import datetime import random -import websockets + +from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import serve CONNECTIONS = set() @@ -17,11 +19,11 @@ async def register(websocket): async def show_time(): while True: message = datetime.datetime.utcnow().isoformat() + "Z" - websockets.broadcast(CONNECTIONS, message) + broadcast(CONNECTIONS, message) await asyncio.sleep(random.random() * 2 + 1) async def main(): - async with websockets.serve(register, "localhost", 5678): + async with serve(register, "localhost", 5678): await show_time() if __name__ == "__main__": diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index 6ec1c60b8..db69070a1 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -4,7 +4,7 @@ import itertools import json -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -57,7 +57,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index db3e36374..feaf223a0 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,7 +4,7 @@ import json import secrets -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -182,7 +182,7 @@ async def handler(websocket): async def main(): - async with websockets.serve(handler, "", 8001): + async with serve(handler, "", 8001): await asyncio.get_running_loop().create_future() # run forever diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index c2ee020d2..a428e29e7 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,7 +6,7 @@ import secrets import signal -import websockets +from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -190,7 +190,7 @@ async def main(): loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) port = int(os.environ.get("PORT", "8001")) - async with websockets.serve(handler, "", port): + async with serve(handler, "", port): await stop diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index 039e21174..e3b2cf1f6 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -8,8 +8,9 @@ import urllib.parse import uuid -import websockets from websockets.frames import CloseCode +from websockets.legacy.auth import BasicAuthWebSocketServerProtocol +from websockets.legacy.server import WebSocketServerProtocol, serve # User accounts database @@ -107,7 +108,7 @@ async def first_message_handler(websocket): # Add credentials to the WebSocket URI in a query parameter -class QueryParamProtocol(websockets.WebSocketServerProtocol): +class QueryParamProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): token = get_query_param(path, "token") if token is None: @@ -131,7 +132,7 @@ async def query_param_handler(websocket): # Set a cookie on the domain of the WebSocket URI -class CookieProtocol(websockets.WebSocketServerProtocol): +class CookieProtocol(WebSocketServerProtocol): async def process_request(self, path, headers): if "Upgrade" not in headers: template = pathlib.Path(__file__).with_name(path[1:]) @@ -161,7 +162,7 @@ async def cookie_handler(websocket): # Adding credentials to the WebSocket URI in user information -class UserInfoProtocol(websockets.BasicAuthWebSocketServerProtocol): +class UserInfoProtocol(BasicAuthWebSocketServerProtocol): async def check_credentials(self, username, password): if username != "token": return False @@ -192,26 +193,26 @@ async def main(): loop.add_signal_handler(signal.SIGINT, stop.set_result, None) loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with websockets.serve( + async with serve( noop_handler, host="", port=8000, process_request=serve_html, - ), websockets.serve( + ), serve( first_message_handler, host="", port=8001, - ), websockets.serve( + ), serve( query_param_handler, host="", port=8002, create_protocol=QueryParamProtocol, - ), websockets.serve( + ), serve( cookie_handler, host="", port=8003, create_protocol=CookieProtocol, - ), websockets.serve( + ), serve( user_info_handler, host="", port=8004, diff --git a/experiments/broadcast/clients.py b/experiments/broadcast/clients.py index fe39dfe05..64334f20f 100644 --- a/experiments/broadcast/clients.py +++ b/experiments/broadcast/clients.py @@ -5,7 +5,7 @@ import sys import time -import websockets +from websockets.asyncio.client import connect LATENCIES = {} @@ -26,7 +26,7 @@ async def log_latency(interval): async def client(): try: - async with websockets.connect( + async with connect( "ws://localhost:8765", ping_timeout=None, ) as websocket: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index adb66e262..52cc48898 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -319,8 +319,8 @@ class AbortHandshake(InvalidHandshake): This exception is an implementation detail. - The public API - is :meth:`~websockets.server.WebSocketServerProtocol.process_request`. + The public API is + :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. Attributes: status (~http.HTTPStatus): HTTP status code. diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 8526bad6b..4d030e5e2 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -125,11 +125,11 @@ def basic_auth_protocol_factory( Protocol factory that enforces HTTP Basic Auth. :func:`basic_auth_protocol_factory` is designed to integrate with - :func:`~websockets.server.serve` like this:: + :func:`~websockets.legacy.server.serve` like this:: - websockets.serve( + serve( ..., - create_protocol=websockets.basic_auth_protocol_factory( + create_protocol=basic_auth_protocol_factory( realm="my dev server", credentials=("hello", "iloveyou"), ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index b61126c81..256bee14c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -342,7 +342,7 @@ class Connect: :func:`connect` can be used as a asynchronous context manager:: - async with websockets.connect(...) as websocket: + async with connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -350,7 +350,7 @@ class Connect: :func:`connect` can be used as an infinite asynchronous iterator to reconnect automatically on errors:: - async for websocket in websockets.connect(...): + async for websocket in connect(...): try: ... except websockets.ConnectionClosed: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 191350de3..66eb94199 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -67,13 +67,13 @@ class WebSocketCommonProtocol(asyncio.Protocol): :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket servers and clients. You shouldn't use it directly. Instead, use - :class:`~websockets.client.WebSocketClientProtocol` or - :class:`~websockets.server.WebSocketServerProtocol`. + :class:`~websockets.legacy.client.WebSocketClientProtocol` or + :class:`~websockets.legacy.server.WebSocketServerProtocol`. This documentation focuses on low-level details that aren't covered in the - documentation of :class:`~websockets.client.WebSocketClientProtocol` and - :class:`~websockets.server.WebSocketServerProtocol` for the sake of - simplicity. + documentation of :class:`~websockets.legacy.client.WebSocketClientProtocol` + and :class:`~websockets.legacy.server.WebSocketServerProtocol` for the sake + of simplicity. Once the connection is open, a Ping_ frame is sent every ``ping_interval`` seconds. This serves as a keepalive. It helps keeping the connection open, @@ -89,7 +89,7 @@ class WebSocketCommonProtocol(asyncio.Protocol): .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - See the discussion of :doc:`timeouts <../../topics/keepalive>` for details. + See the discussion of :doc:`keepalive <../../topics/keepalive>` for details. The ``close_timeout`` parameter defines a maximum wait time for completing the closing handshake and terminating the TCP connection. For legacy @@ -99,8 +99,8 @@ class WebSocketCommonProtocol(asyncio.Protocol): ``close_timeout`` is a parameter of the protocol because websockets usually calls :meth:`close` implicitly upon exit: - * on the client side, when using :func:`~websockets.client.connect` as a - context manager; + * on the client side, when using :func:`~websockets.legacy.client.connect` + as a context manager; * on the server side, when the connection handler terminates. To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 9b84a6b81..3cdfeb21b 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -150,7 +150,7 @@ async def test_aiter_connection_closed_ok(self): await anext(aiterator) async def test_aiter_connection_closed_error(self): - """__aiter__ raises ConnnectionClosedError after an error.""" + """__aiter__ raises ConnectionClosedError after an error.""" aiterator = aiter(self.connection) await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index b3023434b..5d4f0e2f8 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,5 +1,4 @@ import asyncio -import dataclasses import http import logging import socket @@ -228,9 +227,7 @@ async def test_process_response_override_response(self): """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" async with run_server(process_response=process_response) as server: async with run_client(server) as client: @@ -242,9 +239,7 @@ async def test_async_process_response_override_response(self): """Server runs async process_response and overrides the handshake response.""" async def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" async with run_server(process_response=process_response) as server: async with run_client(server) as client: diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 88cbcd669..877adc4bf 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -132,7 +132,7 @@ def test_iter_connection_closed_ok(self): next(iterator) def test_iter_connection_closed_error(self): - """__iter__ raises ConnnectionClosedError after an error.""" + """__iter__ raises ConnectionClosedError after an error.""" iterator = iter(self.connection) self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 4e04a39d5..c0a5f01e6 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,4 +1,3 @@ -import dataclasses import http import logging import socket @@ -173,9 +172,7 @@ def test_process_response_override_response(self): """Server runs process_response and overrides the handshake response.""" def process_response(ws, request, response): - headers = response.headers.copy() - headers["X-ProcessResponse-Ran"] = "true" - return dataclasses.replace(response, headers=headers) + response.headers["X-ProcessResponse-Ran"] = "true" with run_server(process_response=process_response) as server: with run_client(server) as client: From 472f9517b0f8d1f190ae5961fe10a064ef016972 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 19:34:38 +0200 Subject: [PATCH 559/762] Explain new asyncio implementation in docs index page. --- docs/index.rst | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index d9737db12..218a489a3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,23 +28,42 @@ with a focus on correctness, simplicity, robustness, and performance. It supports several network I/O and control flow paradigms: -1. The default implementation builds upon :mod:`asyncio`, Python's standard +1. The primary implementation builds upon :mod:`asyncio`, Python's standard asynchronous I/O framework. It provides an elegant coroutine-based API. It's ideal for servers that handle many clients concurrently. + + .. admonition:: As of version :ref:`13.0`, there is a new :mod:`asyncio` + implementation. + :class: important + + The historical implementation in ``websockets.legacy`` traces its roots to + early versions of websockets. Although it's stable and robust, it is now + considered legacy. + + The new implementation in ``websockets.asyncio`` is a rewrite on top of + the Sans-I/O implementation. It adds a few features that were impossible + to implement within the original design. + + The new implementation will become the default as soon as it reaches + feature parity. If you're using the historical implementation, you should + :doc:`ugrade to the new implementation `. It's usually + straightforward. + 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used for servers that don't need to serve many clients. + 3. The `Sans-I/O`_ implementation is designed for integrating in third-party libraries, typically application servers, in addition being used internally by websockets. .. _Sans-I/O: https://sans-io.readthedocs.io/ -Here's an echo server with the :mod:`asyncio` API: +Here's an echo server using the :mod:`asyncio` API: .. literalinclude:: ../example/echo.py -Here's how a client sends and receives messages with the :mod:`threading` API: +Here's a client using the :mod:`threading` API: .. literalinclude:: ../example/hello.py From 14ca557f53cf19084eb64aef2e4563e5630c211b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 19:35:02 +0200 Subject: [PATCH 560/762] Proof-read tutorial. Switch to the new asyncio implementation. Change exception to ValueError when arguments have incorrect values. --- docs/intro/tutorial1.rst | 4 ++-- docs/intro/tutorial2.rst | 38 ++++++++++++++++-------------- example/tutorial/start/connect4.py | 6 ++--- example/tutorial/step1/app.py | 2 +- example/tutorial/step2/app.py | 7 +++--- example/tutorial/step3/app.py | 7 +++--- 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 74f5f79a3..6e91867c8 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -123,7 +123,7 @@ wins. Here's its API. :param player: :data:`~connect4.PLAYER1` or :data:`~connect4.PLAYER2`. :param column: between ``0`` and ``6``. :returns: Row where the checker lands, between ``0`` and ``5``. - :raises RuntimeError: if the move is illegal. + :raises ValueError: if the move is illegal. .. attribute:: moves @@ -520,7 +520,7 @@ Then, you're going to iterate over incoming messages and take these steps: interface sends; * play the move in the board with the :meth:`~connect4.Connect4.play` method, alternating between the two players; -* if :meth:`~connect4.Connect4.play` raises :exc:`RuntimeError` because the +* if :meth:`~connect4.Connect4.play` raises :exc:`ValueError` because the move is illegal, send an event of type ``"error"``; * else, send an event of type ``"play"`` to tell the user interface where the checker lands; diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index b8e35f292..b5d3a3dc8 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -84,7 +84,7 @@ When the second player joins the game, look it up: async def handler(websocket): ... - join_key = ... # TODO + join_key = ... # Find the Connect Four game. game, connected = JOIN[join_key] @@ -434,7 +434,7 @@ Once the initialization sequence is done, watching a game is as simple as registering the WebSocket connection in the ``connected`` set in order to receive game events and doing nothing until the spectator disconnects. You can wait for a connection to terminate with -:meth:`~legacy.protocol.WebSocketCommonProtocol.wait_closed`: +:meth:`~asyncio.server.ServerConnection.wait_closed`: .. code-block:: python @@ -482,38 +482,40 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`~legacy.protocol.broadcast` helper for this purpose: +the :func:`~asyncio.connection.broadcast` helper for this purpose: .. code-block:: python + from websockets.asyncio.connection import broadcast + async def handler(websocket): ... - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) ... -Calling :func:`legacy.protocol.broadcast` once is more efficient than -calling :meth:`~legacy.protocol.WebSocketCommonProtocol.send` in a loop. +Calling :func:`~asyncio.connection.broadcast` once is more efficient than +calling :meth:`~asyncio.server.ServerConnection.send` in a loop. However, there's a subtle difference in behavior. Did you notice that there's no -``await`` in the second version? Indeed, :func:`legacy.protocol.broadcast` is a -function, not a coroutine like -:meth:`~legacy.protocol.WebSocketCommonProtocol.send` or -:meth:`~legacy.protocol.WebSocketCommonProtocol.recv`. +``await`` in the second version? Indeed, :func:`~asyncio.connection.broadcast` +is a function, not a coroutine like +:meth:`~asyncio.server.ServerConnection.send` or +:meth:`~asyncio.server.ServerConnection.recv`. -It's quite obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.recv` +It's quite obvious why :meth:`~asyncio.server.ServerConnection.recv` is a coroutine. When you want to receive the next message, you have to wait until the client sends it and the network transmits it. -It's less obvious why :meth:`~legacy.protocol.WebSocketCommonProtocol.send` is +It's less obvious why :meth:`~asyncio.server.ServerConnection.send` is a coroutine. If you send many messages or large messages, you could write data faster than the network can transmit it or the client can read it. Then, outgoing data will pile up in buffers, which will consume memory and may crash your application. -To avoid this problem, :meth:`~legacy.protocol.WebSocketCommonProtocol.send` +To avoid this problem, :meth:`~asyncio.server.ServerConnection.send` waits until the write buffer drains. By slowing down the application as necessary, this ensures that the server doesn't send data too quickly. This is called backpressure and it's useful for building robust systems. @@ -522,12 +524,12 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`legacy.protocol.broadcast` doesn't wait until write buffers -drain. +That's why :func:`~asyncio.connection.broadcast` doesn't wait until write +buffers drain and therefore doesn't need to be a coroutine. -For our Connect Four game, there's no difference in practice: the total amount -of data sent on a connection for a game of Connect Four is less than 64 KB, -so the write buffer never fills up and backpressure never kicks in anyway. +For our Connect Four game, there's no difference in practice. The total amount +of data sent on a connection for a game of Connect Four is so small that the +write buffer cannot fill up. As a consequence, backpressure never kicks in. Summary ------- diff --git a/example/tutorial/start/connect4.py b/example/tutorial/start/connect4.py index 0a61e7c7e..104476962 100644 --- a/example/tutorial/start/connect4.py +++ b/example/tutorial/start/connect4.py @@ -43,15 +43,15 @@ def play(self, player, column): Returns the row where the checker lands. - Raises :exc:`RuntimeError` if the move is illegal. + Raises :exc:`ValueError` if the move is illegal. """ if player == self.last_player: - raise RuntimeError("It isn't your turn.") + raise ValueError("It isn't your turn.") row = self.top[column] if row == 6: - raise RuntimeError("This slot is full.") + raise ValueError("This slot is full.") self.moves.append((player, column, row)) self.top[column] += 1 diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index db69070a1..595a10dc7 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -26,7 +26,7 @@ async def handler(websocket): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. event = { "type": "error", diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index feaf223a0..86b2c88c3 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,6 +4,7 @@ import json import secrets +from websockets.asyncio.connection import broadcast from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -59,7 +60,7 @@ async def play(websocket, game, player, connected): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. await error(websocket, str(exc)) continue @@ -71,7 +72,7 @@ async def play(websocket, game, player, connected): "column": column, "row": row, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) # If move is winning, send a "win" event. if game.winner is not None: @@ -79,7 +80,7 @@ async def play(websocket, game, player, connected): "type": "win", "player": game.winner, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) async def start(websocket): diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index a428e29e7..34024d087 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,6 +6,7 @@ import secrets import signal +from websockets.asyncio.connection import broadcast from websockets.asyncio.server import serve from connect4 import PLAYER1, PLAYER2, Connect4 @@ -61,7 +62,7 @@ async def play(websocket, game, player, connected): try: # Play the move. row = game.play(player, column) - except RuntimeError as exc: + except ValueError as exc: # Send an "error" event if the move was illegal. await error(websocket, str(exc)) continue @@ -73,7 +74,7 @@ async def play(websocket, game, player, connected): "column": column, "row": row, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) # If move is winning, send a "win" event. if game.winner is not None: @@ -81,7 +82,7 @@ async def play(websocket, game, player, connected): "type": "win", "player": game.winner, } - websockets.broadcast(connected, json.dumps(event)) + broadcast(connected, json.dumps(event)) async def start(websocket): From 2a17e1dac6a4514f3663c4e15546ebe33bf90e4b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 21:49:46 +0200 Subject: [PATCH 561/762] Move broadcast() to the server module. Rationale: * It is only useful for servers. (Maybe there's a use case for client but I couldn't picture it.) * It was already documented in the page covering the server module. * Within this page, it was the only API from the connection or protocol module. The implementation remains in the connection or protocol module because moving it would require refactoring tests. I'd rather keep them simple. (And I'm lazy.) This change doesn't require a backwards compatibility shim because the documentated location of the legacy implementation of broadcast was websockets.broadcast, it's changing with the introduction of the new asyncio API, and the changes are already documented. --- docs/faq/server.rst | 4 ++-- docs/howto/patterns.rst | 2 +- docs/howto/upgrade.rst | 4 ++-- docs/intro/tutorial2.rst | 15 +++++++-------- docs/project/changelog.rst | 4 ++-- docs/reference/asyncio/server.rst | 2 +- docs/reference/features.rst | 2 ++ docs/reference/legacy/server.rst | 10 +++++----- docs/topics/broadcast.rst | 20 ++++++++++---------- docs/topics/logging.rst | 2 +- docs/topics/performance.rst | 4 ++-- example/django/notifications.py | 3 +-- example/quickstart/counter.py | 3 +-- example/quickstart/show_time_2.py | 3 +-- example/tutorial/step2/app.py | 3 +-- example/tutorial/step3/app.py | 3 +-- experiments/broadcast/server.py | 3 +-- src/websockets/__init__.py | 7 ++++--- src/websockets/asyncio/connection.py | 25 ++++++++++++++++++++----- src/websockets/asyncio/server.py | 4 ++-- src/websockets/legacy/protocol.py | 25 ++++++++++++++++++++----- src/websockets/legacy/server.py | 10 ++++++++-- tests/asyncio/test_connection.py | 1 + 23 files changed, 96 insertions(+), 63 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index e6b068316..66e81edfe 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -116,9 +116,9 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.asyncio.connection.broadcast`:: +Then, call :func:`~websockets.asyncio.server.broadcast`:: - from websockets.asyncio.connection import broadcast + from websockets.asyncio.server import broadcast def message_all(message): broadcast(CONNECTIONS, message) diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst index 60bc8ab42..bfb78b6ca 100644 --- a/docs/howto/patterns.rst +++ b/docs/howto/patterns.rst @@ -90,7 +90,7 @@ connect and unregister them when they disconnect:: connected.add(websocket) try: # Broadcast a message to all connected clients. - websockets.broadcast(connected, "Hello!") + broadcast(connected, "Hello!") await asyncio.sleep(10) finally: # Unregister. diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 40c8c5ec9..c5320155e 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -162,8 +162,8 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.connection.broadcast` | -| :func:`websockets.legacy.protocol.broadcast()` | | +| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.server.broadcast` | +| :func:`websockets.legacy.server.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | | ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | diff --git a/docs/intro/tutorial2.rst b/docs/intro/tutorial2.rst index b5d3a3dc8..0211615d1 100644 --- a/docs/intro/tutorial2.rst +++ b/docs/intro/tutorial2.rst @@ -482,11 +482,11 @@ you're using this pattern: ... Since this is a very common pattern in WebSocket servers, websockets provides -the :func:`~asyncio.connection.broadcast` helper for this purpose: +the :func:`~asyncio.server.broadcast` helper for this purpose: .. code-block:: python - from websockets.asyncio.connection import broadcast + from websockets.asyncio.server import broadcast async def handler(websocket): @@ -496,13 +496,12 @@ the :func:`~asyncio.connection.broadcast` helper for this purpose: ... -Calling :func:`~asyncio.connection.broadcast` once is more efficient than +Calling :func:`~asyncio.server.broadcast` once is more efficient than calling :meth:`~asyncio.server.ServerConnection.send` in a loop. However, there's a subtle difference in behavior. Did you notice that there's no -``await`` in the second version? Indeed, :func:`~asyncio.connection.broadcast` -is a function, not a coroutine like -:meth:`~asyncio.server.ServerConnection.send` or +``await`` in the second version? Indeed, :func:`~asyncio.server.broadcast` is a +function, not a coroutine like :meth:`~asyncio.server.ServerConnection.send` or :meth:`~asyncio.server.ServerConnection.recv`. It's quite obvious why :meth:`~asyncio.server.ServerConnection.recv` @@ -524,8 +523,8 @@ That said, when you're sending the same messages to many clients in a loop, applying backpressure in this way can become counterproductive. When you're broadcasting, you don't want to slow down everyone to the pace of the slowest clients; you want to drop clients that cannot keep up with the data stream. -That's why :func:`~asyncio.connection.broadcast` doesn't wait until write -buffers drain and therefore doesn't need to be a coroutine. +That's why :func:`~asyncio.server.broadcast` doesn't wait until write buffers +drain and therefore doesn't need to be a coroutine. For our Connect Four game, there's no difference in practice. The total amount of data sent on a connection for a game of Connect Four is so small that the diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index df5af54f4..e85c3a395 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -212,7 +212,7 @@ Improvements * Added platform-independent wheels. -* Improved error handling in :func:`~legacy.protocol.broadcast`. +* Improved error handling in :func:`~legacy.server.broadcast`. * Set ``server_hostname`` automatically on TLS connections when providing a ``sock`` argument to :func:`~sync.client.connect`. @@ -402,7 +402,7 @@ New features * Added compatibility with Python 3.10. -* Added :func:`~legacy.protocol.broadcast` to send a message to many clients. +* Added :func:`~legacy.server.broadcast` to send a message to many clients. * Added support for reconnecting automatically by using :func:`~legacy.client.connect` as an asynchronous iterator. diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 7bceca5a0..541c9952c 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -80,4 +80,4 @@ Using a connection Broadcast --------- -.. autofunction:: websockets.asyncio.connection.broadcast +.. autofunction:: websockets.asyncio.server.broadcast diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 6840fe15b..cb0e564f9 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -35,6 +35,8 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ + | Broadcast a message | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ | Receive a message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Iterate over received messages | ✅ | ✅ | — | ✅ | diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index c2758f5a2..b6c383ce7 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -89,6 +89,11 @@ Using a connection .. autoproperty:: close_reason +Broadcast +--------- + +.. autofunction:: websockets.legacy.server.broadcast + Basic authentication -------------------- @@ -106,8 +111,3 @@ websockets supports HTTP Basic Authentication according to .. autoattribute:: username .. automethod:: check_credentials - -Broadcast ---------- - -.. autofunction:: websockets.legacy.protocol.broadcast diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index ec358bbd2..c9699feb2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -4,19 +4,19 @@ Broadcasting .. currentmodule:: websockets .. admonition:: If you want to send a message to all connected clients, - use :func:`~asyncio.connection.broadcast`. + use :func:`~asyncio.server.broadcast`. :class: tip If you want to learn about its design, continue reading this document. For the legacy :mod:`asyncio` implementation, use - :func:`~legacy.protocol.broadcast`. + :func:`~legacy.server.broadcast`. WebSocket servers often send the same message to all connected clients or to a subset of clients for which the message is relevant. Let's explore options for broadcasting a message, explain the design of -:func:`~asyncio.connection.broadcast`, and discuss alternatives. +:func:`~asyncio.server.broadcast`, and discuss alternatives. For each option, we'll provide a connection handler called ``handler()`` and a function or coroutine called ``broadcast()`` that sends a message to all @@ -124,7 +124,7 @@ connections before the write buffer can fill up. Don't set extreme values of ``write_limit``, ``ping_interval``, or ``ping_timeout`` to ensure that this condition holds! Instead, set reasonable -values and use the built-in :func:`~asyncio.connection.broadcast` function. +values and use the built-in :func:`~asyncio.server.broadcast` function. The concurrent way ------------------ @@ -209,11 +209,11 @@ If a client gets too far behind, eventually it reaches the limit defined by ``ping_timeout`` and websockets terminates the connection. You can refer to the discussion of :doc:`keepalive ` for details. -How :func:`~asyncio.connection.broadcast` works ------------------------------------------------ +How :func:`~asyncio.server.broadcast` works +------------------------------------------- -The built-in :func:`~asyncio.connection.broadcast` function is similar to the -naive way. The main difference is that it doesn't apply backpressure. +The built-in :func:`~asyncio.server.broadcast` function is similar to the naive +way. The main difference is that it doesn't apply backpressure. This provides the best performance by avoiding the overhead of scheduling and running one task per client. @@ -324,7 +324,7 @@ the asynchronous iterator returned by ``subscribe()``. Performance considerations -------------------------- -The built-in :func:`~asyncio.connection.broadcast` function sends all messages +The built-in :func:`~asyncio.server.broadcast` function sends all messages without yielding control to the event loop. So does the naive way when the network and clients are fast and reliable. @@ -346,7 +346,7 @@ However, this isn't possible in general for two reasons: All other patterns discussed above yield control to the event loop once per client because messages are sent by different tasks. This makes them slower -than the built-in :func:`~asyncio.connection.broadcast` function. +than the built-in :func:`~asyncio.server.broadcast` function. There is no major difference between the performance of per-client queues and publish–subscribe. diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index cad49ba55..9580b4c50 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -221,7 +221,7 @@ Here's what websockets logs at each level. ``WARNING`` ........... -* Failures in :func:`~asyncio.connection.broadcast` +* Failures in :func:`~asyncio.server.broadcast` ``INFO`` ........ diff --git a/docs/topics/performance.rst b/docs/topics/performance.rst index b226cec43..b0828fe0d 100644 --- a/docs/topics/performance.rst +++ b/docs/topics/performance.rst @@ -18,5 +18,5 @@ application.) broadcast --------- -:func:`~asyncio.connection.broadcast` is the most efficient way to send a -message to many clients. +:func:`~asyncio.server.broadcast` is the most efficient way to send a message to +many clients. diff --git a/example/django/notifications.py b/example/django/notifications.py index 445438d2d..76ce9c2d7 100644 --- a/example/django/notifications.py +++ b/example/django/notifications.py @@ -10,8 +10,7 @@ from django.contrib.contenttypes.models import ContentType from sesame.utils import get_user -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from websockets.frames import CloseCode diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index d42069e64..91eedc56a 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -3,8 +3,7 @@ import asyncio import json import logging -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve logging.basicConfig() diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py index 4fa244a23..9c9659d14 100755 --- a/example/quickstart/show_time_2.py +++ b/example/quickstart/show_time_2.py @@ -4,8 +4,7 @@ import datetime import random -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve CONNECTIONS = set() diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index 86b2c88c3..ef3dd9483 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -4,8 +4,7 @@ import json import secrets -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from connect4 import PLAYER1, PLAYER2, Connect4 diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 34024d087..261057f9a 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -6,8 +6,7 @@ import secrets import signal -from websockets.asyncio.connection import broadcast -from websockets.asyncio.server import serve +from websockets.asyncio.server import broadcast, serve from connect4 import PLAYER1, PLAYER2, Connect4 diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index 0a5c82b3c..d5b50bd71 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -7,8 +7,7 @@ import time from websockets import ConnectionClosed -from websockets.asyncio.server import serve -from websockets.asyncio.connection import broadcast +from websockets.asyncio.server import broadcast, serve CLIENTS = set() diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index fdb028f4c..b618a6dff 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -48,10 +48,10 @@ "unix_connect", # .legacy.protocol "WebSocketCommonProtocol", - "broadcast", # .legacy.server "WebSocketServer", "WebSocketServerProtocol", + "broadcast", "serve", "unix_serve", # .server @@ -102,10 +102,11 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect - from .legacy.protocol import WebSocketCommonProtocol, broadcast + from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, WebSocketServerProtocol, + broadcast, serve, unix_serve, ) @@ -164,10 +165,10 @@ "unix_connect": ".legacy.client", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", - "broadcast": ".legacy.protocol", # .legacy.server "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", "serve": ".legacy.server", "unix_serve": ".legacy.server", # .server diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 284fe2124..a6b909c72 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -34,7 +34,7 @@ from .messages import Assembler -__all__ = ["Connection", "broadcast"] +__all__ = ["Connection"] class Connection(asyncio.Protocol): @@ -1011,6 +1011,12 @@ def eof_received(self) -> None: # Besides, that doesn't work on TLS connections. +# broadcast() is defined in the connection module even though it's primarily +# used by servers and documented in the server module because it works with +# client connections too and because it's easier to test together with the +# Connection class. + + def broadcast( connections: Iterable[Connection], message: Data, @@ -1034,10 +1040,11 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~Connection.send`, :func:`broadcast` doesn't support sending - fragmented messages. Indeed, fragmentation is useful for sending large - messages without buffering them in memory, while :func:`broadcast` buffers - one copy per connection as fast as possible. + Unlike :meth:`~websockets.asyncio.connection.Connection.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. @@ -1047,6 +1054,10 @@ def broadcast( set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. + While :func:`broadcast` makes more sense for servers, it works identically + with clients, if you have a use case for opening connections to many servers + and broadcasting a message to them. + Args: websockets: WebSocket connections to which the message will be sent. message: Message to send. @@ -1101,3 +1112,7 @@ def broadcast( if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) + + +# Pretend that broadcast is actually defined in the server module. +broadcast.__module__ = "websockets.asyncio.server" diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 1f55502bb..35637a18f 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -25,10 +25,10 @@ from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout -from .connection import Connection +from .connection import Connection, broadcast -__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "WebSocketServer"] class ServerConnection(Connection): diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 66eb94199..e83e146f9 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -53,7 +53,7 @@ from .framing import Frame -__all__ = ["WebSocketCommonProtocol", "broadcast"] +__all__ = ["WebSocketCommonProtocol"] # In order to ensure consistency, the code always checks the current value of @@ -1545,6 +1545,12 @@ def eof_received(self) -> None: self.reader.feed_eof() +# broadcast() is defined in the protocol module even though it's primarily +# used by servers and documented in the server module because it works with +# client connections too and because it's easier to test together with the +# WebSocketCommonProtocol class. + + def broadcast( websockets: Iterable[WebSocketCommonProtocol], message: Data, @@ -1568,10 +1574,11 @@ def broadcast( ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage from slow connections. - Unlike :meth:`~WebSocketCommonProtocol.send`, :func:`broadcast` doesn't - support sending fragmented messages. Indeed, fragmentation is useful for - sending large messages without buffering them in memory, while - :func:`broadcast` buffers one copy per connection as fast as possible. + Unlike :meth:`~websockets.legacy.protocol.WebSocketCommonProtocol.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. :func:`broadcast` skips connections that aren't open in order to avoid errors on connections where the closing handshake is in progress. @@ -1581,6 +1588,10 @@ def broadcast( set ``raise_exceptions`` to :obj:`True` to record failures and raise all exceptions in a :pep:`654` :exc:`ExceptionGroup`. + While :func:`broadcast` makes more sense for servers, it works identically + with clients, if you have a use case for opening connections to many servers + and broadcasting a message to them. + Args: websockets: WebSocket connections to which the message will be sent. message: Message to send. @@ -1629,3 +1640,7 @@ def broadcast( if raise_exceptions and exceptions: raise ExceptionGroup("skipped broadcast", exceptions) + + +# Pretend that broadcast is actually defined in the server module. +broadcast.__module__ = "websockets.legacy.server" diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index d230f009e..43136db3e 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -45,10 +45,16 @@ from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol from .handshake import build_response, check_request from .http import read_request -from .protocol import WebSocketCommonProtocol +from .protocol import WebSocketCommonProtocol, broadcast -__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] +__all__ = [ + "broadcast", + "serve", + "unix_serve", + "WebSocketServerProtocol", + "WebSocketServer", +] # Change to HeadersLike | ... when dropping Python < 3.10. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 3cdfeb21b..52e4fc5c8 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -9,6 +9,7 @@ from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * +from websockets.asyncio.connection import broadcast from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol, State From b05fa2cceefcc5cfaba4e0e06e40d588505c8334 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 22:17:14 +0200 Subject: [PATCH 562/762] Rename WebSocketServer to Server. The shorter name is better. Representation remains unambiguous: The change isn't applied to the legacy implementation because it has longer names for other API too. --- docs/faq/server.rst | 3 +-- docs/howto/upgrade.rst | 2 +- docs/project/changelog.rst | 30 ++++++++++++++-------- docs/reference/asyncio/server.rst | 2 +- docs/reference/sync/server.rst | 2 +- src/websockets/asyncio/server.py | 42 +++++++++++++++---------------- src/websockets/sync/server.py | 28 ++++++++++++++------- tests/asyncio/client.py | 4 +-- tests/asyncio/test_server.py | 2 +- tests/sync/client.py | 4 +-- tests/sync/test_server.py | 11 +++++--- 11 files changed, 77 insertions(+), 53 deletions(-) diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 66e81edfe..63eb5ffc6 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -313,8 +313,7 @@ Here's an example that terminates cleanly when it receives SIGTERM on Unix: How do I stop a server while keeping existing connections open? --------------------------------------------------------------- -Call the server's :meth:`~WebSocketServer.close` method with -``close_connections=False``. +Call the server's :meth:`~Server.close` method with ``close_connections=False``. Here's how to adapt the example just above:: diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index c5320155e..8d0895638 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -154,7 +154,7 @@ Server APIs | ``websockets.server.unix_serve()`` |br| | | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.WebSocketServer` | +| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | | ``websockets.server.WebSocketServer`` |br| | | | :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index e85c3a395..f4ae76702 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,16 +35,6 @@ notice. Backwards-incompatible changes .............................. -.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` - and :func:`~sync.server.serve` in the :mod:`threading` implementation is - renamed to ``ssl``. - :class: note - - This aligns the API of the :mod:`threading` implementation with the - :mod:`asyncio` implementation. - - For backwards compatibility, ``ssl_context`` is still supported. - .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note @@ -60,6 +50,26 @@ Backwards-incompatible changes path = request.path # only if handler() uses the path argument ... +.. admonition:: The ``ssl_context`` argument of :func:`~sync.client.connect` + and :func:`~sync.server.serve` in the :mod:`threading` implementation is + renamed to ``ssl``. + :class: note + + This aligns the API of the :mod:`threading` implementation with the + :mod:`asyncio` implementation. + + For backwards compatibility, ``ssl_context`` is still supported. + +.. admonition:: The ``WebSocketServer`` class in the :mod:`threading` + implementation is renamed to :class:`~sync.server.Server`. + :class: note + + This class isn't designed to be imported or instantiated directly. + :func:`~sync.server.serve` returns an instance. For this reason, + the change should be transparent. + + Regardless, an alias provides backwards compatibility. + New features ............ diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 541c9952c..bd5a34b19 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -15,7 +15,7 @@ Creating a server Running a server ---------------- -.. autoclass:: WebSocketServer +.. autoclass:: Server .. automethod:: close diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 26ab872c8..23cb04097 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -13,7 +13,7 @@ Creating a server Running a server ---------------- -.. autoclass:: WebSocketServer +.. autoclass:: Server .. automethod:: serve_forever diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 35637a18f..8ebbddb67 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -28,7 +28,7 @@ from .connection import Connection, broadcast -__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "Server"] class ServerConnection(Connection): @@ -60,7 +60,7 @@ class ServerConnection(Connection): def __init__( self, protocol: ServerProtocol, - server: WebSocketServer, + server: Server, *, ping_interval: float | None = 20, ping_timeout: float | None = 20, @@ -223,11 +223,11 @@ def connection_lost(self, exc: Exception | None) -> None: self.request_rcvd.set_result(None) -class WebSocketServer: +class Server: """ WebSocket server returned by :func:`serve`. - This class mirrors the API of :class:`~asyncio.Server`. + This class mirrors the API of :class:`asyncio.Server`. It keeps track of WebSocket connections in order to close them properly when shutting down. @@ -299,16 +299,16 @@ def __init__( def wrap(self, server: asyncio.Server) -> None: """ - Attach to a given :class:`~asyncio.Server`. + Attach to a given :class:`asyncio.Server`. Since :meth:`~asyncio.loop.create_server` doesn't support injecting a custom ``Server`` class, the easiest solution that doesn't rely on private :mod:`asyncio` APIs is to: - - instantiate a :class:`WebSocketServer` + - instantiate a :class:`Server` - give the protocol factory a reference to that instance - call :meth:`~asyncio.loop.create_server` with the factory - - attach the resulting :class:`~asyncio.Server` with this method + - attach the resulting :class:`asyncio.Server` with this method """ self.server = server @@ -378,7 +378,7 @@ def close(self, close_connections: bool = True) -> None: """ Close the server. - * Close the underlying :class:`~asyncio.Server`. + * Close the underlying :class:`asyncio.Server`. * When ``close_connections`` is :obj:`True`, which is the default, close existing connections. Specifically: @@ -402,7 +402,7 @@ async def _close(self, close_connections: bool) -> None: Implementation of :meth:`close`. This calls :meth:`~asyncio.Server.close` on the underlying - :class:`~asyncio.Server` object to stop accepting new connections and + :class:`asyncio.Server` object to stop accepting new connections and then closes open connections with close code 1001. """ @@ -516,7 +516,7 @@ def sockets(self) -> Iterable[socket.socket]: """ return self.server.sockets - async def __aenter__(self) -> WebSocketServer: # pragma: no cover + async def __aenter__(self) -> Server: # pragma: no cover return self async def __aexit__( @@ -543,8 +543,8 @@ class serve: Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - This coroutine returns a :class:`WebSocketServer` whose API mirrors - :class:`~asyncio.Server`. Treat it as an asynchronous context manager to + This coroutine returns a :class:`Server` whose API mirrors + :class:`asyncio.Server`. Treat it as an asynchronous context manager to ensure that the server will be closed:: def handler(websocket): @@ -556,8 +556,8 @@ def handler(websocket): async with websockets.asyncio.server.serve(handler, host, port): await stop - Alternatively, call :meth:`~WebSocketServer.serve_forever` to serve requests - and cancel it to stop the server:: + Alternatively, call :meth:`~Server.serve_forever` to serve requests and + cancel it to stop the server:: server = await websockets.asyncio.server.serve(handler, host, port) await server.serve_forever() @@ -638,8 +638,8 @@ def handler(websocket): socket and customize it. * You can set ``start_serving`` to ``False`` to start accepting connections - only after you call :meth:`~WebSocketServer.start_serving()` or - :meth:`~WebSocketServer.serve_forever()`. + only after you call :meth:`~Server.start_serving()` or + :meth:`~Server.serve_forever()`. """ @@ -704,7 +704,7 @@ def __init__( if create_connection is None: create_connection = ServerConnection - self.server = WebSocketServer( + self.server = Server( handler, process_request=process_request, process_response=process_response, @@ -773,7 +773,7 @@ def protocol_select_subprotocol( # async with serve(...) as ...: ... - async def __aenter__(self) -> WebSocketServer: + async def __aenter__(self) -> Server: return await self async def __aexit__( @@ -787,11 +787,11 @@ async def __aexit__( # ... = await serve(...) - def __await__(self) -> Generator[Any, None, WebSocketServer]: + def __await__(self) -> Generator[Any, None, Server]: # Create a suitable iterator by calling __await__ on a coroutine. return self.__await_impl__().__await__() - async def __await_impl__(self) -> WebSocketServer: + async def __await_impl__(self) -> Server: server = await self._create_server self.server.wrap(server) return self.server @@ -805,7 +805,7 @@ def unix_serve( handler: Callable[[ServerConnection], Awaitable[None]], path: str | None = None, **kwargs: Any, -) -> Awaitable[WebSocketServer]: +) -> Awaitable[Server]: """ Create a WebSocket server listening on a Unix socket. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index b381908ca..85a7e9907 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -24,7 +24,7 @@ from .utils import Deadline -__all__ = ["serve", "unix_serve", "ServerConnection", "WebSocketServer"] +__all__ = ["serve", "unix_serve", "ServerConnection", "Server"] class ServerConnection(Connection): @@ -196,7 +196,7 @@ def recv_events(self) -> None: self.request_rcvd.set() -class WebSocketServer: +class Server: """ WebSocket server returned by :func:`serve`. @@ -283,7 +283,7 @@ def fileno(self) -> int: """ return self.socket.fileno() - def __enter__(self) -> WebSocketServer: + def __enter__(self) -> Server: return self def __exit__( @@ -295,6 +295,16 @@ def __exit__( self.shutdown() +def __getattr__(name: str) -> Any: + if name == "WebSocketServer": + warnings.warn( + "WebSocketServer was renamed to Server", + DeprecationWarning, + ) + return Server + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + def serve( handler: Callable[[ServerConnection], None], host: str | None = None, @@ -340,7 +350,7 @@ def serve( # Escape hatch for advanced customization create_connection: type[ServerConnection] | None = None, **kwargs: Any, -) -> WebSocketServer: +) -> Server: """ Create a WebSocket server listening on ``host`` and ``port``. @@ -353,10 +363,10 @@ def serve( Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - This function returns a :class:`WebSocketServer` whose API mirrors + This function returns a :class:`Server` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call :meth:`~WebSocketServer.serve_forever` to - serve requests:: + that it will be closed and call :meth:`~Server.serve_forever` to serve + requests:: def handler(websocket): ... @@ -552,14 +562,14 @@ def protocol_select_subprotocol( # Initialize server - return WebSocketServer(sock, conn_handler, logger) + return Server(sock, conn_handler, logger) def unix_serve( handler: Callable[[ServerConnection], None], path: str | None = None, **kwargs: Any, -) -> WebSocketServer: +) -> Server: """ Create a WebSocket server listening on a Unix socket. diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py index e5826add7..a73079c6e 100644 --- a/tests/asyncio/client.py +++ b/tests/asyncio/client.py @@ -1,7 +1,7 @@ import contextlib from websockets.asyncio.client import * -from websockets.asyncio.server import WebSocketServer +from websockets.asyncio.server import Server from .server import get_server_host_port @@ -17,7 +17,7 @@ async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): wsuri = wsuri_or_server else: - assert isinstance(wsuri_or_server, WebSocketServer) + assert isinstance(wsuri_or_server, Server) if secure is None: secure = "ssl" in kwargs protocol = "wss" if secure else "ws" diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 5d4f0e2f8..4b637f3af 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -535,7 +535,7 @@ async def test_unsupported_compression(self): class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): async def test_logger(self): - """WebSocketServer accepts a logger argument.""" + """Server accepts a logger argument.""" logger = logging.getLogger("test") async with run_server(logger=logger) as server: self.assertIs(server.logger, logger) diff --git a/tests/sync/client.py b/tests/sync/client.py index 72eb5b8d2..acbf97fa7 100644 --- a/tests/sync/client.py +++ b/tests/sync/client.py @@ -1,7 +1,7 @@ import contextlib from websockets.sync.client import * -from websockets.sync.server import WebSocketServer +from websockets.sync.server import Server __all__ = [ @@ -15,7 +15,7 @@ def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): if isinstance(wsuri_or_server, str): wsuri = wsuri_or_server else: - assert isinstance(wsuri_or_server, WebSocketServer) + assert isinstance(wsuri_or_server, Server) if secure is None: # Backwards compatibility: ssl used to be called ssl_context. secure = "ssl" in kwargs or "ssl_context" in kwargs diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index c0a5f01e6..315601eca 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -383,18 +383,18 @@ def test_unsupported_compression(self): class WebSocketServerTests(unittest.TestCase): def test_logger(self): - """WebSocketServer accepts a logger argument.""" + """Server accepts a logger argument.""" logger = logging.getLogger("test") with run_server(logger=logger) as server: self.assertIs(server.logger, logger) def test_fileno(self): - """WebSocketServer provides a fileno attribute.""" + """Server provides a fileno attribute.""" with run_server() as server: self.assertIsInstance(server.fileno(), int) def test_shutdown(self): - """WebSocketServer provides a shutdown method.""" + """Server provides a shutdown method.""" with run_server() as server: server.shutdown() # Check that the server socket is closed. @@ -409,3 +409,8 @@ def test_ssl_context_argument(self): with run_server(ssl_context=SERVER_CONTEXT) as server: with run_client(server, ssl=CLIENT_CONTEXT): pass + + def test_web_socket_server_class(self): + with self.assertDeprecationWarning("WebSocketServer was renamed to Server"): + from websockets.sync.server import WebSocketServer + self.assertIs(WebSocketServer, Server) From 09b1d8d4d585ed6d4e2c0db6e200e48f176215b1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 19 Aug 2024 22:26:56 +0200 Subject: [PATCH 563/762] Fix tests on Python < 3.10. --- tests/asyncio/test_connection.py | 17 +++++++++++++++++ tests/legacy/utils.py | 31 ++++++++++++++++--------------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 52e4fc5c8..29bb00418 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -44,6 +44,23 @@ async def asyncTearDown(self): await self.remote_connection.close() await self.connection.close() + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger="websockets", level=logging.ERROR): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + # Test helpers built upon RecordingProtocol and InterceptingConnection. async def assertFrameSent(self, frame): diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 5bb56b26f..5b56050d5 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -56,21 +56,22 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - # Remove when dropping Python < 3.10 - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger="websockets", level=logging.ERROR): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): """ From 8eaa5a26b667fccb1b9d75034a77b2d6906e2b2e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 07:49:47 +0200 Subject: [PATCH 564/762] Document & test process_response modifying the response. a78b5546 inadvertently changed the test from "returning a new response" to "modifying the existing response". Both are supported.. --- src/websockets/asyncio/server.py | 11 +++--- src/websockets/sync/server.py | 13 ++++--- tests/asyncio/test_server.py | 65 +++++++++++++++++++++----------- tests/sync/test_server.py | 33 ++++++++++------ 4 files changed, 78 insertions(+), 44 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 8ebbddb67..8f04ec318 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -236,15 +236,16 @@ class Server: handler: Connection handler. It receives the WebSocket connection, which is a :class:`ServerConnection`, in argument. process_request: Intercept the request during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to + Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the handshake is successful. Else, the connection is aborted. ``process_request`` may be a function or a coroutine. process_response: Intercept the response during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, the - handshake is successful. Else, the connection is aborted. - ``process_response`` may be a function or a coroutine. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. ``process_response`` may be a function or a + coroutine. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 85a7e9907..86c162af3 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -401,13 +401,14 @@ def handler(websocket): :meth:`ServerProtocol.select_subprotocol ` method. process_request: Intercept the request during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, - the handshake is successful. Else, the connection is aborted. + Return an HTTP response to force the response. Return :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. process_response: Intercept the response during the opening handshake. - Return an HTTP response to force the response or :obj:`None` to - continue normally. When you force an HTTP 101 Continue response, - the handshake is successful. Else, the connection is aborted. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 4b637f3af..b899998f4 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import http import logging import socket @@ -117,8 +118,8 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) - async def test_process_request(self): - """Server runs process_request before processing the handshake.""" + async def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" def process_request(ws, request): self.assertIsInstance(request, Request) @@ -128,8 +129,8 @@ def process_request(ws, request): async with run_client(server) as client: await self.assertEval(client, "ws.process_request_ran", "True") - async def test_async_process_request(self): - """Server runs async process_request before processing the handshake.""" + async def test_async_process_request_returns_none(self): + """Server runs async process_request and continues the handshake.""" async def process_request(ws, request): self.assertIsInstance(request, Request) @@ -139,7 +140,7 @@ async def process_request(ws, request): async with run_client(server) as client: await self.assertEval(client, "ws.process_request_ran", "True") - async def test_process_request_abort_handshake(self): + async def test_process_request_returns_response(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): @@ -154,7 +155,7 @@ def process_request(ws, request): "server rejected WebSocket connection: HTTP 403", ) - async def test_async_process_request_abort_handshake(self): + async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" async def process_request(ws, request): @@ -199,8 +200,8 @@ async def process_request(ws, request): "server rejected WebSocket connection: HTTP 500", ) - async def test_process_response(self): - """Server runs process_response after processing the handshake.""" + async def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -211,8 +212,8 @@ def process_response(ws, request, response): async with run_client(server) as client: await self.assertEval(client, "ws.process_response_ran", "True") - async def test_async_process_response(self): - """Server runs async process_response after processing the handshake.""" + async def test_async_process_response_returns_none(self): + """Server runs async process_response but keeps the handshake response.""" async def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -223,29 +224,49 @@ async def process_response(ws, request, response): async with run_client(server) as client: await self.assertEval(client, "ws.process_response_ran", "True") - async def test_process_response_override_response(self): - """Server runs process_response and overrides the handshake response.""" + async def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: async with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") - async def test_async_process_response_override_response(self): - """Server runs async process_response and overrides the handshake response.""" + async def test_async_process_response_modifies_response(self): + """Server runs async process_response and modifies the handshake response.""" async def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: async with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_replaces_response(self): + """Server runs async process_response and replaces the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 315601eca..e3dfeb271 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,3 +1,4 @@ +import dataclasses import http import logging import socket @@ -115,8 +116,8 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) - def test_process_request(self): - """Server runs process_request before processing the handshake.""" + def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" def process_request(ws, request): self.assertIsInstance(request, Request) @@ -126,7 +127,7 @@ def process_request(ws, request): with run_client(server) as client: self.assertEval(client, "ws.process_request_ran", "True") - def test_process_request_abort_handshake(self): + def test_process_request_returns_response(self): """Server aborts handshake if process_request returns a response.""" def process_request(ws, request): @@ -156,8 +157,8 @@ def process_request(ws, request): "server rejected WebSocket connection: HTTP 500", ) - def test_process_response(self): - """Server runs process_response after processing the handshake.""" + def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" def process_response(ws, request, response): self.assertIsInstance(request, Request) @@ -168,17 +169,27 @@ def process_response(ws, request, response): with run_client(server) as client: self.assertEval(client, "ws.process_response_ran", "True") - def test_process_response_override_response(self): - """Server runs process_response and overrides the handshake response.""" + def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" def process_response(ws, request, response): - response.headers["X-ProcessResponse-Ran"] = "true" + response.headers["X-ProcessResponse"] = "OK" with run_server(process_response=process_response) as server: with run_client(server) as client: - self.assertEqual( - client.response.headers["X-ProcessResponse-Ran"], "true" - ) + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + with run_server(process_response=process_response) as server: + with run_client(server) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" From 9e5b91bf8f9039de0af85e597e7fd643cfd1a139 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 08:26:44 +0200 Subject: [PATCH 565/762] Improve documentation of latency. Also fix #1414. --- docs/reference/features.rst | 2 ++ docs/topics/keepalive.rst | 28 ++++++++++++++++++++-------- src/websockets/asyncio/connection.py | 9 +++++---- src/websockets/legacy/protocol.py | 9 +++++---- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index cb0e564f9..a380f4555 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -59,6 +59,8 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ + | Measure latency | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 1c7a43264..91f11fb11 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -40,13 +40,16 @@ It loops through these steps: If the Pong frame isn't received, websockets considers the connection broken and closes it. -This mechanism serves two purposes: +This mechanism serves three purposes: 1. It creates a trickle of traffic so that the TCP connection isn't idle and network infrastructure along the path keeps it open ("keepalive"). 2. It detects if the connection drops or becomes so slow that it's unusable in practice ("heartbeat"). In that case, it terminates the connection and your application gets a :exc:`~exceptions.ConnectionClosed` exception. +3. It measures the :attr:`~asyncio.connection.Connection.latency` of the + connection. The time between sending a Ping frame and receiving a matching + Pong frame approximates the round-trip time. Timings are configurable with the ``ping_interval`` and ``ping_timeout`` arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve`. @@ -54,7 +57,7 @@ Shorter values will detect connection drops faster but they will increase network traffic and they will be more sensitive to latency. Setting ``ping_interval`` to :obj:`None` disables the whole keepalive and -heartbeat mechanism. +heartbeat mechanism, including measurement of latency. Setting ``ping_timeout`` to :obj:`None` disables only timeouts. This enables keepalive, to keep idle connections open, and disables heartbeat, to support large @@ -85,9 +88,23 @@ Unfortunately, the WebSocket API in browsers doesn't expose the native Ping and Pong functionality in the WebSocket protocol. You have to roll your own in the application layer. +Read this `blog post `_ for +a complete walk-through of this issue. + Latency issues -------------- +The :attr:`~asyncio.connection.Connection.latency` attribute stores latency +measured during the last exchange of Ping and Pong frames:: + + latency = websocket.latency + +Alternatively, you can measure the latency at any time by calling +:attr:`~asyncio.connection.Connection.ping` and awaiting its result:: + + pong_waiter = await websocket.ping() + latency = await pong_waiter + Latency between a client and a server may increase for two reasons: * Network connectivity is poor. When network packets are lost, TCP attempts to @@ -97,7 +114,7 @@ Latency between a client and a server may increase for two reasons: * Traffic is high. For example, if a client sends messages on the connection faster than a server can process them, this manifests as latency as well, - because data is waiting in flight, mostly in OS buffers. + because data is waiting in :doc:`buffers `. If the server is more than 20 seconds behind, it doesn't see the Pong before the default timeout elapses. As a consequence, it closes the connection. @@ -109,8 +126,3 @@ Latency between a client and a server may increase for two reasons: The same reasoning applies to situations where the server sends more traffic than the client can accept. - -The latency measured during the last exchange of Ping and Pong frames is -available in the :attr:`~asyncio.connection.Connection.latency` attribute. -Alternatively, you can measure the latency at any time with the -:attr:`~asyncio.connection.Connection.ping` method. diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index a6b909c72..9e7ea3d8c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -109,12 +109,13 @@ def __init__( """ Latency of the connection, in seconds. - This value is updated after sending a ping frame and receiving a - matching pong frame. Before the first ping, :attr:`latency` is ``0``. + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism - that sends ping frames automatically at regular intervals. You can also - send ping frames and measure latency with :meth:`ping`. + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ # Task that sends keepalive pings. None when ping_interval is None. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index e83e146f9..3b9a8c4aa 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -289,12 +289,13 @@ def __init__( """ Latency of the connection, in seconds. - This value is updated after sending a ping frame and receiving a - matching pong frame. Before the first ping, :attr:`latency` is ``0``. + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. By default, websockets enables a :ref:`keepalive ` mechanism - that sends ping frames automatically at regular intervals. You can also - send ping frames and measure latency with :meth:`ping`. + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ # Task running the data transfer. From 453e55ac2a20a50bfd30c0b4c011c50d01e7bb0a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 09:52:31 +0200 Subject: [PATCH 566/762] Standardize on raise AssertionError(...). --- experiments/optimization/parse_frames.py | 2 +- experiments/optimization/parse_handshake.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/experiments/optimization/parse_frames.py b/experiments/optimization/parse_frames.py index e3acbe3c2..9ea71c58e 100644 --- a/experiments/optimization/parse_frames.py +++ b/experiments/optimization/parse_frames.py @@ -33,7 +33,7 @@ def parse_frame(data, count, mask, extensions): except StopIteration: pass else: - assert False, "parser should return frame" + raise AssertionError("parser should return frame") reader.feed_eof() assert reader.at_eof(), "parser should consume all data" diff --git a/experiments/optimization/parse_handshake.py b/experiments/optimization/parse_handshake.py index af5a4ecae..393e0215c 100644 --- a/experiments/optimization/parse_handshake.py +++ b/experiments/optimization/parse_handshake.py @@ -71,7 +71,7 @@ def parse_handshake(handshake): except StopIteration: pass else: - assert False, "parser should return request" + raise AssertionError("parser should return request") reader.feed_eof() assert reader.at_eof(), "parser should consume all data" From 9d355bfeb784e41b2a879645113c73c4560c9a91 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 09:53:05 +0200 Subject: [PATCH 567/762] Remove unnecessary code paths in keepalive(). Also add comments in tests to clarify the intended sequence. --- src/websockets/asyncio/connection.py | 14 +++++--- tests/asyncio/test_connection.py | 50 ++++++++++++++++++---------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 9e7ea3d8c..005e9b4bb 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -723,6 +723,10 @@ async def keepalive(self) -> None: if self.ping_timeout is not None: try: async with asyncio_timeout(self.ping_timeout): + # connection_lost cancels keepalive immediately + # after setting a ConnectionClosed exception on + # pong_waiter. A CancelledError is raised here, + # not a ConnectionClosed exception. latency = await pong_waiter self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: @@ -733,9 +737,10 @@ async def keepalive(self) -> None: CloseCode.INTERNAL_ERROR, "keepalive ping timeout", ) - break - except ConnectionClosed: - pass + raise AssertionError( + "send_context() should wait for connection_lost(), " + "which cancels keepalive()" + ) except Exception: self.logger.error("keepalive ping failed", exc_info=True) @@ -913,8 +918,7 @@ def connection_lost(self, exc: Exception | None) -> None: self.set_recv_exc(exc) self.recv_messages.close() self.abort_pings() - # If keepalive() was waiting for a pong, abort_pings() terminated it. - # If it was sleeping until the next ping, we need to cancel it now + if self.keepalive_task is not None: self.keepalive_task.cancel() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 29bb00418..59218de4b 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -890,12 +890,25 @@ async def test_pong_explicit_binary(self): @patch("random.getrandbits") async def test_keepalive(self, getrandbits): - """keepalive sends pings.""" + """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 2 * MS getrandbits.return_value = 1918987876 self.connection.start_keepalive() + self.assertEqual(self.connection.latency, 0) + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is received. await asyncio.sleep(3 * MS) + # 3 ms: check that the ping frame was sent. await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + await asyncio.sleep(3 * MS) + await self.assertNoFrameSent() @patch("random.getrandbits") async def test_keepalive_times_out(self, getrandbits): @@ -905,13 +918,14 @@ async def test_keepalive_times_out(self, getrandbits): async with self.drop_frames_rcvd(): getrandbits.return_value = 1918987876 self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for MS. - await self.assertFrameSent(Frame(Opcode.PING, b"rand")) - await asyncio.sleep(MS) - await self.assertFrameSent( - Frame(Opcode.CLOSE, b"\x03\xf3keepalive ping timeout") - ) + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection is closed. + await asyncio.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) @patch("random.getrandbits") async def test_keepalive_ignores_timeout(self, getrandbits): @@ -921,18 +935,14 @@ async def test_keepalive_ignores_timeout(self, getrandbits): async with self.drop_frames_rcvd(): getrandbits.return_value = 1918987876 self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) # Exiting the context manager sleeps for MS. - await self.assertFrameSent(Frame(Opcode.PING, b"rand")) - await asyncio.sleep(MS) - await self.assertNoFrameSent() - - async def test_disable_keepalive(self): - """keepalive is disabled when ping_interval is None.""" - self.connection.ping_interval = None - self.connection.start_keepalive() - await asyncio.sleep(3 * MS) - await self.assertNoFrameSent() + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection remains open. + await asyncio.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" @@ -945,21 +955,27 @@ async def test_keepalive_terminates_while_sleeping(self): async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = 2 * MS + self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for MS. + # 2.x ms: a pong frame is dropped. + # 3 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) async def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS - # Inject a fault by raising an exception in a pending pong waiter. async with self.drop_frames_rcvd(): self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for MS. + # 2.x ms: a pong frame is dropped. + # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: pong_waiter.set_exception(Exception("BOOM")) From 12fa8bc8fcc03a120ceb05700905b6e4698df563 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:24:25 +0200 Subject: [PATCH 568/762] Complete changelog with changes since 12.0. --- docs/project/changelog.rst | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f4ae76702..06d8a7774 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,11 @@ notice. Backwards-incompatible changes .............................. +.. admonition:: websockets 13.0 requires Python ≥ 3.8. + :class: tip + + websockets 12.0 is the last version supporting Python 3.7. + .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note @@ -64,9 +69,8 @@ Backwards-incompatible changes implementation is renamed to :class:`~sync.server.Server`. :class: note - This class isn't designed to be imported or instantiated directly. - :func:`~sync.server.serve` returns an instance. For this reason, - the change should be transparent. + This change should be transparent because this class shouldn't be + instantiated directly; :func:`~sync.server.serve` returns an instance. Regardless, an alias provides backwards compatibility. @@ -91,6 +95,22 @@ New features If you were monkey-patching constants, be aware that they were renamed, which will break your configuration. You must switch to the environment variables. +Improvements +............ + +* The error message in server logs when a header is too long is more explicit. + +Bug fixes +......... + +* Fixed a bug in the :mod:`threading` implementation that could prevent the + program from exiting when a connection wasn't closed properly. + +* Redirecting from a ``ws://`` URI to a ``wss://`` URI now works. + +* ``broadcast(raise_exceptions=True)`` no longer crashes when there isn't any + exception. + .. _12.0: 12.0 From 0019943e551d285ec27c29315a23dcc959a2ec29 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:39:43 +0200 Subject: [PATCH 569/762] Release version 13.0. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 06d8a7774..7c5998288 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 13.0 ---- -*In development* +*August 20, 2024* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 44709a91b..56c321940 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True -tag = version = commit = "12.1" +tag = version = commit = "13.0" if not released: # pragma: no cover From 4d0e0e10c6ebb779780a9e590667661381df78dc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 14:55:32 +0200 Subject: [PATCH 570/762] Build sdist and arch-independent wheel with build. This removes the dependency on setuptools, which isn't installed by default anymore, causing the build to fail. --- .github/workflows/release.yml | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ed52ddd80..184444e56 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,22 +17,18 @@ jobs: uses: actions/setup-python@v5 with: python-version: 3.x - - name: Build sdist - run: python setup.py sdist - - name: Save sdist - uses: actions/upload-artifact@v4 - with: - path: dist/*.tar.gz - - name: Install wheel - run: pip install wheel - - name: Build wheel + - name: Install build + run: pip install build + - name: Build sdist & wheel + run: python -m build env: BUILD_EXTENSION: no - run: python setup.py bdist_wheel - - name: Save wheel + - name: Save sdist & wheel uses: actions/upload-artifact@v4 with: - path: dist/*.whl + path: | + dist/*.tar.gz + dist/*.whl wheels: name: Build architecture-specific wheels on ${{ matrix.os }} From f9c20d0e4c9a25b66d4643879bc4594137036793 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 15:06:07 +0200 Subject: [PATCH 571/762] Avoid deleting .so files in .direnv or equivalent. --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index dacfe2a0b..a69248b6e 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,7 @@ build: python setup.py build_ext --inplace clean: - find . -name '*.pyc' -delete -o -name '*.so' -delete + find src -name '*.so' -delete + find . -name '*.pyc' -delete find . -name __pycache__ -delete rm -rf .coverage .mypy_cache build compliance/reports dist docs/_build htmlcov MANIFEST src/websockets.egg-info From 323adef1f3000cf07617d7ee649c27c0801126e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 15:19:28 +0200 Subject: [PATCH 572/762] Migrate to actions/upload-artifact@v4. The version number was increased without accounting for backwards-incompatible changes. --- .github/workflows/release.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 184444e56..4d2b5b75e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,6 +26,7 @@ jobs: - name: Save sdist & wheel uses: actions/upload-artifact@v4 with: + name: dist-architecture-independent path: | dist/*.tar.gz dist/*.whl @@ -58,6 +59,7 @@ jobs: - name: Save wheels uses: actions/upload-artifact@v4 with: + name: dist-${{ matrix.os }} path: wheelhouse/*.whl upload: @@ -74,7 +76,8 @@ jobs: - name: Download artifacts uses: actions/download-artifact@v4 with: - name: artifact + pattern: dist-* + merge-multiple: true path: dist - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 From ed2f21e3ce5bb494e0c9a51833b7da6692b5b9fc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 16:12:31 +0200 Subject: [PATCH 573/762] Attempt to fix automatic creation of GitHub release. --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4d2b5b75e..a68714ed1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -71,7 +71,7 @@ jobs: # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') permissions: - id-token: write + contents: write steps: - name: Download artifacts uses: actions/download-artifact@v4 From 3944595a1f1de2271b2f682f65b534f27628d3b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 20 Aug 2024 16:13:40 +0200 Subject: [PATCH 574/762] Start version 13.1. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7c5998288..955136ac0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _13.1: + +13.1 +---- + +*In development* + .. _13.0: 13.0 diff --git a/src/websockets/version.py b/src/websockets/version.py index 56c321940..bbda56d6b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "13.0" +tag = version = commit = "13.1" if not released: # pragma: no cover From 6b1cc94caa2aa8f0395b8d8d2e6e9e211c927e8c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 21 Aug 2024 08:37:37 +0200 Subject: [PATCH 575/762] Move prepare_data/ctrl back to the legacy framing module. They are deprecated and they must go away with the legacy implementation. They were only documented in the framing module, not in the new frames module, so this doesn't require more backwards-compatibility shims. --- src/websockets/asyncio/connection.py | 17 ++++++-- src/websockets/frames.py | 50 ------------------------ src/websockets/legacy/framing.py | 58 +++++++++++++++++++++++++--- src/websockets/legacy/protocol.py | 4 +- src/websockets/speedups.c | 1 - src/websockets/sync/connection.py | 17 ++++++-- tests/asyncio/test_connection.py | 10 +++++ tests/legacy/test_framing.py | 56 +++++++++++++++++++++++++++ tests/sync/test_connection.py | 10 +++++ tests/test_frames.py | 56 --------------------------- 10 files changed, 156 insertions(+), 123 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 005e9b4bb..069b3e1d2 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -20,7 +20,7 @@ ) from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError -from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol @@ -597,8 +597,12 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: the corresponding pong wasn't received yet. """ - if data is not None: - data = prepare_ctrl(data) + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") async with self.send_context(): # Protect against duplicates if a payload is explicitly set. @@ -632,7 +636,12 @@ async def pong(self, data: Data = b"") -> None: ConnectionClosed: When the connection is closed. """ - data = prepare_ctrl(data) + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") async with self.send_context(): self.protocol.send_pong(data) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 819fdd742..8e44dd3a2 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -9,7 +9,6 @@ from typing import Callable, Generator, Sequence from . import exceptions, extensions -from .typing import Data try: @@ -29,8 +28,6 @@ "DATA_OPCODES", "CTRL_OPCODES", "Frame", - "prepare_data", - "prepare_ctrl", "Close", ] @@ -354,53 +351,6 @@ def check(self) -> None: raise exceptions.ProtocolError("fragmented control frame") -def prepare_data(data: Data) -> tuple[int, bytes]: - """ - Convert a string or byte-like object to an opcode and a bytes-like object. - - This function is designed for data frames. - - If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` - object encoding ``data`` in UTF-8. - - If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like - object. - - Raises: - TypeError: If ``data`` doesn't have a supported type. - - """ - if isinstance(data, str): - return OP_TEXT, data.encode() - elif isinstance(data, BytesLike): - return OP_BINARY, data - else: - raise TypeError("data must be str or bytes-like") - - -def prepare_ctrl(data: Data) -> bytes: - """ - Convert a string or byte-like object to bytes. - - This function is designed for ping and pong frames. - - If ``data`` is a :class:`str`, return a :class:`bytes` object encoding - ``data`` in UTF-8. - - If ``data`` is a bytes-like object, return a :class:`bytes` object. - - Raises: - TypeError: If ``data`` doesn't have a supported type. - - """ - if isinstance(data, str): - return data.encode() - elif isinstance(data, BytesLike): - return bytes(data) - else: - raise TypeError("data must be str or bytes-like") - - @dataclasses.dataclass class Close: """ diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 1aaca5cc6..4c2f8c23f 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -5,6 +5,8 @@ from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError +from ..frames import BytesLike +from ..typing import Data try: @@ -144,12 +146,58 @@ def write( write(self.new_frame.serialize(mask=mask, extensions=extensions)) +def prepare_data(data: Data) -> tuple[int, bytes]: + """ + Convert a string or byte-like object to an opcode and a bytes-like object. + + This function is designed for data frames. + + If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` + object encoding ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like + object. + + Raises: + TypeError: If ``data`` doesn't have a supported type. + + """ + if isinstance(data, str): + return frames.Opcode.TEXT, data.encode() + elif isinstance(data, BytesLike): + return frames.Opcode.BINARY, data + else: + raise TypeError("data must be str or bytes-like") + + +def prepare_ctrl(data: Data) -> bytes: + """ + Convert a string or byte-like object to bytes. + + This function is designed for ping and pong frames. + + If ``data`` is a :class:`str`, return a :class:`bytes` object encoding + ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return a :class:`bytes` object. + + Raises: + TypeError: If ``data`` doesn't have a supported type. + + """ + if isinstance(data, str): + return data.encode() + elif isinstance(data, BytesLike): + return bytes(data) + else: + raise TypeError("data must be str or bytes-like") + + +# Backwards compatibility with previously documented public APIs +encode_data = prepare_ctrl + # Backwards compatibility with previously documented public APIs -from ..frames import ( # noqa: E402, F401, I001 - Close, - prepare_ctrl as encode_data, - prepare_data, -) +from ..frames import Close # noqa: E402 F401, I001 def parse_close(data: bytes) -> tuple[int, str]: diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 3b9a8c4aa..998e390d4 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -45,12 +45,10 @@ Close, CloseCode, Opcode, - prepare_ctrl, - prepare_data, ) from ..protocol import State from ..typing import Data, LoggerLike, Subprotocol -from .framing import Frame +from .framing import Frame, prepare_ctrl, prepare_data __all__ = ["WebSocketCommonProtocol"] diff --git a/src/websockets/speedups.c b/src/websockets/speedups.c index a19590419..cb10dedb8 100644 --- a/src/websockets/speedups.c +++ b/src/websockets/speedups.c @@ -19,7 +19,6 @@ _PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ss { // This supports bytes, bytearrays, and memoryview objects, // which are common data structures for handling byte streams. - // websockets.framing.prepare_data() returns only these types. // If *tmp isn't NULL, the caller gets a new reference. if (PyBytes_Check(obj)) { diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 88d6aee1f..16e51abda 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -11,7 +11,7 @@ from typing import Any, Iterable, Iterator, Mapping from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError -from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode, prepare_ctrl +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State from ..typing import Data, LoggerLike, Subprotocol @@ -449,8 +449,12 @@ def ping(self, data: Data | None = None) -> threading.Event: the corresponding pong wasn't received yet. """ - if data is not None: - data = prepare_ctrl(data) + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") with self.send_context(): # Protect against duplicates if a payload is explicitly set. @@ -481,7 +485,12 @@ def pong(self, data: Data = b"") -> None: ConnectionClosed: When the connection is closed. """ - data = prepare_ctrl(data) + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") with self.send_context(): self.protocol.send_pong(data) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 59218de4b..78f3adf68 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -869,6 +869,11 @@ async def test_ping_duplicate_payload(self): await self.connection.ping("idem") # doesn't raise an exception + async def test_ping_unsupported_type(self): + """ping raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.ping([]) + # Test pong. async def test_pong(self): @@ -886,6 +891,11 @@ async def test_pong_explicit_binary(self): await self.connection.pong(b"pong") await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + async def test_pong_unsupported_type(self): + """pong raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.pong([]) + # Test keepalive. @patch("random.getrandbits") diff --git a/tests/legacy/test_framing.py b/tests/legacy/test_framing.py index 6f811bd5e..e816b91e0 100644 --- a/tests/legacy/test_framing.py +++ b/tests/legacy/test_framing.py @@ -173,6 +173,62 @@ def decode(frame, *, max_size=None): ) +class PrepareDataTests(unittest.TestCase): + def test_prepare_data_str(self): + self.assertEqual( + prepare_data("café"), + (OP_TEXT, b"caf\xc3\xa9"), + ) + + def test_prepare_data_bytes(self): + self.assertEqual( + prepare_data(b"tea"), + (OP_BINARY, b"tea"), + ) + + def test_prepare_data_bytearray(self): + self.assertEqual( + prepare_data(bytearray(b"tea")), + (OP_BINARY, bytearray(b"tea")), + ) + + def test_prepare_data_memoryview(self): + self.assertEqual( + prepare_data(memoryview(b"tea")), + (OP_BINARY, memoryview(b"tea")), + ) + + def test_prepare_data_list(self): + with self.assertRaises(TypeError): + prepare_data([]) + + def test_prepare_data_none(self): + with self.assertRaises(TypeError): + prepare_data(None) + + +class PrepareCtrlTests(unittest.TestCase): + def test_prepare_ctrl_str(self): + self.assertEqual(prepare_ctrl("café"), b"caf\xc3\xa9") + + def test_prepare_ctrl_bytes(self): + self.assertEqual(prepare_ctrl(b"tea"), b"tea") + + def test_prepare_ctrl_bytearray(self): + self.assertEqual(prepare_ctrl(bytearray(b"tea")), b"tea") + + def test_prepare_ctrl_memoryview(self): + self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea") + + def test_prepare_ctrl_list(self): + with self.assertRaises(TypeError): + prepare_ctrl([]) + + def test_prepare_ctrl_none(self): + with self.assertRaises(TypeError): + prepare_ctrl(None) + + class ParseAndSerializeCloseTests(unittest.TestCase): def assertCloseData(self, code, reason, data): """ diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 877adc4bf..d9fb2093b 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -665,6 +665,11 @@ def test_ping_duplicate_payload(self): self.connection.ping("idem") # doesn't raise an exception + def test_ping_unsupported_type(self): + """ping raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + self.connection.ping([]) + # Test pong. def test_pong(self): @@ -682,6 +687,11 @@ def test_pong_explicit_binary(self): self.connection.pong(b"pong") self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + def test_pong_unsupported_type(self): + """pong raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + self.connection.pong([]) + # Test attributes. def test_id(self): diff --git a/tests/test_frames.py b/tests/test_frames.py index 3e9f5d6f8..857b41fe4 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -377,62 +377,6 @@ def test_pong_binary(self): ) -class PrepareDataTests(unittest.TestCase): - def test_prepare_data_str(self): - self.assertEqual( - prepare_data("café"), - (OP_TEXT, b"caf\xc3\xa9"), - ) - - def test_prepare_data_bytes(self): - self.assertEqual( - prepare_data(b"tea"), - (OP_BINARY, b"tea"), - ) - - def test_prepare_data_bytearray(self): - self.assertEqual( - prepare_data(bytearray(b"tea")), - (OP_BINARY, bytearray(b"tea")), - ) - - def test_prepare_data_memoryview(self): - self.assertEqual( - prepare_data(memoryview(b"tea")), - (OP_BINARY, memoryview(b"tea")), - ) - - def test_prepare_data_list(self): - with self.assertRaises(TypeError): - prepare_data([]) - - def test_prepare_data_none(self): - with self.assertRaises(TypeError): - prepare_data(None) - - -class PrepareCtrlTests(unittest.TestCase): - def test_prepare_ctrl_str(self): - self.assertEqual(prepare_ctrl("café"), b"caf\xc3\xa9") - - def test_prepare_ctrl_bytes(self): - self.assertEqual(prepare_ctrl(b"tea"), b"tea") - - def test_prepare_ctrl_bytearray(self): - self.assertEqual(prepare_ctrl(bytearray(b"tea")), b"tea") - - def test_prepare_ctrl_memoryview(self): - self.assertEqual(prepare_ctrl(memoryview(b"tea")), b"tea") - - def test_prepare_ctrl_list(self): - with self.assertRaises(TypeError): - prepare_ctrl([]) - - def test_prepare_ctrl_none(self): - with self.assertRaises(TypeError): - prepare_ctrl(None) - - class CloseTests(unittest.TestCase): def assertCloseData(self, close, data): """ From d3dcfd150385970a0d1df7187a989680cc426f55 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 21 Aug 2024 23:09:46 +0200 Subject: [PATCH 576/762] Add HTTP Basic Auth to asyncio & threading servers. --- docs/howto/upgrade.rst | 86 +++++++++++++--- docs/project/changelog.rst | 6 ++ docs/reference/asyncio/server.rst | 8 ++ docs/reference/features.rst | 2 +- docs/reference/sync/server.rst | 8 ++ docs/topics/compression.rst | 2 - src/websockets/asyncio/client.py | 4 +- src/websockets/asyncio/server.py | 149 ++++++++++++++++++++++++++-- src/websockets/sync/client.py | 4 +- src/websockets/sync/server.py | 133 ++++++++++++++++++++++++- tests/asyncio/test_server.py | 160 ++++++++++++++++++++++++++++++ tests/sync/test_server.py | 145 +++++++++++++++++++++++++++ 12 files changed, 679 insertions(+), 28 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 8d0895638..d68f5fe99 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -71,15 +71,6 @@ Missing features If your application relies on one of them, you should stick to the original implementation until the new implementation supports it in a future release. -HTTP Basic Authentication -......................... - -On the server side, :func:`~asyncio.server.serve` doesn't provide HTTP Basic -Authentication yet. - -For the avoidance of doubt, on the client side, :func:`~asyncio.client.connect` -performs HTTP Basic Authentication. - Following redirects ................... @@ -165,12 +156,12 @@ Server APIs | ``websockets.broadcast`` |br| | :func:`websockets.asyncio.server.broadcast` | | :func:`websockets.legacy.server.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | *not available yet* | -| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | | +| ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | +| ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | :func:`websockets.asyncio.server.basic_auth`. | | :class:`websockets.legacy.auth.BasicAuthWebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.basic_auth_protocol_factory()`` |br| | *not available yet* | -| ``websockets.auth.basic_auth_protocol_factory()`` |br| | | +| ``websockets.basic_auth_protocol_factory()`` |br| | See below :ref:`how to migrate ` to | +| ``websockets.auth.basic_auth_protocol_factory()`` |br| | :func:`websockets.asyncio.server.basic_auth`. | | :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ @@ -206,6 +197,75 @@ implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. +.. _basic-auth: + +Performing HTTP Basic Authentication +.................................... + +.. admonition:: This section applies only to servers. + :class: tip + + On the client side, :func:`~asyncio.client.connect` performs HTTP Basic + Authentication automatically when the URI contains credentials. + +In the original implementation, the recommended way to add HTTP Basic +Authentication to a server was to set the ``create_protocol`` argument of +:func:`~legacy.server.serve` to a factory function generated by +:func:`~legacy.auth.basic_auth_protocol_factory`:: + + from websockets.legacy.auth import basic_auth_protocol_factory + from websockets.legacy.server import serve + + async with serve(..., create_protocol=basic_auth_protocol_factory(...)): + ... + +In the new implementation, the :func:`~asyncio.server.basic_auth` function +generates a ``process_request`` coroutine that performs HTTP Basic +Authentication:: + + from websockets.asyncio.server import basic_auth, serve + + async with serve(..., process_request=basic_auth(...)): + ... + +:func:`~asyncio.server.basic_auth` accepts either hard coded ``credentials`` or +a ``check_credentials`` coroutine as well as an optional ``realm`` just like +:func:`~legacy.auth.basic_auth_protocol_factory`. Furthermore, +``check_credentials`` may be a function instead of a coroutine. + +This new API has more obvious semantics. That makes it easier to understand and +also easier to extend. + +In the original implementation, overriding ``create_protocol`` changed the type +of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, +a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP +Basic Authentication in its ``process_request`` method. If you wanted to +customize ``process_request`` further, you had: + +* an ill-defined option: add a ``process_request`` argument to + :func:`~legacy.server.serve`; to tell which one would run first, you had to + experiment or read the code; +* a cumbersome option: subclass + :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that + subclass in the ``create_protocol`` argument of + :func:`~legacy.auth.basic_auth_protocol_factory`. + +In the new implementation, you just write a ``process_request`` coroutine:: + + from websockets.asyncio.server import basic_auth, serve + + process_basic_auth = basic_auth(...) + + async def process_request(connection, request): + ... # some logic here + response = await process_basic_auth(connection, request) + if response is not None: + return response + ... # more logic here + + async with serve(..., process_request=process_request): + ... + Customizing the opening handshake ................................. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 955136ac0..d940b1ea2 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,12 @@ notice. *In development* +New features +............ + +* The new :mod:`asyncio` and :mod:`threading` implementations provide an API for + enforcing HTTP Basic Auth on the server side. + .. _13.0: 13.0 diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index bd5a34b19..d4d20aeb6 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -81,3 +81,11 @@ Broadcast --------- .. autofunction:: websockets.asyncio.server.broadcast + +HTTP Basic Authentication +------------------------- + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: websockets.asyncio.server.basic_auth diff --git a/docs/reference/features.rst b/docs/reference/features.rst index a380f4555..d9941e408 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -118,7 +118,7 @@ Server +------------------------------------+--------+--------+--------+--------+ | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ❌ | ❌ | ❌ | ✅ | + | Perform HTTP Basic Authentication | ✅ | ✅ | ❌ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 23cb04097..80e9c17bb 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -60,3 +60,11 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + +HTTP Basic Authentication +------------------------- + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: websockets.sync.server.basic_auth diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index be263e56f..5f09bbf73 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -45,7 +45,6 @@ explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], - ..., ) serve( @@ -57,7 +56,6 @@ explicitly with :class:`ClientPerMessageDeflateFactory` or compress_settings={"memLevel": 4}, ), ], - ..., ) The Window Bits and Memory Level values in these examples reduce memory usage diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 632d3ac2b..033887e87 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -131,7 +131,9 @@ class connect: :func:`connect` may be used as an asynchronous context manager:: - async with websockets.asyncio.client.connect(...) as websocket: + from websockets.asyncio.client import connect + + async with connect(...) as websocket: ... The connection is closed automatically when exiting the context. diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 8f04ec318..f6cd9a1b5 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import hmac import http import logging import socket @@ -13,13 +14,19 @@ Generator, Iterable, Sequence, + Tuple, + cast, ) -from websockets.frames import CloseCode - +from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate -from ..headers import validate_subprotocols +from ..frames import CloseCode +from ..headers import ( + build_www_authenticate_basic, + parse_authorization_basic, + validate_subprotocols, +) from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, Event from ..server import ServerProtocol @@ -28,7 +35,14 @@ from .connection import Connection, broadcast -__all__ = ["broadcast", "serve", "unix_serve", "ServerConnection", "Server"] +__all__ = [ + "broadcast", + "serve", + "unix_serve", + "ServerConnection", + "Server", + "basic_auth", +] class ServerConnection(Connection): @@ -79,6 +93,7 @@ def __init__( ) self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + self.username: str # see basic_auth() def respond(self, status: StatusLike, text: str) -> Response: """ @@ -548,19 +563,21 @@ class serve: :class:`asyncio.Server`. Treat it as an asynchronous context manager to ensure that the server will be closed:: + from websockets.asyncio.server import serve + def handler(websocket): ... # set this future to exit the server stop = asyncio.get_running_loop().create_future() - async with websockets.asyncio.server.serve(handler, host, port): + async with serve(handler, host, port): await stop Alternatively, call :meth:`~Server.serve_forever` to serve requests and cancel it to stop the server:: - server = await websockets.asyncio.server.serve(handler, host, port) + server = await serve(handler, host, port) await server.serve_forever() Args: @@ -822,3 +839,123 @@ def unix_serve( """ return serve(handler, unix=True, path=path, **kwargs) + + +def is_credentials(credentials: Any) -> bool: + try: + username, password = credentials + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None, +) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]: + """ + Factory for ``process_request`` to enforce HTTP Basic Authentication. + + :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: + + from websockets.asyncio.server import basic_auth, serve + + async with serve( + ..., + process_request=basic_auth( + realm="my dev server", + credentials=("hello", "iloveyou"), + ), + ): + + If authentication succeeds, the connection's ``username`` attribute is set. + If it fails, the server responds with an HTTP 401 Unauthorized status. + + One of ``credentials`` or ``check_credentials`` must be provided; not both. + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. Refer to + section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Function or coroutine that verifies credentials. + It receives ``username`` and ``password`` arguments and returns + whether they're valid. + Raises: + TypeError: If ``credentials`` or ``check_credentials`` is wrong. + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(Tuple[str, str], credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None # help mypy + + async def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + """ + Perform HTTP Basic Authentication. + + If it succeeds, set the connection's ``username`` attribute and return + :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. + + """ + try: + authorization = request.headers["Authorization"] + except KeyError: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Missing credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Unsupported credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + valid_credentials = check_credentials(username, password) + if isinstance(valid_credentials, Awaitable): + valid_credentials = await valid_credentials + + if not valid_credentials: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Invalid credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + connection.username = username + return None + + return process_request diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e33d53f62..3c700a377 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -156,7 +156,9 @@ def connect( :func:`connect` may be used as a context manager:: - with websockets.sync.client.connect(...) as websocket: + from websockets.sync.client import connect + + with connect(...) as websocket: ... The connection is closed automatically when exiting the context. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 86c162af3..5e22e112e 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hmac import http import logging import os @@ -10,12 +11,17 @@ import threading import warnings from types import TracebackType -from typing import Any, Callable, Sequence +from typing import Any, Callable, Iterable, Sequence, Tuple, cast +from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory from ..extensions.permessage_deflate import enable_server_permessage_deflate from ..frames import CloseCode -from ..headers import validate_subprotocols +from ..headers import ( + build_www_authenticate_basic, + parse_authorization_basic, + validate_subprotocols, +) from ..http11 import SERVER, Request, Response from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol @@ -24,7 +30,7 @@ from .utils import Deadline -__all__ = ["serve", "unix_serve", "ServerConnection", "Server"] +__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"] class ServerConnection(Connection): @@ -65,6 +71,7 @@ def __init__( protocol, close_timeout=close_timeout, ) + self.username: str # see basic_auth() def respond(self, status: StatusLike, text: str) -> Response: """ @@ -368,10 +375,12 @@ def serve( that it will be closed and call :meth:`~Server.serve_forever` to serve requests:: + from websockets.sync.server import serve + def handler(websocket): ... - with websockets.sync.server.serve(handler, ...) as server: + with serve(handler, ...) as server: server.serve_forever() Args: @@ -587,3 +596,119 @@ def unix_serve( """ return serve(handler, unix=True, path=path, **kwargs) + + +def is_credentials(credentials: Any) -> bool: + try: + username, password = credentials + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], bool] | None = None, +) -> Callable[[ServerConnection, Request], Response | None]: + """ + Factory for ``process_request`` to enforce HTTP Basic Authentication. + + :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: + + from websockets.sync.server import basic_auth, serve + + with serve( + ..., + process_request=basic_auth( + realm="my dev server", + credentials=("hello", "iloveyou"), + ), + ): + + If authentication succeeds, the connection's ``username`` attribute is set. + If it fails, the server responds with an HTTP 401 Unauthorized status. + + One of ``credentials`` or ``check_credentials`` must be provided; not both. + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. Refer to + section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Function that verifies credentials. + It receives ``username`` and ``password`` arguments and returns + whether they're valid. + Raises: + TypeError: If ``credentials`` or ``check_credentials`` is wrong. + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(Tuple[str, str], credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None # help mypy + + def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + """ + Perform HTTP Basic Authentication. + + If it succeeds, set the connection's ``username`` attribute and return + :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. + + """ + try: + authorization = request.headers["Authorization"] + except KeyError: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Missing credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Unsupported credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + if not check_credentials(username, password): + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Invalid credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + connection.username = username + return None + + return process_request diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index b899998f4..f05b9f1e4 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -1,5 +1,6 @@ import asyncio import dataclasses +import hmac import http import logging import socket @@ -560,3 +561,162 @@ async def test_logger(self): logger = logging.getLogger("test") async with run_server(logger=logger) as server: self.assertIs(server.logger, logger) + + +class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + async with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + await self.assertEval(client, "ws.username", "bye") + + async def test_check_credentials_function(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_check_credentials_coroutine(self): + """basic_auth accepts a check_credentials coroutine.""" + + async def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + async def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + ) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index e3dfeb271..39d7501bc 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -1,4 +1,5 @@ import dataclasses +import hmac import http import logging import socket @@ -413,6 +414,150 @@ def test_shutdown(self): server.socket.accept() +class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + self.assertEval(client, "ws.username", "hello") + + def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client(server): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + self.assertEval(client, "ws.username", "bye") + + def test_check_credentials(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + with run_client( + server, + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + self.assertEval(client, "ws.username", "hello") + + def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(TypeError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + ) + + class BackwardsCompatibilityTests(DeprecationTestCase): def test_ssl_context_argument(self): """Client supports the deprecated ssl_context argument.""" From 4920a585beef9c4f8a1aa7e332dc20d75f252423 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 22 Aug 2024 09:12:16 +0200 Subject: [PATCH 577/762] Make logging examples more robust. Prevent them from crashing when the opening handshake didn't complete successfully. Also migrate them to the new asyncio API. Fix #1428. --- docs/topics/logging.rst | 15 +++++++++------ src/websockets/server.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 9580b4c50..be5678455 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -124,9 +124,11 @@ Here's how to include them in logs, assuming they're in the def process(self, msg, kwargs): try: websocket = kwargs["extra"]["websocket"] - except KeyError: + except KeyError: # log entry not coming from a connection + return msg, kwargs + if websocket.request is None: # opening handshake not complete return msg, kwargs - xff = websocket.request_headers.get("X-Forwarded-For") + xff = headers.get("X-Forwarded-For") return f"{websocket.id} {xff} {msg}", kwargs async with serve( @@ -165,10 +167,11 @@ a :class:`~logging.LoggerAdapter`:: websocket = kwargs["extra"]["websocket"] except KeyError: return msg, kwargs - kwargs["extra"]["event_data"] = { - "connection_id": str(websocket.id), - "remote_addr": websocket.request_headers.get("X-Forwarded-For"), - } + event_data = {"connection_id": str(websocket.id)} + if websocket.request is not None: # opening handshake complete + headers = websocket.request.headers + event_data["remote_addr"] = headers.get("X-Forwarded-For") + kwargs["extra"]["event_data"] = event_data return msg, kwargs async with serve( diff --git a/src/websockets/server.py b/src/websockets/server.py index 2ab9102f7..11ba8b425 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -69,7 +69,7 @@ class ServerProtocol(Protocol): max_size: Maximum size of incoming messages in bytes; :obj:`None` disables the limit. logger: Logger for this connection; - defaults to ``logging.getLogger("websockets.client")``; + defaults to ``logging.getLogger("websockets.server")``; see the :doc:`logging guide <../../topics/logging>` for details. """ From d341bbabd81facbf084e7c5b321e7c8c2cd977f5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 22 Aug 2024 21:55:55 +0200 Subject: [PATCH 578/762] Add API for the set of active connections. Fix #1486. --- docs/howto/upgrade.rst | 10 ++++++++ docs/project/changelog.rst | 7 ++++-- docs/reference/asyncio/server.rst | 2 ++ src/websockets/asyncio/server.py | 14 ++++++++++- tests/asyncio/test_server.py | 24 ++++++++++++------- tests/sync/test_server.py | 40 +++++++++++++++---------------- 6 files changed, 65 insertions(+), 32 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index d68f5fe99..602d8a4e6 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -197,6 +197,16 @@ implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. +Tracking open connections +......................... + +The new implementation of :class:`~asyncio.server.Server` provides a +:attr:`~asyncio.server.Server.connections` property, which is a set of all open +connections. This didn't exist in the original implementation. + +If you were keeping track of open connections, you may be able to simplify your +code by using this property. + .. _basic-auth: Performing HTTP Basic Authentication diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index d940b1ea2..1c87882f0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,8 +35,11 @@ notice. New features ............ -* The new :mod:`asyncio` and :mod:`threading` implementations provide an API for - enforcing HTTP Basic Auth on the server side. +* Made the set of active connections available in the :attr:`Server.connections + ` property. + +* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` + implementations of servers. .. _13.0: diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index d4d20aeb6..2fcaeb414 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -17,6 +17,8 @@ Running a server .. autoclass:: Server + .. autoattribute:: connections + .. automethod:: close .. automethod:: wait_closed diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index f6cd9a1b5..29860e565 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -28,7 +28,7 @@ validate_subprotocols, ) from ..http11 import SERVER, Request, Response -from ..protocol import CONNECTING, Event +from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .compatibility import asyncio_timeout @@ -313,6 +313,18 @@ def __init__( # Completed when the server is closed and connections are terminated. self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + @property + def connections(self) -> set[ServerConnection]: + """ + Set of active connections. + + This property contains all connections that completed the opening + handshake successfully and didn't start the closing handshake yet. + It can be useful in combination with :func:`~broadcast`. + + """ + return {connection for connection in self.handlers if connection.state is OPEN} + def wrap(self, server: asyncio.Server) -> None: """ Attach to a given :class:`asyncio.Server`. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index f05b9f1e4..38f226903 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -350,6 +350,12 @@ async def test_disable_keepalive(self): latency = eval(await client.recv()) self.assertEqual(latency, 0) + async def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" @@ -362,6 +368,16 @@ def create_connection(*args, **kwargs): async with run_client(server) as client: await self.assertEval(client, "ws.create_connection_ran", "True") + async def test_connections(self): + """Server provides a connections property.""" + async with run_server() as server: + self.assertEqual(server.connections, set()) + async with run_client(server) as client: + self.assertEqual(len(server.connections), 1) + ws_id = str(next(iter(server.connections)).id) + await self.assertEval(client, "ws.id", ws_id) + self.assertEqual(server.connections, set()) + async def test_handshake_fails(self): """Server receives connection from client but the handshake fails.""" @@ -555,14 +571,6 @@ async def test_unsupported_compression(self): ) -class WebSocketServerTests(unittest.IsolatedAsyncioTestCase): - async def test_logger(self): - """Server accepts a logger argument.""" - logger = logging.getLogger("test") - async with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) - - class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 39d7501bc..a17634716 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -235,6 +235,12 @@ def test_disable_compression(self): with run_client(server) as client: self.assertEval(client, "ws.protocol.extensions", "[]") + def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + with run_server(logger=logger) as server: + self.assertIs(server.logger, logger) + def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" @@ -247,6 +253,19 @@ def create_connection(*args, **kwargs): with run_client(server) as client: self.assertEval(client, "ws.create_connection_ran", "True") + def test_fileno(self): + """Server provides a fileno attribute.""" + with run_server() as server: + self.assertIsInstance(server.fileno(), int) + + def test_shutdown(self): + """Server provides a shutdown method.""" + with run_server() as server: + server.shutdown() + # Check that the server socket is closed. + with self.assertRaises(OSError): + server.socket.accept() + def test_handshake_fails(self): """Server receives connection from client but the handshake fails.""" @@ -393,27 +412,6 @@ def test_unsupported_compression(self): ) -class WebSocketServerTests(unittest.TestCase): - def test_logger(self): - """Server accepts a logger argument.""" - logger = logging.getLogger("test") - with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) - - def test_fileno(self): - """Server provides a fileno attribute.""" - with run_server() as server: - self.assertIsInstance(server.fileno(), int) - - def test_shutdown(self): - """Server provides a shutdown method.""" - with run_server() as server: - server.shutdown() - # Check that the server socket is closed. - with self.assertRaises(OSError): - server.socket.accept() - - class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" From 2306835d7bf936d3b7c09caabd12dd072c7d7a39 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 08:41:35 +0200 Subject: [PATCH 579/762] Set Protocol.close_rcvd_then_sent in an edge case. When receiving a close frame in the middle of a fragmented message -- cf. test_(client|server)_receive_close_in_fragmented_message) -- recv_frame() was raising a ProtocolError, fail() was sending a Close frame, and close_rcvd_then_sent was never set. These tests don't proceed to the CLOSED state, making it difficult to test close_exc. --- src/websockets/protocol.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index de065c544..3b3e80cf5 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -378,9 +378,11 @@ def send_close(self, code: int | None = None, reason: str = "") -> None: else: close = Close(code, reason) data = close.serialize() - # send_frame() guarantees that self.state is OPEN at this point. # 7.1.3. The WebSocket Closing Handshake is Started self.send_frame(Frame(OP_CLOSE, data)) + # Since the state is OPEN, no close frame was received yet. + # As a consequence, self.close_rcvd_then_sent remains None. + assert self.close_rcvd is None self.close_sent = close self.state = CLOSING @@ -441,6 +443,12 @@ def fail(self, code: int, reason: str = "") -> None: data = close.serialize() self.send_frame(Frame(OP_CLOSE, data)) self.close_sent = close + # If recv_messages() raised an exception upon receiving a close + # frame but before echoing it, then close_rcvd is not None even + # though the state is OPEN. This happens when the connection is + # closed while receiving a fragmented message. + if self.close_rcvd is not None: + self.close_rcvd_then_sent = True self.state = CLOSING # When failing the connection, a server closes the TCP connection From 61f42acdbe47c51ad51283890d814091d2c201e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 07:20:15 +0200 Subject: [PATCH 580/762] Deprecate private attributes of ConnectionClosed. While I don't think they were ever documented, they were tested and ConnectionClosed is such a common API that they may have been used. --- src/websockets/exceptions.py | 11 +++++++++++ tests/legacy/test_protocol.py | 10 ---------- tests/test_exceptions.py | 32 ++++++++++++++++++++++++++------ 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 52cc48898..64b7d3102 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,6 +31,7 @@ from __future__ import annotations import http +import warnings from . import datastructures, frames, http11 from .typing import StatusLike @@ -120,12 +121,22 @@ def __str__(self) -> str: @property def code(self) -> int: + warnings.warn( # deprecated in 13.1 + "ConnectionClosed.code is deprecated; " + "use Protocol.close_code or ConnectionClosed.rcvd.code", + DeprecationWarning, + ) if self.rcvd is None: return frames.CloseCode.ABNORMAL_CLOSURE return self.rcvd.code @property def reason(self) -> str: + warnings.warn( # deprecated in 13.1 + "ConnectionClosed.reason is deprecated; " + "use Protocol.close_reason or ConnectionClosed.rcvd.reason", + DeprecationWarning, + ) if self.rcvd is None: return "" return self.rcvd.reason diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index 8751b9ac6..de2a320b5 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1180,16 +1180,6 @@ def test_legacy_recv(self): # Now recv() returns None instead of raising ConnectionClosed. self.assertIsNone(self.loop.run_until_complete(self.protocol.recv())) - def test_connection_closed_attributes(self): - self.close_connection() - - with self.assertRaises(ConnectionClosed) as context: - self.loop.run_until_complete(self.protocol.recv()) - - connection_closed_exc = context.exception - self.assertEqual(connection_closed_exc.code, CloseCode.NORMAL_CLOSURE) - self.assertEqual(connection_closed_exc.reason, "close") - # Test the protocol logic for sending keepalive pings. def restart_protocol_with_keepalive_ping( diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 1e6f58fad..92ba7dda8 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -5,6 +5,8 @@ from websockets.frames import Close, CloseCode from websockets.http11 import Response +from .utils import DeprecationTestCase + class ExceptionsTests(unittest.TestCase): def test_str(self): @@ -185,12 +187,30 @@ def test_str(self): with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) - def test_connection_closed_attributes_backwards_compatibility(self): + +class DeprecationTests(DeprecationTestCase): + def test_connection_closed_attributes_deprecation(self): exception = ConnectionClosed(Close(CloseCode.NORMAL_CLOSURE, "OK"), None, None) - self.assertEqual(exception.code, CloseCode.NORMAL_CLOSURE) - self.assertEqual(exception.reason, "OK") + with self.assertDeprecationWarning( + "ConnectionClosed.code is deprecated; " + "use Protocol.close_code or ConnectionClosed.rcvd.code" + ): + self.assertEqual(exception.code, CloseCode.NORMAL_CLOSURE) + with self.assertDeprecationWarning( + "ConnectionClosed.reason is deprecated; " + "use Protocol.close_reason or ConnectionClosed.rcvd.reason" + ): + self.assertEqual(exception.reason, "OK") - def test_connection_closed_attributes_backwards_compatibility_defaults(self): + def test_connection_closed_attributes_deprecation_defaults(self): exception = ConnectionClosed(None, None, None) - self.assertEqual(exception.code, CloseCode.ABNORMAL_CLOSURE) - self.assertEqual(exception.reason, "") + with self.assertDeprecationWarning( + "ConnectionClosed.code is deprecated; " + "use Protocol.close_code or ConnectionClosed.rcvd.code" + ): + self.assertEqual(exception.code, CloseCode.ABNORMAL_CLOSURE) + with self.assertDeprecationWarning( + "ConnectionClosed.reason is deprecated; " + "use Protocol.close_reason or ConnectionClosed.rcvd.reason" + ): + self.assertEqual(exception.reason, "") From 5638611169aa763ec84cdc8d2a967b0533401129 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 07:25:06 +0200 Subject: [PATCH 581/762] Clean up implementation of ConnectionClosed. --- src/websockets/exceptions.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 64b7d3102..8df319dae 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -78,13 +78,13 @@ class ConnectionClosed(WebSocketException): Raised when trying to interact with a closed connection. Attributes: - rcvd (Close | None): if a close frame was received, its code and - reason are available in ``rcvd.code`` and ``rcvd.reason``. - sent (Close | None): if a close frame was sent, its code and reason - are available in ``sent.code`` and ``sent.reason``. - rcvd_then_sent (bool | None): if close frames were received and - sent, this attribute tells in which order this happened, from the - perspective of this side of the connection. + rcvd: If a close frame was received, its code and reason are available + in ``rcvd.code`` and ``rcvd.reason``. + sent: If a close frame was sent, its code and reason are available + in ``sent.code`` and ``sent.reason``. + rcvd_then_sent: If close frames were received and sent, this attribute + tells in which order this happened, from the perspective of this + side of the connection. """ @@ -97,21 +97,18 @@ def __init__( self.rcvd = rcvd self.sent = sent self.rcvd_then_sent = rcvd_then_sent + assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None) def __str__(self) -> str: if self.rcvd is None: if self.sent is None: - assert self.rcvd_then_sent is None return "no close frame received or sent" else: - assert self.rcvd_then_sent is None return f"sent {self.sent}; no close frame received" else: if self.sent is None: - assert self.rcvd_then_sent is None return f"received {self.rcvd}; no close frame sent" else: - assert self.rcvd_then_sent is not None if self.rcvd_then_sent: return f"received {self.rcvd}; then sent {self.sent}" else: From 567f9859158fe9849b8a1ae7536471119ce4c61f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 07:32:28 +0200 Subject: [PATCH 582/762] Move InvalidMessage to the legacy package. It is only used by the legacy implementation. --- src/websockets/__init__.py | 8 +++++--- src/websockets/exceptions.py | 23 +++++++++++++++-------- src/websockets/legacy/client.py | 2 +- src/websockets/legacy/exceptions.py | 8 ++++++++ src/websockets/legacy/server.py | 2 +- tests/legacy/test_exceptions.py | 15 +++++++++++++++ tests/test_exceptions.py | 4 ---- 7 files changed, 45 insertions(+), 17 deletions(-) create mode 100644 src/websockets/legacy/exceptions.py create mode 100644 tests/legacy/test_exceptions.py diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index b618a6dff..63b0a260b 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -23,7 +23,6 @@ "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", - "InvalidMessage", "InvalidOrigin", "InvalidParameterName", "InvalidParameterValue", @@ -46,6 +45,8 @@ "WebSocketClientProtocol", "connect", "unix_connect", + # .legacy.exceptions + "InvalidMessage", # .legacy.protocol "WebSocketCommonProtocol", # .legacy.server @@ -80,7 +81,6 @@ InvalidHeader, InvalidHeaderFormat, InvalidHeaderValue, - InvalidMessage, InvalidOrigin, InvalidParameterName, InvalidParameterValue, @@ -102,6 +102,7 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect + from .legacy.exceptions import InvalidMessage from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, @@ -140,7 +141,6 @@ "InvalidHeader": ".exceptions", "InvalidHeaderFormat": ".exceptions", "InvalidHeaderValue": ".exceptions", - "InvalidMessage": ".exceptions", "InvalidOrigin": ".exceptions", "InvalidParameterName": ".exceptions", "InvalidParameterValue": ".exceptions", @@ -163,6 +163,8 @@ "WebSocketClientProtocol": ".legacy.client", "connect": ".legacy.client", "unix_connect": ".legacy.client", + # .legacy.exceptions + "InvalidMessage": ".legacy.exceptions", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", # .legacy.server diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 8df319dae..b9b100150 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -7,7 +7,7 @@ * :exc:`ConnectionClosedOK` * :exc:`InvalidHandshake` * :exc:`SecurityError` - * :exc:`InvalidMessage` + * :exc:`InvalidMessage` (legacy) * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` @@ -31,9 +31,11 @@ from __future__ import annotations import http +import typing import warnings from . import datastructures, frames, http11 +from .imports import lazy_import from .typing import StatusLike @@ -175,13 +177,6 @@ class SecurityError(InvalidHandshake): """ -class InvalidMessage(InvalidHandshake): - """ - Raised when a handshake request or response is malformed. - - """ - - class InvalidHeader(InvalidHandshake): """ Raised when an HTTP header doesn't have a valid format or value. @@ -410,3 +405,15 @@ class ProtocolError(WebSocketException): WebSocketProtocolError = ProtocolError # for backwards compatibility + + +# When type checking, import non-deprecated aliases eagerly. Else, import on demand. +if typing.TYPE_CHECKING: + from .legacy.exceptions import InvalidMessage +else: + lazy_import( + globals(), + aliases={ + "InvalidMessage": ".legacy.exceptions", + }, + ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 256bee14c..bd56c5d16 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -21,7 +21,6 @@ from ..exceptions import ( InvalidHandshake, InvalidHeader, - InvalidMessage, InvalidStatusCode, NegotiationError, RedirectHandshake, @@ -41,6 +40,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri +from .exceptions import InvalidMessage from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py new file mode 100644 index 000000000..9ea173e1a --- /dev/null +++ b/src/websockets/legacy/exceptions.py @@ -0,0 +1,8 @@ +from ..exceptions import InvalidHandshake + + +class InvalidMessage(InvalidHandshake): + """ + Raised when a handshake request or response is malformed. + + """ diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 43136db3e..1bf359f32 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -27,7 +27,6 @@ AbortHandshake, InvalidHandshake, InvalidHeader, - InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -43,6 +42,7 @@ from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol +from .exceptions import InvalidMessage from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol, broadcast diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py new file mode 100644 index 000000000..1850b3bf2 --- /dev/null +++ b/tests/legacy/test_exceptions.py @@ -0,0 +1,15 @@ +import unittest + +from websockets.legacy.exceptions import * + + +class ExceptionsTests(unittest.TestCase): + def test_str(self): + for exception, exception_str in [ + ( + InvalidMessage("malformed HTTP message"), + "malformed HTTP message", + ), + ]: + with self.subTest(exception=exception): + self.assertEqual(str(exception), exception_str) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 92ba7dda8..36a4d1724 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -87,10 +87,6 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), - ( - InvalidMessage("malformed HTTP message"), - "malformed HTTP message", - ), ( InvalidHeader("Name"), "missing Name header", From c02648e3ce9e96deec83f966dbb74fca1d9b2e0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 07:56:36 +0200 Subject: [PATCH 583/762] Move WebSocketProtocolError to the legacy package. This ensures that it will be deprecated and removed. --- src/websockets/exceptions.py | 6 +++--- src/websockets/legacy/exceptions.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b9b100150..1188ec432 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -404,16 +404,16 @@ class ProtocolError(WebSocketException): """ -WebSocketProtocolError = ProtocolError # for backwards compatibility - - # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: from .legacy.exceptions import InvalidMessage + + WebSocketProtocolError = ProtocolError else: lazy_import( globals(), aliases={ "InvalidMessage": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", }, ) diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 9ea173e1a..79bfa7f4d 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -1,4 +1,7 @@ -from ..exceptions import InvalidHandshake +from ..exceptions import ( + InvalidHandshake, + ProtocolError as WebSocketProtocolError, # noqa: F401 +) class InvalidMessage(InvalidHandshake): From 938f1075ebcada7e46e9fb4caacfb32839c51605 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 09:02:11 +0200 Subject: [PATCH 584/762] Move InvalidStatusCode to the legacy package. It is only used by the legacy implementation. --- src/websockets/__init__.py | 7 +++---- src/websockets/exceptions.py | 20 +++++--------------- src/websockets/legacy/client.py | 3 +-- src/websockets/legacy/exceptions.py | 15 +++++++++++++++ tests/legacy/test_exceptions.py | 5 +++++ tests/test_exceptions.py | 4 ---- 6 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 63b0a260b..09d7aac88 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -28,7 +28,6 @@ "InvalidParameterValue", "InvalidState", "InvalidStatus", - "InvalidStatusCode", "InvalidUpgrade", "InvalidURI", "NegotiationError", @@ -47,6 +46,7 @@ "unix_connect", # .legacy.exceptions "InvalidMessage", + "InvalidStatusCode", # .legacy.protocol "WebSocketCommonProtocol", # .legacy.server @@ -86,7 +86,6 @@ InvalidParameterValue, InvalidState, InvalidStatus, - InvalidStatusCode, InvalidUpgrade, InvalidURI, NegotiationError, @@ -102,7 +101,7 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect - from .legacy.exceptions import InvalidMessage + from .legacy.exceptions import InvalidMessage, InvalidStatusCode from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, @@ -146,7 +145,6 @@ "InvalidParameterValue": ".exceptions", "InvalidState": ".exceptions", "InvalidStatus": ".exceptions", - "InvalidStatusCode": ".exceptions", "InvalidUpgrade": ".exceptions", "InvalidURI": ".exceptions", "NegotiationError": ".exceptions", @@ -165,6 +163,7 @@ "unix_connect": ".legacy.client", # .legacy.exceptions "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", # .legacy.server diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 1188ec432..e35bf430e 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -250,20 +250,6 @@ def __str__(self) -> str: ) -class InvalidStatusCode(InvalidHandshake): - """ - Raised when a handshake response status code is invalid. - - """ - - def __init__(self, status_code: int, headers: datastructures.Headers) -> None: - self.status_code = status_code - self.headers = headers - - def __str__(self) -> str: - return f"server rejected WebSocket connection: HTTP {self.status_code}" - - class NegotiationError(InvalidHandshake): """ Raised when negotiating an extension fails. @@ -406,7 +392,10 @@ class ProtocolError(WebSocketException): # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: - from .legacy.exceptions import InvalidMessage + from .legacy.exceptions import ( + InvalidMessage, + InvalidStatusCode, + ) WebSocketProtocolError = ProtocolError else: @@ -414,6 +403,7 @@ class ProtocolError(WebSocketException): globals(), aliases={ "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", }, ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index bd56c5d16..fd0803e7d 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -21,7 +21,6 @@ from ..exceptions import ( InvalidHandshake, InvalidHeader, - InvalidStatusCode, NegotiationError, RedirectHandshake, SecurityError, @@ -40,7 +39,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .exceptions import InvalidMessage +from .exceptions import InvalidMessage, InvalidStatusCode from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 79bfa7f4d..d82676d5e 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -1,3 +1,4 @@ +from .. import datastructures from ..exceptions import ( InvalidHandshake, ProtocolError as WebSocketProtocolError, # noqa: F401 @@ -9,3 +10,17 @@ class InvalidMessage(InvalidHandshake): Raised when a handshake request or response is malformed. """ + + +class InvalidStatusCode(InvalidHandshake): + """ + Raised when a handshake response status code is invalid. + + """ + + def __init__(self, status_code: int, headers: datastructures.Headers) -> None: + self.status_code = status_code + self.headers = headers + + def __str__(self) -> str: + return f"server rejected WebSocket connection: HTTP {self.status_code}" diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index 1850b3bf2..8b9a616e9 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -1,5 +1,6 @@ import unittest +from websockets.datastructures import Headers from websockets.legacy.exceptions import * @@ -10,6 +11,10 @@ def test_str(self): InvalidMessage("malformed HTTP message"), "malformed HTTP message", ), + ( + InvalidStatusCode(403, Headers()), + "server rejected WebSocket connection: HTTP 403", + ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 36a4d1724..b79489e44 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -127,10 +127,6 @@ def test_str(self): InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", ), - ( - InvalidStatusCode(403, Headers()), - "server rejected WebSocket connection: HTTP 403", - ), ( NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", From 30cbbc3adaf949c0801a2fd37f82d0bd1eaa0ffc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 09:09:41 +0200 Subject: [PATCH 585/762] Move AbortHandshake to the legacy package. It is only used by the legacy implementation. --- src/websockets/__init__.py | 11 +++++--- src/websockets/exceptions.py | 42 +++-------------------------- src/websockets/legacy/exceptions.py | 37 +++++++++++++++++++++++++ src/websockets/legacy/server.py | 3 +-- tests/legacy/test_exceptions.py | 4 +++ tests/test_exceptions.py | 4 --- 6 files changed, 53 insertions(+), 48 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 09d7aac88..a623d32c4 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -14,7 +14,6 @@ "HeadersLike", "MultipleValuesError", # .exceptions - "AbortHandshake", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", @@ -45,6 +44,7 @@ "connect", "unix_connect", # .legacy.exceptions + "AbortHandshake", "InvalidMessage", "InvalidStatusCode", # .legacy.protocol @@ -72,7 +72,6 @@ from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( - AbortHandshake, ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, @@ -101,7 +100,11 @@ basic_auth_protocol_factory, ) from .legacy.client import WebSocketClientProtocol, connect, unix_connect - from .legacy.exceptions import InvalidMessage, InvalidStatusCode + from .legacy.exceptions import ( + AbortHandshake, + InvalidMessage, + InvalidStatusCode, + ) from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( WebSocketServer, @@ -131,7 +134,6 @@ "HeadersLike": ".datastructures", "MultipleValuesError": ".datastructures", # .exceptions - "AbortHandshake": ".exceptions", "ConnectionClosed": ".exceptions", "ConnectionClosedError": ".exceptions", "ConnectionClosedOK": ".exceptions", @@ -162,6 +164,7 @@ "connect": ".legacy.client", "unix_connect": ".legacy.client", # .legacy.exceptions + "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", # .legacy.protocol diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index e35bf430e..803e3fa53 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -19,7 +19,7 @@ * :exc:`DuplicateParameter` * :exc:`InvalidParameterName` * :exc:`InvalidParameterValue` - * :exc:`AbortHandshake` + * :exc:`AbortHandshake` (legacy) * :exc:`RedirectHandshake` * :exc:`InvalidState` * :exc:`InvalidURI` @@ -30,13 +30,11 @@ from __future__ import annotations -import http import typing import warnings -from . import datastructures, frames, http11 +from . import frames, http11 from .imports import lazy_import -from .typing import StatusLike __all__ = [ @@ -302,40 +300,6 @@ def __str__(self) -> str: return f"invalid value for parameter {self.name}: {self.value}" -class AbortHandshake(InvalidHandshake): - """ - Raised to abort the handshake on purpose and return an HTTP response. - - This exception is an implementation detail. - - The public API is - :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. - - Attributes: - status (~http.HTTPStatus): HTTP status code. - headers (Headers): HTTP response headers. - body (bytes): HTTP response body. - """ - - def __init__( - self, - status: StatusLike, - headers: datastructures.HeadersLike, - body: bytes = b"", - ) -> None: - # If a user passes an int instead of a HTTPStatus, fix it automatically. - self.status = http.HTTPStatus(status) - self.headers = datastructures.Headers(headers) - self.body = body - - def __str__(self) -> str: - return ( - f"HTTP {self.status:d}, " - f"{len(self.headers)} headers, " - f"{len(self.body)} bytes" - ) - - class RedirectHandshake(InvalidHandshake): """ Raised when a handshake gets redirected. @@ -393,6 +357,7 @@ class ProtocolError(WebSocketException): # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: from .legacy.exceptions import ( + AbortHandshake, InvalidMessage, InvalidStatusCode, ) @@ -402,6 +367,7 @@ class ProtocolError(WebSocketException): lazy_import( globals(), aliases={ + "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index d82676d5e..d02a1f933 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -1,8 +1,11 @@ +import http + from .. import datastructures from ..exceptions import ( InvalidHandshake, ProtocolError as WebSocketProtocolError, # noqa: F401 ) +from ..typing import StatusLike class InvalidMessage(InvalidHandshake): @@ -24,3 +27,37 @@ def __init__(self, status_code: int, headers: datastructures.Headers) -> None: def __str__(self) -> str: return f"server rejected WebSocket connection: HTTP {self.status_code}" + + +class AbortHandshake(InvalidHandshake): + """ + Raised to abort the handshake on purpose and return an HTTP response. + + This exception is an implementation detail. + + The public API is + :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. + + Attributes: + status (~http.HTTPStatus): HTTP status code. + headers (Headers): HTTP response headers. + body (bytes): HTTP response body. + """ + + def __init__( + self, + status: StatusLike, + headers: datastructures.HeadersLike, + body: bytes = b"", + ) -> None: + # If a user passes an int instead of a HTTPStatus, fix it automatically. + self.status = http.HTTPStatus(status) + self.headers = datastructures.Headers(headers) + self.body = body + + def __str__(self) -> str: + return ( + f"HTTP {self.status:d}, " + f"{len(self.headers)} headers, " + f"{len(self.body)} bytes" + ) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 1bf359f32..a71996e45 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -24,7 +24,6 @@ from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError from ..exceptions import ( - AbortHandshake, InvalidHandshake, InvalidHeader, InvalidOrigin, @@ -42,7 +41,7 @@ from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .exceptions import InvalidMessage +from .exceptions import AbortHandshake, InvalidMessage from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol, broadcast diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index 8b9a616e9..1bab24c4f 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -15,6 +15,10 @@ def test_str(self): InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", ), + ( + AbortHandshake(200, Headers(), b"OK\n"), + "HTTP 200, 0 headers, 3 bytes", + ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b79489e44..b45642840 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -151,10 +151,6 @@ def test_str(self): InvalidParameterValue("a", "|"), "invalid value for parameter a: |", ), - ( - AbortHandshake(200, Headers(), b"OK\n"), - "HTTP 200, 0 headers, 3 bytes", - ), ( RedirectHandshake("wss://example.com"), "redirect to wss://example.com", From 5f31bcb5a52e4ec8f2b4fc52e9e9b1ffef356771 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 09:12:20 +0200 Subject: [PATCH 586/762] Move RedirectHandshake to the legacy package. It is only used by the legacy implementation. --- src/websockets/__init__.py | 6 +++--- src/websockets/exceptions.py | 19 +++---------------- src/websockets/legacy/client.py | 4 ++-- src/websockets/legacy/exceptions.py | 15 +++++++++++++++ tests/legacy/test_exceptions.py | 4 ++++ tests/test_exceptions.py | 4 ---- 6 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index a623d32c4..12141adb0 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -32,7 +32,6 @@ "NegotiationError", "PayloadTooBig", "ProtocolError", - "RedirectHandshake", "SecurityError", "WebSocketException", "WebSocketProtocolError", @@ -47,6 +46,7 @@ "AbortHandshake", "InvalidMessage", "InvalidStatusCode", + "RedirectHandshake", # .legacy.protocol "WebSocketCommonProtocol", # .legacy.server @@ -90,7 +90,6 @@ NegotiationError, PayloadTooBig, ProtocolError, - RedirectHandshake, SecurityError, WebSocketException, WebSocketProtocolError, @@ -104,6 +103,7 @@ AbortHandshake, InvalidMessage, InvalidStatusCode, + RedirectHandshake, ) from .legacy.protocol import WebSocketCommonProtocol from .legacy.server import ( @@ -152,7 +152,6 @@ "NegotiationError": ".exceptions", "PayloadTooBig": ".exceptions", "ProtocolError": ".exceptions", - "RedirectHandshake": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", "WebSocketProtocolError": ".exceptions", @@ -167,6 +166,7 @@ "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", # .legacy.protocol "WebSocketCommonProtocol": ".legacy.protocol", # .legacy.server diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 803e3fa53..b5aabf75d 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -20,7 +20,7 @@ * :exc:`InvalidParameterName` * :exc:`InvalidParameterValue` * :exc:`AbortHandshake` (legacy) - * :exc:`RedirectHandshake` + * :exc:`RedirectHandshake` (legacy) * :exc:`InvalidState` * :exc:`InvalidURI` * :exc:`PayloadTooBig` @@ -300,21 +300,6 @@ def __str__(self) -> str: return f"invalid value for parameter {self.name}: {self.value}" -class RedirectHandshake(InvalidHandshake): - """ - Raised when a handshake gets redirected. - - This exception is an implementation detail. - - """ - - def __init__(self, uri: str) -> None: - self.uri = uri - - def __str__(self) -> str: - return f"redirect to {self.uri}" - - class InvalidState(WebSocketException, AssertionError): """ Raised when an operation is forbidden in the current state. @@ -360,6 +345,7 @@ class ProtocolError(WebSocketException): AbortHandshake, InvalidMessage, InvalidStatusCode, + RedirectHandshake, ) WebSocketProtocolError = ProtocolError @@ -370,6 +356,7 @@ class ProtocolError(WebSocketException): "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", }, ) diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index fd0803e7d..057acb656 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -22,7 +22,7 @@ InvalidHandshake, InvalidHeader, NegotiationError, - RedirectHandshake, + SecurityError, ) from ..extensions import ClientExtensionFactory, Extension @@ -39,7 +39,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .exceptions import InvalidMessage, InvalidStatusCode +from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index d02a1f933..9ca9b7aff 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -61,3 +61,18 @@ def __str__(self) -> str: f"{len(self.headers)} headers, " f"{len(self.body)} bytes" ) + + +class RedirectHandshake(InvalidHandshake): + """ + Raised when a handshake gets redirected. + + This exception is an implementation detail. + + """ + + def __init__(self, uri: str) -> None: + self.uri = uri + + def __str__(self) -> str: + return f"redirect to {self.uri}" diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index 1bab24c4f..e5d22a917 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -19,6 +19,10 @@ def test_str(self): AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), + ( + RedirectHandshake("wss://example.com"), + "redirect to wss://example.com", + ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b45642840..b54903275 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -151,10 +151,6 @@ def test_str(self): InvalidParameterValue("a", "|"), "invalid value for parameter a: |", ), - ( - RedirectHandshake("wss://example.com"), - "redirect to wss://example.com", - ), ( InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", From 58d72bc995b3e254b5804b36860d948a9c2cf38f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 09:21:39 +0200 Subject: [PATCH 587/762] Document legacy exceptions in their own section. This change has the side effect of dropping the WebSocketProtocolError exception from the list. It was documented as an alias of ProtocolError as a side effect of using the automodule directive. Eventually it will get deprecated with the rest of the legacy module so potential users will be aware. --- docs/reference/exceptions.rst | 53 ++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 907a650d2..062e69df1 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -2,5 +2,56 @@ Exceptions ========== .. automodule:: websockets.exceptions - :members: +.. autoexception:: WebSocketException + +.. autoexception:: ConnectionClosed + +.. autoexception:: ConnectionClosedError + +.. autoexception:: ConnectionClosedOK + +.. autoexception:: InvalidHandshake + +.. autoexception:: SecurityError + +.. autoexception:: InvalidHeader + +.. autoexception:: InvalidHeaderFormat + +.. autoexception:: InvalidHeaderValue + +.. autoexception:: InvalidOrigin + +.. autoexception:: InvalidUpgrade + +.. autoexception:: InvalidStatus + +.. autoexception:: NegotiationError + +.. autoexception:: DuplicateParameter + +.. autoexception:: InvalidParameterName + +.. autoexception:: InvalidParameterValue + +.. autoexception:: InvalidState + +.. autoexception:: InvalidURI + +.. autoexception:: PayloadTooBig + +.. autoexception:: ProtocolError + +Legacy exceptions +----------------- + +These exceptions are only used by the legacy :mod:`asyncio` implementation. + +.. autoexception:: InvalidMessage + +.. autoexception:: InvalidStatusCode + +.. autoexception:: AbortHandshake + +.. autoexception:: RedirectHandshake From 60fc23788f8e7d20e97780a73573e68f5d29db68 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 14:22:35 +0200 Subject: [PATCH 588/762] Always raise subclasses of InvalidHandshake. Also shorten messages for InvalidHeader exceptions. --- src/websockets/client.py | 15 +++++++-------- src/websockets/datastructures.py | 2 +- src/websockets/legacy/client.py | 13 +++++++------ src/websockets/legacy/handshake.py | 12 +++--------- src/websockets/legacy/server.py | 2 +- src/websockets/server.py | 11 +++-------- tests/test_client.py | 7 ++++--- tests/test_server.py | 8 +++----- 8 files changed, 29 insertions(+), 41 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 07d1d34ed..c408c1447 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -176,10 +176,7 @@ def process_response(self, response: Response) -> None: except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Accept") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Accept", - "more than one Sec-WebSocket-Accept header found", - ) from exc + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) @@ -225,7 +222,7 @@ def process_extensions(self, headers: Headers) -> list[Extension]: if extensions: if self.available_extensions is None: - raise InvalidHandshake("no extensions supported") + raise NegotiationError("no extensions supported") parsed_extensions: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in extensions], [] @@ -280,15 +277,17 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None: if subprotocols: if self.available_subprotocols is None: - raise InvalidHandshake("no subprotocols supported") + raise NegotiationError("no subprotocols supported") parsed_subprotocols: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in subprotocols], [] ) if len(parsed_subprotocols) > 1: - subprotocols_display = ", ".join(parsed_subprotocols) - raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}") + raise InvalidHeader( + "Sec-WebSocket-Protocol", + f"multiple values: {', '.join(parsed_subprotocols)}", + ) subprotocol = parsed_subprotocols[0] diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 3d64d951e..106d6f393 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -17,7 +17,7 @@ class MultipleValuesError(LookupError): """ - Exception raised when :class:`Headers` has more than one value for a key. + Exception raised when :class:`Headers` has multiple values for a key. """ diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 057acb656..25142ea25 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -19,10 +19,9 @@ from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike from ..exceptions import ( - InvalidHandshake, InvalidHeader, + InvalidHeaderValue, NegotiationError, - SecurityError, ) from ..extensions import ClientExtensionFactory, Extension @@ -181,7 +180,7 @@ def process_extensions( if header_values: if available_extensions is None: - raise InvalidHandshake("no extensions supported") + raise NegotiationError("no extensions supported") parsed_header_values: list[ExtensionHeader] = sum( [parse_extension(header_value) for header_value in header_values], [] @@ -235,15 +234,17 @@ def process_subprotocol( if header_values: if available_subprotocols is None: - raise InvalidHandshake("no subprotocols supported") + raise NegotiationError("no subprotocols supported") parsed_header_values: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in header_values], [] ) if len(parsed_header_values) > 1: - subprotocols = ", ".join(parsed_header_values) - raise InvalidHandshake(f"multiple subprotocols: {subprotocols}") + raise InvalidHeaderValue( + "Sec-WebSocket-Protocol", + f"multiple values: {', '.join(parsed_header_values)}", + ) subprotocol = parsed_header_values[0] diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 2a39c1b03..6a7157c01 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -76,9 +76,7 @@ def check_request(headers: Headers) -> str: except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Key") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" - ) from exc + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc try: raw_key = base64.b64decode(s_w_key.encode(), validate=True) @@ -92,9 +90,7 @@ def check_request(headers: Headers) -> str: except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Version") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found" - ) from exc + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc if s_w_version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) @@ -156,9 +152,7 @@ def check_response(headers: Headers, key: str) -> None: except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Accept") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found" - ) from exc + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc if s_w_accept != accept(key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index a71996e45..b31cc25b8 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -397,7 +397,7 @@ def process_origin( try: origin = headers.get("Origin") except MultipleValuesError as exc: - raise InvalidHeader("Origin", "more than one Origin header found") from exc + raise InvalidHeader("Origin", "multiple values") from exc if origin is not None: origin = cast(Origin, origin) if origins is not None: diff --git a/src/websockets/server.py b/src/websockets/server.py index 11ba8b425..3a03378b2 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -257,9 +257,7 @@ def process_request( except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Key") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" - ) from exc + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc try: raw_key = base64.b64decode(key.encode(), validate=True) @@ -273,10 +271,7 @@ def process_request( except KeyError as exc: raise InvalidHeader("Sec-WebSocket-Version") from exc except MultipleValuesError as exc: - raise InvalidHeader( - "Sec-WebSocket-Version", - "more than one Sec-WebSocket-Version header found", - ) from exc + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) @@ -315,7 +310,7 @@ def process_origin(self, headers: Headers) -> Origin | None: try: origin = headers.get("Origin") except MultipleValuesError as exc: - raise InvalidHeader("Origin", "more than one Origin header found") from exc + raise InvalidHeader("Origin", "multiple values") from exc if origin is not None: origin = cast(Origin, origin) if self.origins is not None: diff --git a/tests/test_client.py b/tests/test_client.py index c83c87038..d798a66f9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -342,8 +342,7 @@ def test_multiple_accept(self): raise client.handshake_exc self.assertEqual( str(raised.exception), - "invalid Sec-WebSocket-Accept header: " - "more than one Sec-WebSocket-Accept header found", + "invalid Sec-WebSocket-Accept header: multiple values", ) def test_invalid_accept(self): @@ -556,7 +555,9 @@ def test_multiple_subprotocols(self): with self.assertRaises(InvalidHandshake) as raised: raise client.handshake_exc self.assertEqual( - str(raised.exception), "multiple subprotocols: superchat, chat" + str(raised.exception), + "invalid Sec-WebSocket-Protocol header: " + "multiple values: superchat, chat", ) def test_supported_subprotocol(self): diff --git a/tests/test_server.py b/tests/test_server.py index e4460dcba..e7d249f49 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -305,8 +305,7 @@ def test_multiple_key(self): raise server.handshake_exc self.assertEqual( str(raised.exception), - "invalid Sec-WebSocket-Key header: " - "more than one Sec-WebSocket-Key header found", + "invalid Sec-WebSocket-Key header: multiple values", ) def test_invalid_key(self): @@ -366,8 +365,7 @@ def test_multiple_version(self): raise server.handshake_exc self.assertEqual( str(raised.exception), - "invalid Sec-WebSocket-Version header: " - "more than one Sec-WebSocket-Version header found", + "invalid Sec-WebSocket-Version header: multiple values", ) def test_invalid_version(self): @@ -437,7 +435,7 @@ def test_multiple_origin(self): raise server.handshake_exc self.assertEqual( str(raised.exception), - "invalid Origin header: more than one Origin header found", + "invalid Origin header: multiple values", ) def test_supported_origin(self): From 4e17142e81d6d832061fd69694875e0d962c54ed Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 14:10:18 +0200 Subject: [PATCH 589/762] Improve documentation of exceptions. Group and order them. Extend docstrings. --- docs/project/changelog.rst | 2 +- docs/reference/exceptions.rst | 33 +++++++-- src/websockets/exceptions.py | 132 +++++++++++++++++++--------------- tests/test_exceptions.py | 24 +++---- 4 files changed, 114 insertions(+), 77 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1c87882f0..e82f61753 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -613,7 +613,7 @@ Bug fixes * Aligned maximum cookie size with popular web browsers. * Ensured cancellation always propagates, even on Python versions where - :exc:`~asyncio.CancelledError` inherits :exc:`Exception`. + :exc:`~asyncio.CancelledError` inherits from :exc:`Exception`. .. _8.1: diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 062e69df1..14a8edcd1 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -5,16 +5,35 @@ Exceptions .. autoexception:: WebSocketException +Connection closed +----------------- + +:meth:`~websockets.asyncio.connection.Connection.recv`, +:meth:`~websockets.asyncio.connection.Connection.send`, and similar methods +raise the exceptions below when the connection is closed. This is the expected +way to detect disconnections. + .. autoexception:: ConnectionClosed +.. autoexception:: ConnectionClosedOK + .. autoexception:: ConnectionClosedError -.. autoexception:: ConnectionClosedOK +Connection failed +----------------- + +These exceptions are raised by :func:`~websockets.asyncio.client.connect` when +the opening handshake fails and the connection cannot be established. They are +also reported by :func:`~websockets.asyncio.server.serve` in logs. + +.. autoexception:: InvalidURI .. autoexception:: InvalidHandshake .. autoexception:: SecurityError +.. autoexception:: InvalidStatus + .. autoexception:: InvalidHeader .. autoexception:: InvalidHeaderFormat @@ -25,8 +44,6 @@ Exceptions .. autoexception:: InvalidUpgrade -.. autoexception:: InvalidStatus - .. autoexception:: NegotiationError .. autoexception:: DuplicateParameter @@ -35,13 +52,17 @@ Exceptions .. autoexception:: InvalidParameterValue -.. autoexception:: InvalidState +Sans-I/O exceptions +------------------- -.. autoexception:: InvalidURI +These exceptions are only raised by the Sans-I/O implementation. They are +translated to :exc:`ConnectionClosedError` in the other implementations. + +.. autoexception:: ProtocolError .. autoexception:: PayloadTooBig -.. autoexception:: ProtocolError +.. autoexception:: InvalidState Legacy exceptions ----------------- diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b5aabf75d..8d998bfdb 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -1,30 +1,30 @@ """ -:mod:`websockets.exceptions` defines the following exception hierarchy: +:mod:`websockets.exceptions` defines the following hierarchy of exceptions. * :exc:`WebSocketException` * :exc:`ConnectionClosed` - * :exc:`ConnectionClosedError` * :exc:`ConnectionClosedOK` + * :exc:`ConnectionClosedError` + * :exc:`InvalidURI` * :exc:`InvalidHandshake` * :exc:`SecurityError` * :exc:`InvalidMessage` (legacy) + * :exc:`InvalidStatus` + * :exc:`InvalidStatusCode` (legacy) * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` * :exc:`InvalidOrigin` * :exc:`InvalidUpgrade` - * :exc:`InvalidStatus` - * :exc:`InvalidStatusCode` (legacy) * :exc:`NegotiationError` * :exc:`DuplicateParameter` * :exc:`InvalidParameterName` * :exc:`InvalidParameterValue` * :exc:`AbortHandshake` (legacy) * :exc:`RedirectHandshake` (legacy) - * :exc:`InvalidState` - * :exc:`InvalidURI` - * :exc:`PayloadTooBig` - * :exc:`ProtocolError` + * :exc:`ProtocolError` (Sans-I/O) + * :exc:`PayloadTooBig` (Sans-I/O) + * :exc:`InvalidState` (Sans-I/O) """ @@ -40,29 +40,29 @@ __all__ = [ "WebSocketException", "ConnectionClosed", - "ConnectionClosedError", "ConnectionClosedOK", + "ConnectionClosedError", + "InvalidURI", "InvalidHandshake", "SecurityError", "InvalidMessage", + "InvalidStatus", + "InvalidStatusCode", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", "InvalidOrigin", "InvalidUpgrade", - "InvalidStatus", - "InvalidStatusCode", "NegotiationError", "DuplicateParameter", "InvalidParameterName", "InvalidParameterValue", "AbortHandshake", "RedirectHandshake", - "InvalidState", - "InvalidURI", - "PayloadTooBig", "ProtocolError", "WebSocketProtocolError", + "PayloadTooBig", + "InvalidState", ] @@ -139,6 +139,16 @@ def reason(self) -> str: return self.rcvd.reason +class ConnectionClosedOK(ConnectionClosed): + """ + Like :exc:`ConnectionClosed`, when the connection terminated properly. + + A close code with code 1000 (OK) or 1001 (going away) or without a code was + received and sent. + + """ + + class ConnectionClosedError(ConnectionClosed): """ Like :exc:`ConnectionClosed`, when the connection terminated with an error. @@ -149,19 +159,23 @@ class ConnectionClosedError(ConnectionClosed): """ -class ConnectionClosedOK(ConnectionClosed): +class InvalidURI(WebSocketException): """ - Like :exc:`ConnectionClosed`, when the connection terminated properly. - - A close code with code 1000 (OK) or 1001 (going away) or without a code was - received and sent. + Raised when connecting to a URI that isn't a valid WebSocket URI. """ + def __init__(self, uri: str, msg: str) -> None: + self.uri = uri + self.msg = msg + + def __str__(self) -> str: + return f"{self.uri} isn't a valid URI: {self.msg}" + class InvalidHandshake(WebSocketException): """ - Raised during the handshake when the WebSocket connection fails. + Base class for exceptions raised when the opening handshake fails. """ @@ -170,10 +184,27 @@ class SecurityError(InvalidHandshake): """ Raised when a handshake request or response breaks a security rule. - Security limits are hard coded. + Security limits can be configured with :doc:`environment variables + <../reference/variables>`. + + """ + + +class InvalidStatus(InvalidHandshake): + """ + Raised when a handshake response rejects the WebSocket upgrade. """ + def __init__(self, response: http11.Response) -> None: + self.response = response + + def __str__(self) -> str: + return ( + "server rejected WebSocket connection: " + f"HTTP {self.response.status_code:d}" + ) + class InvalidHeader(InvalidHandshake): """ @@ -210,7 +241,7 @@ class InvalidHeaderValue(InvalidHeader): """ Raised when an HTTP header has a wrong value. - The format of the header is correct but a value isn't acceptable. + The format of the header is correct but the value isn't acceptable. """ @@ -232,25 +263,9 @@ class InvalidUpgrade(InvalidHeader): """ -class InvalidStatus(InvalidHandshake): - """ - Raised when a handshake response rejects the WebSocket upgrade. - - """ - - def __init__(self, response: http11.Response) -> None: - self.response = response - - def __str__(self) -> str: - return ( - "server rejected WebSocket connection: " - f"HTTP {self.response.status_code:d}" - ) - - class NegotiationError(InvalidHandshake): """ - Raised when negotiating an extension fails. + Raised when negotiating an extension or a subprotocol fails. """ @@ -300,41 +315,42 @@ def __str__(self) -> str: return f"invalid value for parameter {self.name}: {self.value}" -class InvalidState(WebSocketException, AssertionError): +class ProtocolError(WebSocketException): """ - Raised when an operation is forbidden in the current state. + Raised when receiving or sending a frame that breaks the protocol. - This exception is an implementation detail. + The Sans-I/O implementation raises this exception when: - It should never be raised in normal circumstances. + * receiving or sending a frame that contains invalid data; + * receiving or sending an invalid sequence of frames. """ -class InvalidURI(WebSocketException): +class PayloadTooBig(WebSocketException): """ - Raised when connecting to a URI that isn't a valid WebSocket URI. + Raised when parsing a frame with a payload that exceeds the maximum size. - """ - - def __init__(self, uri: str, msg: str) -> None: - self.uri = uri - self.msg = msg - - def __str__(self) -> str: - return f"{self.uri} isn't a valid URI: {self.msg}" + The Sans-I/O layer uses this exception internally. It doesn't bubble up to + the I/O layer. + The :meth:`~websockets.extensions.Extension.decode` method of extensions + must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit. -class PayloadTooBig(WebSocketException): """ - Raised when receiving a frame with a payload exceeding the maximum size. + +class InvalidState(WebSocketException, AssertionError): """ + Raised when sending a frame is forbidden in the current state. + Specifically, the Sans-I/O layer raises this exception when: -class ProtocolError(WebSocketException): - """ - Raised when a frame breaks the protocol. + * sending a data frame to a connection in a state other + :attr:`~websockets.protocol.State.OPEN`; + * sending a control frame to a connection in a state other than + :attr:`~websockets.protocol.State.OPEN` or + :attr:`~websockets.protocol.State.CLOSING`. """ diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b54903275..5620b8a53 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -79,6 +79,10 @@ def test_str(self): ), "no close frame received or sent", ), + ( + InvalidURI("|", "not at all!"), + "| isn't a valid URI: not at all!", + ), ( InvalidHandshake("invalid request"), "invalid request", @@ -87,6 +91,10 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), + ( + InvalidStatus(Response(401, "Unauthorized", Headers())), + "server rejected WebSocket connection: HTTP 401", + ), ( InvalidHeader("Name"), "missing Name header", @@ -123,10 +131,6 @@ def test_str(self): InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), - ( - InvalidStatus(Response(401, "Unauthorized", Headers())), - "server rejected WebSocket connection: HTTP 401", - ), ( NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", @@ -152,20 +156,16 @@ def test_str(self): "invalid value for parameter a: |", ), ( - InvalidState("WebSocket connection isn't established yet"), - "WebSocket connection isn't established yet", - ), - ( - InvalidURI("|", "not at all!"), - "| isn't a valid URI: not at all!", + ProtocolError("invalid opcode: 7"), + "invalid opcode: 7", ), ( PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), "payload length exceeds limit: 2 > 1 bytes", ), ( - ProtocolError("invalid opcode: 7"), - "invalid opcode: 7", + InvalidState("WebSocket connection isn't established yet"), + "WebSocket connection isn't established yet", ), ]: with self.subTest(exception=exception): From 2a9dfb593eb7dd17424242cf344945af0510aa3f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 14:40:58 +0200 Subject: [PATCH 590/762] Simplify untangling of import cycles Three imports at the bottom, related to type annotations, is a lesser evil compared to dozens of module-prefixed identifiers, a departure from the coding style of this library. Refs #989. --- src/websockets/exceptions.py | 4 +- src/websockets/extensions/base.py | 11 ++--- .../extensions/permessage_deflate.py | 48 +++++++++++-------- src/websockets/frames.py | 24 +++++----- src/websockets/headers.py | 32 +++++-------- src/websockets/http11.py | 23 ++++----- src/websockets/uri.py | 10 ++-- 7 files changed, 74 insertions(+), 78 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 8d998bfdb..b2b679e6b 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -33,7 +33,6 @@ import typing import warnings -from . import frames, http11 from .imports import lazy_import @@ -376,3 +375,6 @@ class InvalidState(WebSocketException, AssertionError): "WebSocketProtocolError": ".legacy.exceptions", }, ) + +# At the bottom to break import cycles created by type annotations. +from . import frames, http11 # noqa: E402 diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index a6c76c3d4..75bae6b77 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -2,7 +2,7 @@ from typing import Sequence -from .. import frames +from ..frames import Frame from ..typing import ExtensionName, ExtensionParameter @@ -18,12 +18,7 @@ class Extension: name: ExtensionName """Extension identifier.""" - def decode( - self, - frame: frames.Frame, - *, - max_size: int | None = None, - ) -> frames.Frame: + def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame: """ Decode an incoming frame. @@ -40,7 +35,7 @@ def decode( """ raise NotImplementedError - def encode(self, frame: frames.Frame) -> frames.Frame: + def encode(self, frame: Frame) -> Frame: """ Encode an outgoing frame. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 5b907b79f..25d2c1c45 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -4,7 +4,15 @@ import zlib from typing import Any, Sequence -from .. import exceptions, frames +from .. import frames +from ..exceptions import ( + DuplicateParameter, + InvalidParameterName, + InvalidParameterValue, + NegotiationError, + PayloadTooBig, + ProtocolError, +) from ..typing import ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -129,9 +137,9 @@ def decode( try: data = self.decoder.decompress(data, max_length) except zlib.error as exc: - raise exceptions.ProtocolError("decompression failed") from exc + raise ProtocolError("decompression failed") from exc if self.decoder.unconsumed_tail: - raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)") + raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -215,40 +223,40 @@ def _extract_parameters( for name, value in params: if name == "server_no_context_takeover": if server_no_context_takeover: - raise exceptions.DuplicateParameter(name) + raise DuplicateParameter(name) if value is None: server_no_context_takeover = True else: - raise exceptions.InvalidParameterValue(name, value) + raise InvalidParameterValue(name, value) elif name == "client_no_context_takeover": if client_no_context_takeover: - raise exceptions.DuplicateParameter(name) + raise DuplicateParameter(name) if value is None: client_no_context_takeover = True else: - raise exceptions.InvalidParameterValue(name, value) + raise InvalidParameterValue(name, value) elif name == "server_max_window_bits": if server_max_window_bits is not None: - raise exceptions.DuplicateParameter(name) + raise DuplicateParameter(name) if value in _MAX_WINDOW_BITS_VALUES: server_max_window_bits = int(value) else: - raise exceptions.InvalidParameterValue(name, value) + raise InvalidParameterValue(name, value) elif name == "client_max_window_bits": if client_max_window_bits is not None: - raise exceptions.DuplicateParameter(name) + raise DuplicateParameter(name) if is_server and value is None: # only in handshake requests client_max_window_bits = True elif value in _MAX_WINDOW_BITS_VALUES: client_max_window_bits = int(value) else: - raise exceptions.InvalidParameterValue(name, value) + raise InvalidParameterValue(name, value) else: - raise exceptions.InvalidParameterName(name) + raise InvalidParameterName(name) return ( server_no_context_takeover, @@ -340,7 +348,7 @@ def process_response_params( """ if any(other.name == self.name for other in accepted_extensions): - raise exceptions.NegotiationError(f"received duplicate {self.name}") + raise NegotiationError(f"received duplicate {self.name}") # Request parameters are available in instance variables. @@ -366,7 +374,7 @@ def process_response_params( if self.server_no_context_takeover: if not server_no_context_takeover: - raise exceptions.NegotiationError("expected server_no_context_takeover") + raise NegotiationError("expected server_no_context_takeover") # client_no_context_takeover # @@ -396,9 +404,9 @@ def process_response_params( else: if server_max_window_bits is None: - raise exceptions.NegotiationError("expected server_max_window_bits") + raise NegotiationError("expected server_max_window_bits") elif server_max_window_bits > self.server_max_window_bits: - raise exceptions.NegotiationError("unsupported server_max_window_bits") + raise NegotiationError("unsupported server_max_window_bits") # client_max_window_bits @@ -414,7 +422,7 @@ def process_response_params( if self.client_max_window_bits is None: if client_max_window_bits is not None: - raise exceptions.NegotiationError("unexpected client_max_window_bits") + raise NegotiationError("unexpected client_max_window_bits") elif self.client_max_window_bits is True: pass @@ -423,7 +431,7 @@ def process_response_params( if client_max_window_bits is None: client_max_window_bits = self.client_max_window_bits elif client_max_window_bits > self.client_max_window_bits: - raise exceptions.NegotiationError("unsupported client_max_window_bits") + raise NegotiationError("unsupported client_max_window_bits") return PerMessageDeflate( server_no_context_takeover, # remote_no_context_takeover @@ -534,7 +542,7 @@ def process_request_params( """ if any(other.name == self.name for other in accepted_extensions): - raise exceptions.NegotiationError(f"skipped duplicate {self.name}") + raise NegotiationError(f"skipped duplicate {self.name}") # Load request parameters in local variables. ( @@ -613,7 +621,7 @@ def process_request_params( else: if client_max_window_bits is None: if self.require_client_max_window_bits: - raise exceptions.NegotiationError("required client_max_window_bits") + raise NegotiationError("required client_max_window_bits") elif client_max_window_bits is True: client_max_window_bits = self.client_max_window_bits elif self.client_max_window_bits < client_max_window_bits: diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 8e44dd3a2..a63bdc3b6 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -8,7 +8,7 @@ import struct from typing import Callable, Generator, Sequence -from . import exceptions, extensions +from .exceptions import PayloadTooBig, ProtocolError try: @@ -239,10 +239,10 @@ def parse( try: opcode = Opcode(head1 & 0b00001111) except ValueError as exc: - raise exceptions.ProtocolError("invalid opcode") from exc + raise ProtocolError("invalid opcode") from exc if (True if head2 & 0b10000000 else False) != mask: - raise exceptions.ProtocolError("incorrect masking") + raise ProtocolError("incorrect masking") length = head2 & 0b01111111 if length == 126: @@ -252,9 +252,7 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise exceptions.PayloadTooBig( - f"over size limit ({length} > {max_size} bytes)" - ) + raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") if mask: mask_bytes = yield from read_exact(4) @@ -342,13 +340,13 @@ def check(self) -> None: """ if self.rsv1 or self.rsv2 or self.rsv3: - raise exceptions.ProtocolError("reserved bits must be 0") + raise ProtocolError("reserved bits must be 0") if self.opcode in CTRL_OPCODES: if len(self.data) > 125: - raise exceptions.ProtocolError("control frame too long") + raise ProtocolError("control frame too long") if not self.fin: - raise exceptions.ProtocolError("fragmented control frame") + raise ProtocolError("fragmented control frame") @dataclasses.dataclass @@ -405,7 +403,7 @@ def parse(cls, data: bytes) -> Close: elif len(data) == 0: return cls(CloseCode.NO_STATUS_RCVD, "") else: - raise exceptions.ProtocolError("close frame too short") + raise ProtocolError("close frame too short") def serialize(self) -> bytes: """ @@ -424,4 +422,8 @@ def check(self) -> None: """ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): - raise exceptions.ProtocolError("invalid status code") + raise ProtocolError("invalid status code") + + +# At the bottom to break import cycles created by type annotations. +from . import extensions # noqa: E402 diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 0ffd65233..9103018a0 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -6,7 +6,7 @@ import re from typing import Callable, Sequence, TypeVar, cast -from . import exceptions +from .exceptions import InvalidHeaderFormat, InvalidHeaderValue from .typing import ( ConnectionOption, ExtensionHeader, @@ -108,7 +108,7 @@ def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]: """ match = _token_re.match(header, pos) if match is None: - raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos) + raise InvalidHeaderFormat(header_name, "expected token", header, pos) return match.group(), match.end() @@ -132,9 +132,7 @@ def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, i """ match = _quoted_string_re.match(header, pos) if match is None: - raise exceptions.InvalidHeaderFormat( - header_name, "expected quoted string", header, pos - ) + raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos) return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() @@ -206,9 +204,7 @@ def parse_list( if peek_ahead(header, pos) == ",": pos = parse_OWS(header, pos + 1) else: - raise exceptions.InvalidHeaderFormat( - header_name, "expected comma", header, pos - ) + raise InvalidHeaderFormat(header_name, "expected comma", header, pos) # Remove extra delimiters before the next item. while peek_ahead(header, pos) == ",": @@ -276,9 +272,7 @@ def parse_upgrade_protocol( """ match = _protocol_re.match(header, pos) if match is None: - raise exceptions.InvalidHeaderFormat( - header_name, "expected protocol", header, pos - ) + raise InvalidHeaderFormat(header_name, "expected protocol", header, pos) return cast(UpgradeProtocol, match.group()), match.end() @@ -324,7 +318,7 @@ def parse_extension_item_param( # the value after quoted-string unescaping MUST conform to # the 'token' ABNF. if _token_re.fullmatch(value) is None: - raise exceptions.InvalidHeaderFormat( + raise InvalidHeaderFormat( header_name, "invalid quoted header content", header, pos_before ) else: @@ -510,9 +504,7 @@ def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]: """ match = _token68_re.match(header, pos) if match is None: - raise exceptions.InvalidHeaderFormat( - header_name, "expected token68", header, pos - ) + raise InvalidHeaderFormat(header_name, "expected token68", header, pos) return match.group(), match.end() @@ -522,7 +514,7 @@ def parse_end(header: str, pos: int, header_name: str) -> None: """ if pos < len(header): - raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos) + raise InvalidHeaderFormat(header_name, "trailing data", header, pos) def parse_authorization_basic(header: str) -> tuple[str, str]: @@ -543,12 +535,12 @@ def parse_authorization_basic(header: str) -> tuple[str, str]: # https://datatracker.ietf.org/doc/html/rfc7617#section-2 scheme, pos = parse_token(header, 0, "Authorization") if scheme.lower() != "basic": - raise exceptions.InvalidHeaderValue( + raise InvalidHeaderValue( "Authorization", f"unsupported scheme: {scheme}", ) if peek_ahead(header, pos) != " ": - raise exceptions.InvalidHeaderFormat( + raise InvalidHeaderFormat( "Authorization", "expected space after scheme", header, pos ) pos += 1 @@ -558,14 +550,14 @@ def parse_authorization_basic(header: str) -> tuple[str, str]: try: user_pass = base64.b64decode(basic_credentials.encode()).decode() except binascii.Error: - raise exceptions.InvalidHeaderValue( + raise InvalidHeaderValue( "Authorization", "expected base64-encoded credentials", ) from None try: username, password = user_pass.split(":", 1) except ValueError: - raise exceptions.InvalidHeaderValue( + raise InvalidHeaderValue( "Authorization", "expected username:password credentials", ) from None diff --git a/src/websockets/http11.py b/src/websockets/http11.py index b86c6ca4a..562bcb72c 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -7,7 +7,8 @@ import warnings from typing import Callable, Generator -from . import datastructures, exceptions +from .datastructures import Headers +from .exceptions import SecurityError from .version import version as websockets_version @@ -79,7 +80,7 @@ class Request: """ path: str - headers: datastructures.Headers + headers: Headers # body isn't useful is the context of this library. _exception: Exception | None = None @@ -183,7 +184,7 @@ class Response: status_code: int reason_phrase: str - headers: datastructures.Headers + headers: Headers body: bytes | None = None _exception: Exception | None = None @@ -280,13 +281,9 @@ def parse( try: body = yield from read_to_eof(MAX_BODY_SIZE) except RuntimeError: - raise exceptions.SecurityError( - f"body too large: over {MAX_BODY_SIZE} bytes" - ) + raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") elif content_length > MAX_BODY_SIZE: - raise exceptions.SecurityError( - f"body too large: {content_length} bytes" - ) + raise SecurityError(f"body too large: {content_length} bytes") else: body = yield from read_exact(content_length) @@ -308,7 +305,7 @@ def serialize(self) -> bytes: def parse_headers( read_line: Callable[[int], Generator[None, None, bytes]], -) -> Generator[None, None, datastructures.Headers]: +) -> Generator[None, None, Headers]: """ Parse HTTP headers. @@ -328,7 +325,7 @@ def parse_headers( # We don't attempt to support obsolete line folding. - headers = datastructures.Headers() + headers = Headers() for _ in range(MAX_NUM_HEADERS + 1): try: line = yield from parse_line(read_line) @@ -352,7 +349,7 @@ def parse_headers( headers[name] = value else: - raise exceptions.SecurityError("too many HTTP headers") + raise SecurityError("too many HTTP headers") return headers @@ -377,7 +374,7 @@ def parse_line( try: line = yield from read_line(MAX_LINE_LENGTH) except RuntimeError: - raise exceptions.SecurityError("line too long") + raise SecurityError("line too long") # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 if not line.endswith(b"\r\n"): raise EOFError("line without CRLF") diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 82b35f92a..16bb3f1c1 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -3,7 +3,7 @@ import dataclasses import urllib.parse -from . import exceptions +from .exceptions import InvalidURI __all__ = ["parse_uri", "WebSocketURI"] @@ -73,11 +73,11 @@ def parse_uri(uri: str) -> WebSocketURI: """ parsed = urllib.parse.urlparse(uri) if parsed.scheme not in ["ws", "wss"]: - raise exceptions.InvalidURI(uri, "scheme isn't ws or wss") + raise InvalidURI(uri, "scheme isn't ws or wss") if parsed.hostname is None: - raise exceptions.InvalidURI(uri, "hostname isn't provided") + raise InvalidURI(uri, "hostname isn't provided") if parsed.fragment != "": - raise exceptions.InvalidURI(uri, "fragment identifier is meaningless") + raise InvalidURI(uri, "fragment identifier is meaningless") secure = parsed.scheme == "wss" host = parsed.hostname @@ -89,7 +89,7 @@ def parse_uri(uri: str) -> WebSocketURI: # urllib.parse.urlparse accepts URLs with a username but without a # password. This doesn't make sense for HTTP Basic Auth credentials. if username is not None and password is None: - raise exceptions.InvalidURI(uri, "username provided without password") + raise InvalidURI(uri, "username provided without password") try: uri.encode("ascii") From ed7d392f58bc08ead4556c874d9d980b11874d2a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 21:56:13 +0200 Subject: [PATCH 591/762] Note when each deprecation occurred. The legacy package was ignored because it will be deprecated and removed wholesale. --- docs/project/changelog.rst | 11 +++++++++++ src/websockets/__init__.py | 1 + src/websockets/client.py | 2 +- src/websockets/connection.py | 2 +- src/websockets/http.py | 2 +- src/websockets/http11.py | 4 ++-- src/websockets/server.py | 2 +- src/websockets/sync/client.py | 5 ++++- src/websockets/sync/server.py | 7 +++++-- 9 files changed, 27 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index e82f61753..c239cfe5d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,17 @@ notice. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: The ``code`` and ``reason`` attributes of + :exc:`~exceptions.ConnectionClosed` are deprecated. + :class: note + + They were removed from the documentation in version 10.0, due to their + spec-compliant but counter-intuitive behavior, but they were kept in + the code for backwards compatibility. They're now formally deprecated. + New features ............ diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 12141adb0..ac02a9f7e 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -187,6 +187,7 @@ "Subprotocol": ".typing", }, deprecated_aliases={ + # deprecated in 9.0 - 2021-09-01 "framing": ".legacy", "handshake": ".legacy", "parse_uri": ".uri", diff --git a/src/websockets/client.py b/src/websockets/client.py index c408c1447..ae467993a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -352,7 +352,7 @@ def parse(self) -> Generator[None, None, None]: class ClientConnection(ClientProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: - warnings.warn( + warnings.warn( # deprecated in 11.0 - 2023-04-02 "ClientConnection was renamed to ClientProtocol", DeprecationWarning, ) diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 7942c1a28..5e78e3447 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -5,7 +5,7 @@ from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 -warnings.warn( +warnings.warn( # deprecated in 11.0 - 2023-04-02 "websockets.connection was renamed to websockets.protocol " "and Connection was renamed to Protocol", DeprecationWarning, diff --git a/src/websockets/http.py b/src/websockets/http.py index 3dc560062..0ff5598c7 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -6,7 +6,7 @@ from .legacy.http import read_request, read_response # noqa: F401 -warnings.warn( +warnings.warn( # deprecated in 9.0 - 2021-09-01 "Headers and MultipleValuesError were moved " "from websockets.http to websockets.datastructures" "and read_request and read_response were moved " diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 562bcb72c..61865bb92 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -87,7 +87,7 @@ class Request: @property def exception(self) -> Exception | None: # pragma: no cover - warnings.warn( + warnings.warn( # deprecated in 10.3 - 2022-04-17 "Request.exception is deprecated; " "use ServerProtocol.handshake_exc instead", DeprecationWarning, @@ -191,7 +191,7 @@ class Response: @property def exception(self) -> Exception | None: # pragma: no cover - warnings.warn( + warnings.warn( # deprecated in 10.3 - 2022-04-17 "Response.exception is deprecated; " "use ClientProtocol.handshake_exc instead", DeprecationWarning, diff --git a/src/websockets/server.py b/src/websockets/server.py index 3a03378b2..ac62800d6 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -575,7 +575,7 @@ def parse(self) -> Generator[None, None, None]: class ServerConnection(ServerProtocol): def __init__(self, *args: Any, **kwargs: Any) -> None: - warnings.warn( + warnings.warn( # deprecated in 11.0 - 2023-04-02 "ServerConnection was renamed to ServerProtocol", DeprecationWarning, ) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 3c700a377..6a04515f0 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -212,7 +212,10 @@ def connect( # Backwards compatibility: ssl used to be called ssl_context. if ssl is None and "ssl_context" in kwargs: ssl = kwargs.pop("ssl_context") - warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + warnings.warn( # deprecated in 13.0 - 2024-08-20 + "ssl_context was renamed to ssl", + DeprecationWarning, + ) wsuri = parse_uri(uri) if not wsuri.secure and ssl is not None: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 5e22e112e..15de458b5 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -304,7 +304,7 @@ def __exit__( def __getattr__(name: str) -> Any: if name == "WebSocketServer": - warnings.warn( + warnings.warn( # deprecated in 13.0 - 2024-08-20 "WebSocketServer was renamed to Server", DeprecationWarning, ) @@ -446,7 +446,10 @@ def handler(websocket): # Backwards compatibility: ssl used to be called ssl_context. if ssl is None and "ssl_context" in kwargs: ssl = kwargs.pop("ssl_context") - warnings.warn("ssl_context was renamed to ssl", DeprecationWarning) + warnings.warn( # deprecated in 13.0 - 2024-08-20 + "ssl_context was renamed to ssl", + DeprecationWarning, + ) if subprotocols is not None: validate_subprotocols(subprotocols) From d61dc43614063b1620b30f50df3947e6079bb371 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 23 Aug 2024 22:46:57 +0200 Subject: [PATCH 592/762] Restructure upgrade guide. --- docs/howto/upgrade.rst | 346 +++++++++++++++++++++++------------------ 1 file changed, 194 insertions(+), 152 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 602d8a4e6..18c3cc127 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -16,8 +16,10 @@ The recommended upgrade process is: should stick to the original implementation until they're added. 3. `Update import paths`_. For straightforward usage of websockets, this could be the only step you need to take. Upgrading could be transparent. -4. `Review API changes`_ and adapt your application to preserve its current - functionality or take advantage of improvements in the new implementation. +4. Check out `new features and improvements`_ and consider taking advantage of + them to improve your application. +5. Review `API changes`_ and adapt your application to preserve its current + functionality. In the interest of brevity, only :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` are discussed below but everything also applies @@ -99,11 +101,12 @@ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package. -* The ``websockets`` package provides aliases for convenience. +* The ``websockets`` package provides aliases for convenience. Currently, they + point to the original implementation. They will be updated to point to the new + implementation when it feels mature. * The ``websockets.client`` and ``websockets.server`` packages provide aliases - for backwards-compatibility with earlier versions of websockets. -* Currently, all aliases point to the original implementation. In the future, - they will point to the new implementation or they will be deprecated. + for backwards-compatibility with earlier versions of websockets. They will + be deprecated together with the original implementation. To upgrade to the new :mod:`asyncio` implementation, change import paths as shown in the tables below. @@ -153,7 +156,7 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast`` |br| | :func:`websockets.asyncio.server.broadcast` | +| ``websockets.broadcast()`` |br| | :func:`websockets.asyncio.server.broadcast` | | :func:`websockets.legacy.server.broadcast()` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | @@ -165,10 +168,32 @@ Server APIs | :func:`websockets.legacy.auth.basic_auth_protocol_factory` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -.. _Review API changes: +.. _new features and improvements: -API changes ------------ +New features and improvements +----------------------------- + +Customizing the opening handshake +................................. + +On the server side, if you're customizing how :func:`~legacy.server.serve` +processes the opening handshake with the ``process_request``, ``extra_headers``, +or ``select_subprotocol``, you must update your code and you can probably make +it simpler. + +``process_request`` and ``select_subprotocol`` have new signatures. +``process_response`` replaces ``extra_headers`` and provides more flexibility. +See process_request_, select_subprotocol_, and process_response_ below. + +Tracking open connections +......................... + +The new implementation of :class:`~asyncio.server.Server` provides a +:attr:`~asyncio.server.Server.connections` property, which is a set of all open +connections. This didn't exist in the original implementation. + +If you're keeping track of open connections in order to broadcast messages to +all of them, you can simplify your code by using this property. Controlling UTF-8 decoding .......................... @@ -197,97 +222,76 @@ implementation. Depending on your use case, adopting this method may improve performance when streaming large messages. Specifically, it could reduce memory usage. -Tracking open connections -......................... - -The new implementation of :class:`~asyncio.server.Server` provides a -:attr:`~asyncio.server.Server.connections` property, which is a set of all open -connections. This didn't exist in the original implementation. - -If you were keeping track of open connections, you may be able to simplify your -code by using this property. - -.. _basic-auth: - -Performing HTTP Basic Authentication -.................................... +.. _API changes: -.. admonition:: This section applies only to servers. - :class: tip - - On the client side, :func:`~asyncio.client.connect` performs HTTP Basic - Authentication automatically when the URI contains credentials. +API changes +----------- -In the original implementation, the recommended way to add HTTP Basic -Authentication to a server was to set the ``create_protocol`` argument of -:func:`~legacy.server.serve` to a factory function generated by -:func:`~legacy.auth.basic_auth_protocol_factory`:: +Attributes of connection objects +................................ - from websockets.legacy.auth import basic_auth_protocol_factory - from websockets.legacy.server import serve +``path``, ``request_headers``, and ``response_headers`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - async with serve(..., create_protocol=basic_auth_protocol_factory(...)): - ... +The :attr:`~legacy.protocol.WebSocketCommonProtocol.path`, +:attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` properties are +replaced by :attr:`~asyncio.connection.Connection.request` and +:attr:`~asyncio.connection.Connection.response`. -In the new implementation, the :func:`~asyncio.server.basic_auth` function -generates a ``process_request`` coroutine that performs HTTP Basic -Authentication:: +If your code uses them, you can update it as follows. - from websockets.asyncio.server import basic_auth, serve +========================================== ========================================== +Legacy :mod:`asyncio` implementation New :mod:`asyncio` implementation +========================================== ========================================== +``connection.path`` ``connection.request.path`` +``connection.request_headers`` ``connection.request.headers`` +``connection.response_headers`` ``connection.response.headers`` +========================================== ========================================== - async with serve(..., process_request=basic_auth(...)): - ... +``open`` and ``closed`` +~~~~~~~~~~~~~~~~~~~~~~~ -:func:`~asyncio.server.basic_auth` accepts either hard coded ``credentials`` or -a ``check_credentials`` coroutine as well as an optional ``realm`` just like -:func:`~legacy.auth.basic_auth_protocol_factory`. Furthermore, -``check_credentials`` may be a function instead of a coroutine. +The :attr:`~legacy.protocol.WebSocketCommonProtocol.open` and +:attr:`~legacy.protocol.WebSocketCommonProtocol.closed` properties are removed. +Using them was discouraged. -This new API has more obvious semantics. That makes it easier to understand and -also easier to extend. +Instead, you should call :meth:`~asyncio.connection.Connection.recv` or +:meth:`~asyncio.connection.Connection.send` and handle +:exc:`~exceptions.ConnectionClosed` exceptions. -In the original implementation, overriding ``create_protocol`` changed the type -of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, -a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP -Basic Authentication in its ``process_request`` method. If you wanted to -customize ``process_request`` further, you had: +If your code uses them, you can update it as follows. -* an ill-defined option: add a ``process_request`` argument to - :func:`~legacy.server.serve`; to tell which one would run first, you had to - experiment or read the code; -* a cumbersome option: subclass - :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that - subclass in the ``create_protocol`` argument of - :func:`~legacy.auth.basic_auth_protocol_factory`. +========================================== ========================================== +Legacy :mod:`asyncio` implementation New :mod:`asyncio` implementation +========================================== ========================================== +.. ``from websockets.protocol import State`` +``connection.open`` ``connection.state is State.OPEN`` +``connection.closed`` ``connection.state is State.CLOSED`` +========================================== ========================================== -In the new implementation, you just write a ``process_request`` coroutine:: +Arguments of :func:`~asyncio.client.connect` +............................................ - from websockets.asyncio.server import basic_auth, serve +``extra_headers`` → ``additional_headers`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - process_basic_auth = basic_auth(...) +If you're adding headers to the handshake request sent by +:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must +rename it to ``additional_headers``. - async def process_request(connection, request): - ... # some logic here - response = await process_basic_auth(connection, request) - if response is not None: - return response - ... # more logic here +Arguments of :func:`~asyncio.server.serve` +.......................................... - async with serve(..., process_request=process_request): - ... +``ws_handler`` → ``handler`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Customizing the opening handshake -................................. +The first argument of :func:`~asyncio.server.serve` is now called ``handler`` +instead of ``ws_handler``. It's usually passed as a positional argument, making +this change transparent. If you're passing it as a keyword argument, you must +update its name. -On the client side, if you're adding headers to the handshake request sent by -:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must -rename it to ``additional_headers``. - -On the server side, if you're customizing how :func:`~legacy.server.serve` -processes the opening handshake with the ``process_request``, ``extra_headers``, -or ``select_subprotocol``, you must update your code. ``process_response`` and -``select_subprotocol`` have new signatures; ``process_response`` replaces -``extra_headers`` and provides more flexibility. +.. _process_request: ``process_request`` ~~~~~~~~~~~~~~~~~~~ @@ -302,8 +306,6 @@ an example:: def process_request(path, request_headers): return http.HTTPStatus.OK, [], b"OK\n" - serve(..., process_request=process_request, ...) - # New implementation def process_request(connection, request): @@ -312,16 +314,22 @@ an example:: serve(..., process_request=process_request, ...) ``connection`` is always available in ``process_request``. In the original -implementation, you had to write a subclass of +implementation, if you wanted to make the connection object available in a +``process_request`` method, you had to write a subclass of :class:`~legacy.server.WebSocketServerProtocol` and pass it in the -``create_protocol`` argument to make the connection object available in a -``process_request`` method. This pattern isn't useful anymore; you can replace -it with a ``process_request`` function or coroutine. +``create_protocol`` argument. This pattern isn't useful anymore; you can +replace it with a ``process_request`` function or coroutine. ``path`` and ``headers`` are available as attributes of the ``request`` object. -``process_response`` -~~~~~~~~~~~~~~~~~~~~ +.. _process_response: + +``extra_headers`` → ``process_response`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you're adding headers to the handshake response sent by +:func:`~legacy.server.serve` with the ``extra_headers`` argument, you must write +a ``process_response`` callable instead. ``process_request`` replaces ``extra_headers`` and provides more flexibility. In the most basic case, you would adapt your code as follows:: @@ -346,10 +354,13 @@ In addition, the ``request`` and ``response`` objects are available, which enables a broader range of use cases (e.g., logging) and makes ``process_response`` more useful than ``extra_headers``. +.. _select_subprotocol: + ``select_subprotocol`` ~~~~~~~~~~~~~~~~~~~~~~ -The signature of ``select_subprotocol`` changed. Here's an example:: +If you're selecting a subprotocol, you must update your code because the +signature of ``select_subprotocol`` changed. Here's an example:: # Original implementation @@ -366,37 +377,30 @@ The signature of ``select_subprotocol`` changed. Here's an example:: serve(..., select_subprotocol=select_subprotocol, ...) ``connection`` is always available in ``select_subprotocol``. This brings the -same benefits as in ``process_request``. It may remove the need to subclass of +same benefits as in ``process_request``. It may remove the need to subclass :class:`~legacy.server.WebSocketServerProtocol`. The ``subprotocols`` argument contains the list of subprotocols offered by the client. The list of subprotocols supported by the server was removed because -``select_subprotocols`` already knows which subprotocols it may select and under +``select_subprotocols`` has to know which subprotocols it may select and under which conditions. -Arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` -.............................................................................. - -``ws_handler`` → ``handler`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The first argument of :func:`~asyncio.server.serve` is now called ``handler`` -instead of ``ws_handler``. It's usually passed as a positional argument, making -this change transparent. If you're passing it as a keyword argument, you must -update its name. +Furthermore, the default behavior when ``select_subprotocol`` isn't provided +changed in two ways: -``create_protocol`` → ``create_connection`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +1. In the original implementation, a server with a list of subprotocols accepted + to continue without a subprotocol. In the new implementation, a server that + is configured with subprotocols rejects connections that don't support any. +2. In the original implementation, when several subprotocols were available, the + server averaged the client's preferences with its own preferences. In the new + implementation, the server just picks the first subprotocol from its list. -The keyword argument of :func:`~asyncio.server.serve` for customizing the -creation of the connection object is now called ``create_connection`` instead of -``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` -instead of a :class:`~legacy.server.WebSocketServerProtocol`. +If you had a ``select_subprotocol`` for the sole purpose of rejecting +connections without a subprotocol, you can remove it and keep only the +``subprotocols`` argument. -If you were customizing connection objects, you should check the new -implementation and possibly redo your customization. Keep in mind that the -changes to ``process_request`` and ``select_subprotocol`` remove most use cases -for ``create_connection``. +Arguments of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` +.............................................................................. ``max_queue`` ~~~~~~~~~~~~~ @@ -409,11 +413,15 @@ frames. It used to be the size of a buffer of incoming messages that refilled as soon as a message was read. It used to default to 32 messages. This can make a difference when messages are fragmented in several frames. In -that case, you may want to increase ``max_queue``. If you're writing a high -performance server and you know that you're receiving fragmented messages, -probably you should adopt :meth:`~asyncio.connection.Connection.recv_streaming` -and optimize the performance of reads again. In all other cases, given how -uncommon fragmentation is, you shouldn't worry about this change. +that case, you may want to increase ``max_queue``. + +If you're writing a high performance server and you know that you're receiving +fragmented messages, probably you should adopt +:meth:`~asyncio.connection.Connection.recv_streaming` and optimize the +performance of reads again. + +In all other cases, given how uncommon fragmentation is, you shouldn't worry +about this change. ``read_limit`` ~~~~~~~~~~~~~~ @@ -432,50 +440,84 @@ buffer now. The ``write_limit`` argument of :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` defaults to 32 KiB instead of 64 KiB. -Attributes of connections -......................... +``create_protocol`` → ``create_connection`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -``path``, ``request_headers`` and ``response_headers`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The keyword argument of :func:`~asyncio.server.serve` for customizing the +creation of the connection object is now called ``create_connection`` instead of +``create_protocol``. It must return a :class:`~asyncio.server.ServerConnection` +instead of a :class:`~legacy.server.WebSocketServerProtocol`. -The :attr:`~legacy.protocol.WebSocketCommonProtocol.path`, -:attr:`~legacy.protocol.WebSocketCommonProtocol.request_headers` and -:attr:`~legacy.protocol.WebSocketCommonProtocol.response_headers` properties are -replaced by :attr:`~asyncio.connection.Connection.request` and -:attr:`~asyncio.connection.Connection.response`, which provide a ``headers`` -attribute. +If you were customizing connection objects, probably you need to redo your +customization. Consider switching to ``process_request`` and +``select_subprotocol`` as their new design removes most use cases for +``create_connection``. -If your code relies on them, you can replace:: +.. _basic-auth: - connection.path - connection.request_headers - connection.response_headers +Performing HTTP Basic Authentication +.................................... -with:: +.. admonition:: This section applies only to servers. + :class: tip - connection.request.path - connection.request.headers - connection.response.headers + On the client side, :func:`~asyncio.client.connect` performs HTTP Basic + Authentication automatically when the URI contains credentials. -``open`` and ``closed`` -~~~~~~~~~~~~~~~~~~~~~~~ +In the original implementation, the recommended way to add HTTP Basic +Authentication to a server was to set the ``create_protocol`` argument of +:func:`~legacy.server.serve` to a factory function generated by +:func:`~legacy.auth.basic_auth_protocol_factory`:: -The :attr:`~legacy.protocol.WebSocketCommonProtocol.open` and -:attr:`~legacy.protocol.WebSocketCommonProtocol.closed` properties are removed. -Using them was discouraged. + from websockets.legacy.auth import basic_auth_protocol_factory + from websockets.legacy.server import serve -Instead, you should call :meth:`~asyncio.connection.Connection.recv` or -:meth:`~asyncio.connection.Connection.send` and handle -:exc:`~exceptions.ConnectionClosed` exceptions. + async with serve(..., create_protocol=basic_auth_protocol_factory(...)): + ... + +In the new implementation, the :func:`~asyncio.server.basic_auth` function +generates a ``process_request`` coroutine that performs HTTP Basic +Authentication:: -If your code relies on them, you can replace:: + from websockets.asyncio.server import basic_auth, serve - connection.open - connection.closed + async with serve(..., process_request=basic_auth(...)): + ... -with:: +:func:`~asyncio.server.basic_auth` accepts either hard coded ``credentials`` or +a ``check_credentials`` coroutine as well as an optional ``realm`` just like +:func:`~legacy.auth.basic_auth_protocol_factory`. Furthermore, +``check_credentials`` may be a function instead of a coroutine. - from websockets.protocol import State +This new API has more obvious semantics. That makes it easier to understand and +also easier to extend. - connection.state is State.OPEN - connection.state is State.CLOSED +In the original implementation, overriding ``create_protocol`` changed the type +of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, +a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP +Basic Authentication in its ``process_request`` method. If you wanted to +customize ``process_request`` further, you had: + +* an ill-defined option: add a ``process_request`` argument to + :func:`~legacy.server.serve`; to tell which one would run first, you had to + experiment or read the code; +* a cumbersome option: subclass + :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that + subclass in the ``create_protocol`` argument of + :func:`~legacy.auth.basic_auth_protocol_factory`. + +In the new implementation, you just write a ``process_request`` coroutine:: + + from websockets.asyncio.server import basic_auth, serve + + process_basic_auth = basic_auth(...) + + async def process_request(connection, request): + ... # some logic here + response = await process_basic_auth(connection, request) + if response is not None: + return response + ... # more logic here + + async with serve(..., process_request=process_request): + ... From 451944aa2d58e6d1c2d42042b1d2eea61a97b578 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 08:35:34 +0200 Subject: [PATCH 593/762] process_request/response can be coroutines. --- src/websockets/asyncio/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 29860e565..f24281252 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -693,14 +693,14 @@ def __init__( process_request: ( Callable[ [ServerConnection, Request], - Response | None, + Awaitable[Response | None] | Response | None, ] | None ) = None, process_response: ( Callable[ [ServerConnection, Request, Response], - Response | None, + Awaitable[Response | None] | Response | None, ] | None ) = None, From 6261ad360db65c95d229a02f40a92892a20183b2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 21:33:04 +0200 Subject: [PATCH 594/762] Remove run_(unix_)client abstraction from tests. It wasn't adding any value, as shown in the diff. --- tests/asyncio/client.py | 33 --------- tests/asyncio/server.py | 9 ++- tests/asyncio/test_client.py | 70 +++++++++----------- tests/asyncio/test_server.py | 125 +++++++++++++++++------------------ tests/sync/client.py | 32 --------- tests/sync/server.py | 8 +++ tests/sync/test_client.py | 67 +++++++++---------- tests/sync/test_server.py | 77 ++++++++++----------- 8 files changed, 178 insertions(+), 243 deletions(-) delete mode 100644 tests/asyncio/client.py delete mode 100644 tests/sync/client.py diff --git a/tests/asyncio/client.py b/tests/asyncio/client.py deleted file mode 100644 index a73079c6e..000000000 --- a/tests/asyncio/client.py +++ /dev/null @@ -1,33 +0,0 @@ -import contextlib - -from websockets.asyncio.client import * -from websockets.asyncio.server import Server - -from .server import get_server_host_port - - -__all__ = [ - "run_client", - "run_unix_client", -] - - -@contextlib.asynccontextmanager -async def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): - if isinstance(wsuri_or_server, str): - wsuri = wsuri_or_server - else: - assert isinstance(wsuri_or_server, Server) - if secure is None: - secure = "ssl" in kwargs - protocol = "wss" if secure else "ws" - host, port = get_server_host_port(wsuri_or_server) - wsuri = f"{protocol}://{host}:{port}{resource_name}" - async with connect(wsuri, **kwargs) as client: - yield client - - -@contextlib.asynccontextmanager -async def run_unix_client(path, **kwargs): - async with unix_connect(path, **kwargs) as client: - yield client diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index 0fe20dc65..06fa92dea 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -5,13 +5,20 @@ from websockets.asyncio.server import * -def get_server_host_port(server): +def get_host_port(server): for sock in server.sockets: if sock.family == socket.AF_INET: # pragma: no branch return sock.getsockname() raise AssertionError("expected at least one IPv4 socket") +def get_uri(server): + secure = server.server._ssl_context is not None # hack + protocol = "wss" if secure else "ws" + host, port = get_host_port(server) + return f"{protocol}://{host}:{port}" + + async def eval_shell(ws): async for expr in ws: value = eval(expr) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 0bd2af4f1..a8ef6ef9d 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -9,49 +9,48 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path -from .client import run_client, run_unix_client -from .server import do_nothing, get_server_host_port, run_server, run_unix_server +from .server import do_nothing, get_host_port, get_uri, run_server, run_unix_server class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server and the handshake succeeds.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_existing_socket(self): """Client connects using a pre-existing socket.""" async with run_server() as server: - with socket.create_connection(get_server_host_port(server)) as sock: + with socket.create_connection(get_host_port(server)) as sock: # Use a non-existing domain to ensure we connect to the right socket. - async with run_client("ws://invalid/", sock=sock) as client: + async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_additional_headers(self): """Client can set additional headers with additional_headers.""" async with run_server() as server: - async with run_client( - server, additional_headers={"Authorization": "Bearer ..."} + async with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} ) as client: self.assertEqual(client.request.headers["Authorization"], "Bearer ...") async def test_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" async with run_server() as server: - async with run_client(server, user_agent_header="Smith") as client: + async with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") async def test_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" async with run_server() as server: - async with run_client(server, user_agent_header=None) as client: + async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) async def test_compression_is_enabled(self): """Client enables compression by default.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual( [type(ext) for ext in client.protocol.extensions], [PerMessageDeflate], @@ -60,13 +59,13 @@ async def test_compression_is_enabled(self): async def test_disable_compression(self): """Client disables compression.""" async with run_server() as server: - async with run_client(server, compression=None) as client: + async with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" async with run_server() as server: - async with run_client(server, ping_interval=MS) as client: + async with connect(get_uri(server), ping_interval=MS) as client: self.assertEqual(client.latency, 0) await asyncio.sleep(2 * MS) self.assertGreater(client.latency, 0) @@ -74,7 +73,7 @@ async def test_keepalive_is_enabled(self): async def test_disable_keepalive(self): """Client disables keepalive.""" async with run_server() as server: - async with run_client(server, ping_interval=None) as client: + async with connect(get_uri(server), ping_interval=None) as client: await asyncio.sleep(2 * MS) self.assertEqual(client.latency, 0) @@ -87,21 +86,21 @@ def create_connection(*args, **kwargs): return client async with run_server() as server: - async with run_client( - server, create_connection=create_connection + async with connect( + get_uri(server), create_connection=create_connection ) as client: self.assertTrue(client.create_connection_ran) async def test_invalid_uri(self): """Client receives an invalid URI.""" with self.assertRaises(InvalidURI): - async with run_client("http://localhost"): # invalid scheme + async with connect("http://localhost"): # invalid scheme self.fail("did not raise") async def test_tcp_connection_fails(self): """Client fails to connect to server.""" with self.assertRaises(OSError): - async with run_client("ws://localhost:54321"): # invalid port + async with connect("ws://localhost:54321"): # invalid port self.fail("did not raise") async def test_handshake_fails(self): @@ -116,7 +115,7 @@ def remove_accept_header(self, request, response): do_nothing, process_response=remove_accept_header ) as server: with self.assertRaises(InvalidHandshake) as raised: - async with run_client(server, close_timeout=MS): + async with connect(get_uri(server), close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -135,7 +134,7 @@ async def stall_connection(self, request): async with run_server(do_nothing, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - async with run_client(server, open_timeout=2 * MS): + async with connect(get_uri(server), open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -152,7 +151,7 @@ def close_connection(self, request): async with run_server(process_request=close_connection) as server: with self.assertRaises(ConnectionError) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -164,7 +163,7 @@ class SecureClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server securely.""" async with run_server(ssl=SERVER_CONTEXT) as server: - async with run_client(server, ssl=CLIENT_CONTEXT) as client: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.version()[:3], "TLS") @@ -172,12 +171,9 @@ async def test_connection(self): async def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" async with run_server(ssl=SERVER_CONTEXT) as server: - host, port = get_server_host_port(server) - async with run_client( - "wss://overridden/", - host=host, - port=port, - ssl=CLIENT_CONTEXT, + host, port = get_host_port(server) + async with connect( + "wss://overridden/", host=host, port=port, ssl=CLIENT_CONTEXT ) as client: ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.server_hostname, "overridden") @@ -185,10 +181,8 @@ async def test_set_server_hostname_implicitly(self): async def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" async with run_server(ssl=SERVER_CONTEXT) as server: - async with run_client( - server, - ssl=CLIENT_CONTEXT, - server_hostname="overridden", + async with connect( + get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="overridden" ) as client: ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.server_hostname, "overridden") @@ -198,7 +192,7 @@ async def test_reject_invalid_server_certificate(self): async with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. - async with run_client(server, secure=True): + async with connect(get_uri(server)): self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", @@ -210,8 +204,8 @@ async def test_reject_invalid_server_hostname(self): async with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. - async with run_client( - server, ssl=CLIENT_CONTEXT, server_hostname="invalid" + async with connect( + get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="invalid" ): self.fail("did not raise") self.assertIn( @@ -226,7 +220,7 @@ async def test_connection(self): """Client connects to server over a Unix socket.""" with temp_unix_socket_path() as path: async with run_unix_server(path): - async with run_unix_client(path) as client: + async with unix_connect(path) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_set_host_header(self): @@ -234,7 +228,7 @@ async def test_set_host_header(self): # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: async with run_unix_server(path): - async with run_unix_client(path, uri="ws://overridden/") as client: + async with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") @@ -244,7 +238,7 @@ async def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") self.assertEqual(ssl_object.version()[:3], "TLS") @@ -254,7 +248,7 @@ async def test_set_server_hostname(self): # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client( + async with unix_connect( path, ssl=CLIENT_CONTEXT, uri="wss://overridden/", diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 38f226903..a16267439 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -6,6 +6,7 @@ import socket import unittest +from websockets.asyncio.client import connect, unix_connect from websockets.asyncio.compatibility import TimeoutError, asyncio_timeout from websockets.asyncio.server import * from websockets.exceptions import ( @@ -22,13 +23,13 @@ SERVER_CONTEXT, temp_unix_socket_path, ) -from .client import run_client, run_unix_client from .server import ( EvalShellMixin, crash, do_nothing, eval_shell, - get_server_host_port, + get_host_port, + get_uri, keep_running, run_server, run_unix_server, @@ -39,13 +40,13 @@ class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_connection_handler_returns(self): """Connection handler returns.""" async with run_server(do_nothing) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() self.assertEqual( @@ -56,7 +57,7 @@ async def test_connection_handler_returns(self): async def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" async with run_server(crash) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: with self.assertRaises(ConnectionClosedError) as raised: await client.recv() self.assertEqual( @@ -70,7 +71,7 @@ async def test_existing_socket(self): with socket.create_server(("localhost", 0)) as sock: async with run_server(sock=sock, host=None, port=None): uri = "ws://{}:{}/".format(*sock.getsockname()) - async with run_client(uri) as client: + async with connect(uri) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_select_subprotocol(self): @@ -85,7 +86,7 @@ def select_subprotocol(ws, subprotocols): subprotocols=["chat"], select_subprotocol=select_subprotocol, ) as server: - async with run_client(server, subprotocols=["chat"]) as client: + async with connect(get_uri(server), subprotocols=["chat"]) as client: await self.assertEval(client, "ws.select_subprotocol_ran", "True") await self.assertEval(client, "ws.subprotocol", "chat") @@ -97,7 +98,7 @@ def select_subprotocol(ws, subprotocols): async with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -112,7 +113,7 @@ def select_subprotocol(ws, subprotocols): async with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -127,7 +128,7 @@ def process_request(ws, request): ws.process_request_ran = True async with run_server(process_request=process_request) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") async def test_async_process_request_returns_none(self): @@ -138,7 +139,7 @@ async def process_request(ws, request): ws.process_request_ran = True async with run_server(process_request=process_request) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") async def test_process_request_returns_response(self): @@ -149,7 +150,7 @@ def process_request(ws, request): async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -164,7 +165,7 @@ async def process_request(ws, request): async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -179,7 +180,7 @@ def process_request(ws, request): async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -194,7 +195,7 @@ async def process_request(ws, request): async with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -210,7 +211,7 @@ def process_response(ws, request, response): ws.process_response_ran = True async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") async def test_async_process_response_returns_none(self): @@ -222,7 +223,7 @@ async def process_response(ws, request, response): ws.process_response_ran = True async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") async def test_process_response_modifies_response(self): @@ -232,7 +233,7 @@ def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_async_process_response_modifies_response(self): @@ -242,7 +243,7 @@ async def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_replaces_response(self): @@ -254,7 +255,7 @@ def process_response(ws, request, response): return dataclasses.replace(response, headers=headers) async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_async_process_response_replaces_response(self): @@ -266,7 +267,7 @@ async def process_response(ws, request, response): return dataclasses.replace(response, headers=headers) async with run_server(process_response=process_response) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") async def test_process_response_raises_exception(self): @@ -277,7 +278,7 @@ def process_response(ws, request, response): async with run_server(process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -292,7 +293,7 @@ async def process_response(ws, request, response): async with run_server(process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -302,13 +303,13 @@ async def process_response(ws, request, response): async def test_override_server(self): """Server can override Server header with server_header.""" async with run_server(server_header="Neo") as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.response.headers['Server']", "Neo") async def test_remove_server(self): """Server can remove Server header with server_header.""" async with run_server(server_header=None) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval( client, "'Server' in ws.response.headers", "False" ) @@ -316,7 +317,7 @@ async def test_remove_server(self): async def test_compression_is_enabled(self): """Server enables compression by default.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval( client, "[type(ext).__name__ for ext in ws.protocol.extensions]", @@ -326,13 +327,13 @@ async def test_compression_is_enabled(self): async def test_disable_compression(self): """Server disables compression.""" async with run_server(compression=None) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.extensions", "[]") async def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" async with run_server(ping_interval=MS) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await client.send("ws.latency") latency = eval(await client.recv()) self.assertEqual(latency, 0) @@ -344,7 +345,7 @@ async def test_keepalive_is_enabled(self): async def test_disable_keepalive(self): """Client disables keepalive.""" async with run_server(ping_interval=None) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await asyncio.sleep(2 * MS) await client.send("ws.latency") latency = eval(await client.recv()) @@ -365,14 +366,14 @@ def create_connection(*args, **kwargs): return server async with run_server(create_connection=create_connection) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.create_connection_ran", "True") async def test_connections(self): """Server provides a connections property.""" async with run_server() as server: self.assertEqual(server.connections, set()) - async with run_client(server) as client: + async with connect(get_uri(server)) as client: self.assertEqual(len(server.connections), 1) ws_id = str(next(iter(server.connections)).id) await self.assertEval(client, "ws.id", ws_id) @@ -386,7 +387,7 @@ def remove_key_header(self, request): async with run_server(process_request=remove_key_header) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -396,9 +397,7 @@ def remove_key_header(self, request): async def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" async with run_server(open_timeout=MS) as server: - reader, writer = await asyncio.open_connection( - *get_server_host_port(server) - ) + reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") finally: @@ -407,9 +406,7 @@ async def test_timeout_during_handshake(self): async def test_connection_closed_during_handshake(self): """Server reads EOF before receiving handshake request from client.""" async with run_server() as server: - _reader, writer = await asyncio.open_connection( - *get_server_host_port(server) - ) + _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() async def test_close_server_rejects_connecting_connections(self): @@ -422,7 +419,7 @@ async def process_request(ws, _request): async with run_server(process_request=process_request) as server: asyncio.get_running_loop().call_later(MS, server.close) with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -432,7 +429,7 @@ async def process_request(ws, _request): async def test_close_server_closes_open_connections(self): """Server closes open connections with close code 1001 when closing.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: server.close() with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() @@ -444,7 +441,7 @@ async def test_close_server_closes_open_connections(self): async def test_close_server_keeps_connections_open(self): """Server waits for client to close open connections when closing.""" async with run_server() as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: server.close(close_connections=False) # Server cannot receive new connections. @@ -464,7 +461,7 @@ async def test_close_server_keeps_connections_open(self): async def test_close_server_keeps_handlers_running(self): """Server waits for connection handlers to terminate.""" async with run_server(keep_running) as server: - async with run_client(server) as client: + async with connect(get_uri(server)) as client: # Delay termination of connection handler. await client.send(str(3 * MS)) @@ -486,16 +483,14 @@ class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives secure connection from client.""" async with run_server(ssl=SERVER_CONTEXT) as server: - async with run_client(server, ssl=CLIENT_CONTEXT) as client: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") async def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: - reader, writer = await asyncio.open_connection( - *get_server_host_port(server) - ) + reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") finally: @@ -504,9 +499,7 @@ async def test_timeout_during_tls_handshake(self): async def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" async with run_server(ssl=SERVER_CONTEXT) as server: - _reader, writer = await asyncio.open_connection( - *get_server_host_port(server) - ) + _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() @@ -516,7 +509,7 @@ async def test_connection(self): """Server receives connection from client over a Unix socket.""" with temp_unix_socket_path() as path: async with run_unix_server(path): - async with run_unix_client(path) as client: + async with unix_connect(path) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") @@ -526,7 +519,7 @@ async def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: async with run_unix_server(path, ssl=SERVER_CONTEXT): - async with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") @@ -577,8 +570,8 @@ async def test_valid_authorization(self): async with run_server( process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") @@ -589,7 +582,7 @@ async def test_missing_authorization(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client(server): + async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -602,8 +595,8 @@ async def test_unsupported_authorization(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Negotiate ..."}, ): self.fail("did not raise") @@ -618,8 +611,8 @@ async def test_authorization_with_unknown_username(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ): self.fail("did not raise") @@ -634,8 +627,8 @@ async def test_authorization_with_incorrect_password(self): process_request=basic_auth(credentials=("hello", "changeme")), ) as server: with self.assertRaises(InvalidStatus) as raised: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ): self.fail("did not raise") @@ -654,8 +647,8 @@ async def test_list_of_credentials(self): ] ), ) as server: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ) as client: await self.assertEval(client, "ws.username", "bye") @@ -669,8 +662,8 @@ def check_credentials(username, password): async with run_server( process_request=basic_auth(check_credentials=check_credentials), ) as server: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") @@ -684,8 +677,8 @@ async def check_credentials(username, password): async with run_server( process_request=basic_auth(check_credentials=check_credentials), ) as server: - async with run_client( - server, + async with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: await self.assertEval(client, "ws.username", "hello") diff --git a/tests/sync/client.py b/tests/sync/client.py deleted file mode 100644 index acbf97fa7..000000000 --- a/tests/sync/client.py +++ /dev/null @@ -1,32 +0,0 @@ -import contextlib - -from websockets.sync.client import * -from websockets.sync.server import Server - - -__all__ = [ - "run_client", - "run_unix_client", -] - - -@contextlib.contextmanager -def run_client(wsuri_or_server, secure=None, resource_name="/", **kwargs): - if isinstance(wsuri_or_server, str): - wsuri = wsuri_or_server - else: - assert isinstance(wsuri_or_server, Server) - if secure is None: - # Backwards compatibility: ssl used to be called ssl_context. - secure = "ssl" in kwargs or "ssl_context" in kwargs - protocol = "wss" if secure else "ws" - host, port = wsuri_or_server.socket.getsockname() - wsuri = f"{protocol}://{host}:{port}{resource_name}" - with connect(wsuri, **kwargs) as client: - yield client - - -@contextlib.contextmanager -def run_unix_client(path, **kwargs): - with unix_connect(path, **kwargs) as client: - yield client diff --git a/tests/sync/server.py b/tests/sync/server.py index d5295ccd8..a86cf88ce 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,9 +1,17 @@ import contextlib +import ssl import threading from websockets.sync.server import * +def get_uri(server): + secure = isinstance(server.socket, ssl.SSLSocket) # hack + protocol = "wss" if secure else "ws" + host, port = server.socket.getsockname() + return f"{protocol}://{host}:{port}" + + def crash(ws): raise RuntimeError diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 03f4e972f..44cbdd6c4 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -14,15 +14,14 @@ DeprecationTestCase, temp_unix_socket_path, ) -from .client import run_client, run_unix_client -from .server import do_nothing, run_server, run_unix_server +from .server import do_nothing, get_uri, run_server, run_unix_server class ClientTests(unittest.TestCase): def test_connection(self): """Client connects to server and the handshake succeeds.""" with run_server() as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_existing_socket(self): @@ -30,33 +29,33 @@ def test_existing_socket(self): with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: # Use a non-existing domain to ensure we connect to the right socket. - with run_client("ws://invalid/", sock=sock) as client: + with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_additional_headers(self): """Client can set additional headers with additional_headers.""" with run_server() as server: - with run_client( - server, additional_headers={"Authorization": "Bearer ..."} + with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} ) as client: self.assertEqual(client.request.headers["Authorization"], "Bearer ...") def test_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" with run_server() as server: - with run_client(server, user_agent_header="Smith") as client: + with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") def test_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" with run_server() as server: - with run_client(server, user_agent_header=None) as client: + with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) def test_compression_is_enabled(self): """Client enables compression by default.""" with run_server() as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEqual( [type(ext) for ext in client.protocol.extensions], [PerMessageDeflate], @@ -65,7 +64,7 @@ def test_compression_is_enabled(self): def test_disable_compression(self): """Client disables compression.""" with run_server() as server: - with run_client(server, compression=None) as client: + with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) def test_custom_connection_factory(self): @@ -77,19 +76,21 @@ def create_connection(*args, **kwargs): return client with run_server() as server: - with run_client(server, create_connection=create_connection) as client: + with connect( + get_uri(server), create_connection=create_connection + ) as client: self.assertTrue(client.create_connection_ran) def test_invalid_uri(self): """Client receives an invalid URI.""" with self.assertRaises(InvalidURI): - with run_client("http://localhost"): # invalid scheme + with connect("http://localhost"): # invalid scheme self.fail("did not raise") def test_tcp_connection_fails(self): """Client fails to connect to server.""" with self.assertRaises(OSError): - with run_client("ws://localhost:54321"): # invalid port + with connect("ws://localhost:54321"): # invalid port self.fail("did not raise") def test_handshake_fails(self): @@ -102,7 +103,7 @@ def remove_accept_header(self, request, response): # Use a connection handler that exits immediately to avoid an exception. with run_server(do_nothing, process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: - with run_client(server, close_timeout=MS): + with connect(get_uri(server), close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -121,7 +122,7 @@ def stall_connection(self, request): with run_server(do_nothing, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - with run_client(server, open_timeout=2 * MS): + with connect(get_uri(server), open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -138,7 +139,7 @@ def close_connection(self, request): with run_server(process_request=close_connection) as server: with self.assertRaises(ConnectionError) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -150,7 +151,7 @@ class SecureClientTests(unittest.TestCase): def test_connection(self): """Client connects to server securely.""" with run_server(ssl=SERVER_CONTEXT) as server: - with run_client(server, ssl=CLIENT_CONTEXT) as client: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") @@ -158,10 +159,8 @@ def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - uri="wss://overridden/", + with unix_connect( + path, ssl=CLIENT_CONTEXT, uri="wss://overridden/" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -169,10 +168,8 @@ def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - server_hostname="overridden", + with unix_connect( + path, ssl=CLIENT_CONTEXT, server_hostname="overridden" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -181,7 +178,7 @@ def test_reject_invalid_server_certificate(self): with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. - with run_client(server, secure=True): + with connect(get_uri(server)): self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", @@ -193,7 +190,9 @@ def test_reject_invalid_server_hostname(self): with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. - with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"): + with connect( + get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="invalid" + ): self.fail("did not raise") self.assertIn( "certificate verify failed: Hostname mismatch", @@ -207,7 +206,7 @@ def test_connection(self): """Client connects to server over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path): - with run_unix_client(path) as client: + with unix_connect(path) as client: self.assertEqual(client.protocol.state.name, "OPEN") def test_set_host_header(self): @@ -215,7 +214,7 @@ def test_set_host_header(self): # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: with run_unix_server(path): - with run_unix_client(path, uri="ws://overridden/") as client: + with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") @@ -225,7 +224,7 @@ def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(client.socket.version()[:3], "TLS") @@ -234,10 +233,8 @@ def test_set_server_hostname(self): # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client( - path, - ssl=CLIENT_CONTEXT, - uri="wss://overridden/", + with unix_connect( + path, ssl=CLIENT_CONTEXT, uri="wss://overridden/" ) as client: self.assertEqual(client.socket.server_hostname, "overridden") @@ -296,5 +293,5 @@ def test_ssl_context_argument(self): """Client supports the deprecated ssl_context argument.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertDeprecationWarning("ssl_context was renamed to ssl"): - with run_client(server, ssl_context=CLIENT_CONTEXT): + with connect(get_uri(server), ssl_context=CLIENT_CONTEXT): pass diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index a17634716..e6a17d02f 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -13,6 +13,7 @@ NegotiationError, ) from websockets.http11 import Request, Response +from websockets.sync.client import connect, unix_connect from websockets.sync.server import * from ..utils import ( @@ -22,12 +23,12 @@ DeprecationTestCase, temp_unix_socket_path, ) -from .client import run_client, run_unix_client from .server import ( EvalShellMixin, crash, do_nothing, eval_shell, + get_uri, run_server, run_unix_server, ) @@ -37,13 +38,13 @@ class ServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives connection from client and the handshake succeeds.""" with run_server() as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_connection_handler_returns(self): """Connection handler returns.""" with run_server(do_nothing) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: with self.assertRaises(ConnectionClosedOK) as raised: client.recv() self.assertEqual( @@ -54,7 +55,7 @@ def test_connection_handler_returns(self): def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" with run_server(crash) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: with self.assertRaises(ConnectionClosedError) as raised: client.recv() self.assertEqual( @@ -68,7 +69,7 @@ def test_existing_socket(self): with socket.create_server(("localhost", 0)) as sock: with run_server(sock=sock): uri = "ws://{}:{}/".format(*sock.getsockname()) - with run_client(uri) as client: + with connect(uri) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_select_subprotocol(self): @@ -83,7 +84,7 @@ def select_subprotocol(ws, subprotocols): subprotocols=["chat"], select_subprotocol=select_subprotocol, ) as server: - with run_client(server, subprotocols=["chat"]) as client: + with connect(get_uri(server), subprotocols=["chat"]) as client: self.assertEval(client, "ws.select_subprotocol_ran", "True") self.assertEval(client, "ws.subprotocol", "chat") @@ -95,7 +96,7 @@ def select_subprotocol(ws, subprotocols): with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -110,7 +111,7 @@ def select_subprotocol(ws, subprotocols): with run_server(select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -125,7 +126,7 @@ def process_request(ws, request): ws.process_request_ran = True with run_server(process_request=process_request) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.process_request_ran", "True") def test_process_request_returns_response(self): @@ -136,7 +137,7 @@ def process_request(ws, request): with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -151,7 +152,7 @@ def process_request(ws, request): with run_server(process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -167,7 +168,7 @@ def process_response(ws, request, response): ws.process_response_ran = True with run_server(process_response=process_response) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.process_response_ran", "True") def test_process_response_modifies_response(self): @@ -177,7 +178,7 @@ def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" with run_server(process_response=process_response) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_replaces_response(self): @@ -189,7 +190,7 @@ def process_response(ws, request, response): return dataclasses.replace(response, headers=headers) with run_server(process_response=process_response) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") def test_process_response_raises_exception(self): @@ -200,7 +201,7 @@ def process_response(ws, request, response): with run_server(process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -210,19 +211,19 @@ def process_response(ws, request, response): def test_override_server(self): """Server can override Server header with server_header.""" with run_server(server_header="Neo") as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.response.headers['Server']", "Neo") def test_remove_server(self): """Server can remove Server header with server_header.""" with run_server(server_header=None) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "'Server' in ws.response.headers", "False") def test_compression_is_enabled(self): """Server enables compression by default.""" with run_server() as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval( client, "[type(ext).__name__ for ext in ws.protocol.extensions]", @@ -232,7 +233,7 @@ def test_compression_is_enabled(self): def test_disable_compression(self): """Server disables compression.""" with run_server(compression=None) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.extensions", "[]") def test_logger(self): @@ -250,7 +251,7 @@ def create_connection(*args, **kwargs): return server with run_server(create_connection=create_connection) as server: - with run_client(server) as client: + with connect(get_uri(server)) as client: self.assertEval(client, "ws.create_connection_ran", "True") def test_fileno(self): @@ -274,7 +275,7 @@ def remove_key_header(self, request): with run_server(process_request=remove_key_header) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -316,7 +317,7 @@ class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): """Server receives secure connection from client.""" with run_server(ssl=SERVER_CONTEXT) as server: - with run_client(server, ssl=CLIENT_CONTEXT) as client: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") @@ -357,7 +358,7 @@ def test_connection(self): """Server receives connection from client over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path): - with run_unix_client(path) as client: + with unix_connect(path) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") @@ -367,7 +368,7 @@ def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): - with run_unix_client(path, ssl=CLIENT_CONTEXT) as client: + with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") self.assertEval(client, "ws.socket.version()[:3]", "TLS") @@ -418,8 +419,8 @@ def test_valid_authorization(self): with run_server( process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: self.assertEval(client, "ws.username", "hello") @@ -430,7 +431,7 @@ def test_missing_authorization(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client(server): + with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -443,8 +444,8 @@ def test_unsupported_authorization(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Negotiate ..."}, ): self.fail("did not raise") @@ -459,8 +460,8 @@ def test_authorization_with_unknown_username(self): process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ): self.fail("did not raise") @@ -475,8 +476,8 @@ def test_authorization_with_incorrect_password(self): process_request=basic_auth(credentials=("hello", "changeme")), ) as server: with self.assertRaises(InvalidStatus) as raised: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ): self.fail("did not raise") @@ -495,8 +496,8 @@ def test_list_of_credentials(self): ] ), ) as server: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, ) as client: self.assertEval(client, "ws.username", "bye") @@ -510,8 +511,8 @@ def check_credentials(username, password): with run_server( process_request=basic_auth(check_credentials=check_credentials), ) as server: - with run_client( - server, + with connect( + get_uri(server), additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, ) as client: self.assertEval(client, "ws.username", "hello") @@ -561,7 +562,7 @@ def test_ssl_context_argument(self): """Client supports the deprecated ssl_context argument.""" with self.assertDeprecationWarning("ssl_context was renamed to ssl"): with run_server(ssl_context=SERVER_CONTEXT) as server: - with run_client(server, ssl=CLIENT_CONTEXT): + with connect(get_uri(server), ssl=CLIENT_CONTEXT): pass def test_web_socket_server_class(self): From f0d1ebb238466ec48cffc657341c7928ef08f43e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 21:48:31 +0200 Subject: [PATCH 595/762] Use the same connection handler for all tests. --- tests/asyncio/server.py | 39 ++++++++++++++++++------------------ tests/asyncio/test_client.py | 12 +++++------ tests/asyncio/test_server.py | 25 ++++++++++------------- tests/sync/server.py | 29 ++++++++++++++------------- tests/sync/test_client.py | 10 ++++----- tests/sync/test_server.py | 20 +++++++++--------- 6 files changed, 64 insertions(+), 71 deletions(-) diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index 06fa92dea..097950971 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -19,10 +19,23 @@ def get_uri(server): return f"{protocol}://{host}:{port}" -async def eval_shell(ws): - async for expr in ws: - value = eval(expr) - await ws.send(str(value)) +async def handler(ws): + path = ws.request.path + if path == "/": + # The default path is an eval shell. + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + elif path == "/crash": + raise RuntimeError + elif path == "/no-op": + pass + elif path == "/delay": + delay = float(await ws.recv()) + await ws.close() + await asyncio.sleep(delay) + else: + raise AssertionError(f"unexpected path: {path}") class EvalShellMixin: @@ -31,27 +44,13 @@ async def assertEval(self, client, expr, value): self.assertEqual(await client.recv(), value) -async def crash(ws): - raise RuntimeError - - -async def do_nothing(ws): - pass - - -async def keep_running(ws): - delay = float(await ws.recv()) - await ws.close() - await asyncio.sleep(delay) - - @contextlib.asynccontextmanager -async def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): +async def run_server(handler=handler, host="localhost", port=0, **kwargs): async with serve(handler, host, port, **kwargs) as server: yield server @contextlib.asynccontextmanager -async def run_unix_server(path, handler=eval_shell, **kwargs): +async def run_unix_server(path, handler=handler, **kwargs): async with unix_serve(handler, path, **kwargs) as server: yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index a8ef6ef9d..77261129a 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -9,7 +9,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path -from .server import do_nothing, get_host_port, get_uri, run_server, run_unix_server +from .server import get_host_port, get_uri, run_server, run_unix_server class ClientTests(unittest.IsolatedAsyncioTestCase): @@ -111,11 +111,9 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - async with run_server( - do_nothing, process_response=remove_accept_header - ) as server: + async with run_server(process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: - async with connect(get_uri(server), close_timeout=MS): + async with connect(get_uri(server) + "/no-op", close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -131,10 +129,10 @@ async def stall_connection(self, request): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - async with run_server(do_nothing, process_request=stall_connection) as server: + async with run_server(process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - async with connect(get_uri(server), open_timeout=2 * MS): + async with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index a16267439..fd9c69c40 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -25,12 +25,9 @@ ) from .server import ( EvalShellMixin, - crash, - do_nothing, - eval_shell, get_host_port, get_uri, - keep_running, + handler, run_server, run_unix_server, ) @@ -45,8 +42,8 @@ async def test_connection(self): async def test_connection_handler_returns(self): """Connection handler returns.""" - async with run_server(do_nothing) as server: - async with connect(get_uri(server)) as client: + async with run_server() as server: + async with connect(get_uri(server) + "/no-op") as client: with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() self.assertEqual( @@ -56,8 +53,8 @@ async def test_connection_handler_returns(self): async def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" - async with run_server(crash) as server: - async with connect(get_uri(server)) as client: + async with run_server() as server: + async with connect(get_uri(server) + "/crash") as client: with self.assertRaises(ConnectionClosedError) as raised: await client.recv() self.assertEqual( @@ -460,8 +457,8 @@ async def test_close_server_keeps_connections_open(self): async def test_close_server_keeps_handlers_running(self): """Server waits for connection handlers to terminate.""" - async with run_server(keep_running) as server: - async with connect(get_uri(server)) as client: + async with run_server() as server: + async with connect(get_uri(server) + "/delay") as client: # Delay termination of connection handler. await client.send(str(3 * MS)) @@ -528,7 +525,7 @@ class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): async def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: - await unix_serve(eval_shell) + await unix_serve(handler) self.assertEqual( str(raised.exception), "path was not specified, and no sock specified", @@ -539,7 +536,7 @@ async def test_unix_with_path_and_sock(self): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) with self.assertRaises(ValueError) as raised: - await unix_serve(eval_shell, path="/", sock=sock) + await unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), "path and sock can not be specified at the same time", @@ -548,7 +545,7 @@ async def test_unix_with_path_and_sock(self): async def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: - await serve(eval_shell, subprotocols="chat") + await serve(handler, subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", @@ -557,7 +554,7 @@ async def test_invalid_subprotocol(self): async def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: - await serve(eval_shell, compression=False) + await serve(handler, compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", diff --git a/tests/sync/server.py b/tests/sync/server.py index a86cf88ce..114c1545b 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -12,18 +12,19 @@ def get_uri(server): return f"{protocol}://{host}:{port}" -def crash(ws): - raise RuntimeError - - -def do_nothing(ws): - pass - - -def eval_shell(ws): - for expr in ws: - value = eval(expr) - ws.send(str(value)) +def handler(ws): + path = ws.request.path + if path == "/": + # The default path is an eval shell. + for expr in ws: + value = eval(expr) + ws.send(str(value)) + elif path == "/crash": + raise RuntimeError + elif path == "/no-op": + pass + else: + raise AssertionError(f"unexpected path: {path}") class EvalShellMixin: @@ -33,7 +34,7 @@ def assertEval(self, client, expr, value): @contextlib.contextmanager -def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): +def run_server(handler=handler, host="localhost", port=0, **kwargs): with serve(handler, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() @@ -45,7 +46,7 @@ def run_server(handler=eval_shell, host="localhost", port=0, **kwargs): @contextlib.contextmanager -def run_unix_server(path, handler=eval_shell, **kwargs): +def run_unix_server(path, handler=handler, **kwargs): with unix_serve(handler, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 44cbdd6c4..0d5273d12 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -14,7 +14,7 @@ DeprecationTestCase, temp_unix_socket_path, ) -from .server import do_nothing, get_uri, run_server, run_unix_server +from .server import get_uri, run_server, run_unix_server class ClientTests(unittest.TestCase): @@ -101,9 +101,9 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - with run_server(do_nothing, process_response=remove_accept_header) as server: + with run_server(process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: - with connect(get_uri(server), close_timeout=MS): + with connect(get_uri(server) + "/no-op", close_timeout=MS): self.fail("did not raise") self.assertEqual( str(raised.exception), @@ -119,10 +119,10 @@ def stall_connection(self, request): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - with run_server(do_nothing, process_request=stall_connection) as server: + with run_server(process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: - with connect(get_uri(server), open_timeout=2 * MS): + with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): self.fail("did not raise") self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index e6a17d02f..a4b537c66 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -25,10 +25,8 @@ ) from .server import ( EvalShellMixin, - crash, - do_nothing, - eval_shell, get_uri, + handler, run_server, run_unix_server, ) @@ -43,8 +41,8 @@ def test_connection(self): def test_connection_handler_returns(self): """Connection handler returns.""" - with run_server(do_nothing) as server: - with connect(get_uri(server)) as client: + with run_server() as server: + with connect(get_uri(server) + "/no-op") as client: with self.assertRaises(ConnectionClosedOK) as raised: client.recv() self.assertEqual( @@ -54,8 +52,8 @@ def test_connection_handler_returns(self): def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" - with run_server(crash) as server: - with connect(get_uri(server)) as client: + with run_server() as server: + with connect(get_uri(server) + "/crash") as client: with self.assertRaises(ConnectionClosedError) as raised: client.recv() self.assertEqual( @@ -377,7 +375,7 @@ class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" with self.assertRaises(TypeError) as raised: - unix_serve(eval_shell) + unix_serve(handler) self.assertEqual( str(raised.exception), "missing path argument", @@ -388,7 +386,7 @@ def test_unix_with_path_and_sock(self): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) with self.assertRaises(TypeError) as raised: - unix_serve(eval_shell, path="/", sock=sock) + unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), "path and sock arguments are incompatible", @@ -397,7 +395,7 @@ def test_unix_with_path_and_sock(self): def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: - serve(eval_shell, subprotocols="chat") + serve(handler, subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", @@ -406,7 +404,7 @@ def test_invalid_subprotocol(self): def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: - serve(eval_shell, compression=False) + serve(handler, compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", From 15eb223e52aafad92849cc784837691f47c474cb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 22:06:59 +0200 Subject: [PATCH 596/762] Remove run_(unix_)server abstraction from asyncio tests. It wasn't adding much value. Calling serve directly is more obvious. --- tests/asyncio/server.py | 19 ++---- tests/asyncio/test_client.py | 47 +++++++------- tests/asyncio/test_server.py | 116 +++++++++++++++++++---------------- 3 files changed, 90 insertions(+), 92 deletions(-) diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index 097950971..acf6500c6 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -1,9 +1,6 @@ import asyncio -import contextlib import socket -from websockets.asyncio.server import * - def get_host_port(server): for sock in server.sockets: @@ -38,19 +35,11 @@ async def handler(ws): raise AssertionError(f"unexpected path: {path}") +# This shortcut avoids repeating serve(handler, "localhost", 0) for every test. +args = handler, "localhost", 0 + + class EvalShellMixin: async def assertEval(self, client, expr, value): await client.send(expr) self.assertEqual(await client.recv(), value) - - -@contextlib.asynccontextmanager -async def run_server(handler=handler, host="localhost", port=0, **kwargs): - async with serve(handler, host, port, **kwargs) as server: - yield server - - -@contextlib.asynccontextmanager -async def run_unix_server(path, handler=handler, **kwargs): - async with unix_serve(handler, path, **kwargs) as server: - yield server diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 77261129a..53a6eaaf1 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -5,23 +5,24 @@ from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError +from websockets.asyncio.server import serve, unix_serve from websockets.exceptions import InvalidHandshake, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path -from .server import get_host_port, get_uri, run_server, run_unix_server +from .server import args, get_host_port, get_uri, handler class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server and the handshake succeeds.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") async def test_existing_socket(self): """Client connects using a pre-existing socket.""" - async with run_server() as server: + async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: # Use a non-existing domain to ensure we connect to the right socket. async with connect("ws://invalid/", sock=sock) as client: @@ -29,7 +30,7 @@ async def test_existing_socket(self): async def test_additional_headers(self): """Client can set additional headers with additional_headers.""" - async with run_server() as server: + async with serve(*args) as server: async with connect( get_uri(server), additional_headers={"Authorization": "Bearer ..."} ) as client: @@ -37,19 +38,19 @@ async def test_additional_headers(self): async def test_override_user_agent(self): """Client can override User-Agent header with user_agent_header.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), user_agent_header="Smith") as client: self.assertEqual(client.request.headers["User-Agent"], "Smith") async def test_remove_user_agent(self): """Client can remove User-Agent header with user_agent_header.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) async def test_compression_is_enabled(self): """Client enables compression by default.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual( [type(ext) for ext in client.protocol.extensions], @@ -58,13 +59,13 @@ async def test_compression_is_enabled(self): async def test_disable_compression(self): """Client disables compression.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), ping_interval=MS) as client: self.assertEqual(client.latency, 0) await asyncio.sleep(2 * MS) @@ -72,7 +73,7 @@ async def test_keepalive_is_enabled(self): async def test_disable_keepalive(self): """Client disables keepalive.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server), ping_interval=None) as client: await asyncio.sleep(2 * MS) self.assertEqual(client.latency, 0) @@ -85,7 +86,7 @@ def create_connection(*args, **kwargs): client.create_connection_ran = True return client - async with run_server() as server: + async with serve(*args) as server: async with connect( get_uri(server), create_connection=create_connection ) as client: @@ -111,7 +112,7 @@ def remove_accept_header(self, request, response): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - async with run_server(process_response=remove_accept_header) as server: + async with serve(*args, process_response=remove_accept_header) as server: with self.assertRaises(InvalidHandshake) as raised: async with connect(get_uri(server) + "/no-op", close_timeout=MS): self.fail("did not raise") @@ -129,7 +130,7 @@ async def stall_connection(self, request): # The connection will be open for the server but failed for the client. # Use a connection handler that exits immediately to avoid an exception. - async with run_server(process_request=stall_connection) as server: + async with serve(*args, process_request=stall_connection) as server: try: with self.assertRaises(TimeoutError) as raised: async with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): @@ -147,7 +148,7 @@ async def test_connection_closed_during_handshake(self): def close_connection(self, request): self.close_transport() - async with run_server(process_request=close_connection) as server: + async with serve(*args, process_request=close_connection) as server: with self.assertRaises(ConnectionError) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -160,7 +161,7 @@ def close_connection(self, request): class SecureClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server securely.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") @@ -168,7 +169,7 @@ async def test_connection(self): async def test_set_server_hostname_implicitly(self): """Client sets server_hostname to the host in the WebSocket URI.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: host, port = get_host_port(server) async with connect( "wss://overridden/", host=host, port=port, ssl=CLIENT_CONTEXT @@ -178,7 +179,7 @@ async def test_set_server_hostname_implicitly(self): async def test_set_server_hostname_explicitly(self): """Client sets server_hostname to the value provided in argument.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect( get_uri(server), ssl=CLIENT_CONTEXT, server_hostname="overridden" ) as client: @@ -187,7 +188,7 @@ async def test_set_server_hostname_explicitly(self): async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # The test certificate isn't trusted system-wide. async with connect(get_uri(server)): @@ -199,7 +200,7 @@ async def test_reject_invalid_server_certificate(self): async def test_reject_invalid_server_hostname(self): """Client rejects certificate where server hostname doesn't match.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: # This hostname isn't included in the test certificate. async with connect( @@ -217,7 +218,7 @@ class UnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server over a Unix socket.""" with temp_unix_socket_path() as path: - async with run_unix_server(path): + async with unix_serve(handler, path): async with unix_connect(path) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -225,7 +226,7 @@ async def test_set_host_header(self): """Client sets the Host header to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: - async with run_unix_server(path): + async with unix_serve(handler, path): async with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") @@ -235,7 +236,7 @@ class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") ssl_object = client.transport.get_extra_info("ssl_object") @@ -245,7 +246,7 @@ async def test_set_server_hostname(self): """Client sets server_hostname to the host in the WebSocket URI.""" # This is part of the documented behavior of unix_connect(). with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect( path, ssl=CLIENT_CONTEXT, diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fd9c69c40..fe0cafe13 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -25,24 +25,23 @@ ) from .server import ( EvalShellMixin, + args, get_host_port, get_uri, handler, - run_server, - run_unix_server, ) class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_connection_handler_returns(self): """Connection handler returns.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server) + "/no-op") as client: with self.assertRaises(ConnectionClosedOK) as raised: await client.recv() @@ -53,7 +52,7 @@ async def test_connection_handler_returns(self): async def test_connection_handler_raises_exception(self): """Connection handler raises an exception.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server) + "/crash") as client: with self.assertRaises(ConnectionClosedError) as raised: await client.recv() @@ -66,7 +65,7 @@ async def test_connection_handler_raises_exception(self): async def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: - async with run_server(sock=sock, host=None, port=None): + async with serve(handler, sock=sock, host=None, port=None): uri = "ws://{}:{}/".format(*sock.getsockname()) async with connect(uri) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") @@ -79,7 +78,8 @@ def select_subprotocol(ws, subprotocols): assert "chat" in subprotocols return "chat" - async with run_server( + async with serve( + *args, subprotocols=["chat"], select_subprotocol=select_subprotocol, ) as server: @@ -93,7 +93,7 @@ async def test_select_subprotocol_rejects_handshake(self): def select_subprotocol(ws, subprotocols): raise NegotiationError - async with run_server(select_subprotocol=select_subprotocol) as server: + async with serve(*args, select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -108,7 +108,7 @@ async def test_select_subprotocol_raises_exception(self): def select_subprotocol(ws, subprotocols): raise RuntimeError - async with run_server(select_subprotocol=select_subprotocol) as server: + async with serve(*args, select_subprotocol=select_subprotocol) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -124,7 +124,7 @@ def process_request(ws, request): self.assertIsInstance(request, Request) ws.process_request_ran = True - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") @@ -135,7 +135,7 @@ async def process_request(ws, request): self.assertIsInstance(request, Request) ws.process_request_ran = True - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_request_ran", "True") @@ -145,7 +145,7 @@ async def test_process_request_returns_response(self): def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -160,7 +160,7 @@ async def test_async_process_request_returns_response(self): async def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -175,7 +175,7 @@ async def test_process_request_raises_exception(self): def process_request(ws, request): raise RuntimeError - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -190,7 +190,7 @@ async def test_async_process_request_raises_exception(self): async def process_request(ws, request): raise RuntimeError - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -207,7 +207,7 @@ def process_response(ws, request, response): self.assertIsInstance(response, Response) ws.process_response_ran = True - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") @@ -219,7 +219,7 @@ async def process_response(ws, request, response): self.assertIsInstance(response, Response) ws.process_response_ran = True - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.process_response_ran", "True") @@ -229,7 +229,7 @@ async def test_process_response_modifies_response(self): def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") @@ -239,7 +239,7 @@ async def test_async_process_response_modifies_response(self): async def process_response(ws, request, response): response.headers["X-ProcessResponse"] = "OK" - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") @@ -251,7 +251,7 @@ def process_response(ws, request, response): headers["X-ProcessResponse"] = "OK" return dataclasses.replace(response, headers=headers) - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") @@ -263,7 +263,7 @@ async def process_response(ws, request, response): headers["X-ProcessResponse"] = "OK" return dataclasses.replace(response, headers=headers) - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") @@ -273,7 +273,7 @@ async def test_process_response_raises_exception(self): def process_response(ws, request, response): raise RuntimeError - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -288,7 +288,7 @@ async def test_async_process_response_raises_exception(self): async def process_response(ws, request, response): raise RuntimeError - async with run_server(process_response=process_response) as server: + async with serve(*args, process_response=process_response) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -299,13 +299,13 @@ async def process_response(ws, request, response): async def test_override_server(self): """Server can override Server header with server_header.""" - async with run_server(server_header="Neo") as server: + async with serve(*args, server_header="Neo") as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.response.headers['Server']", "Neo") async def test_remove_server(self): """Server can remove Server header with server_header.""" - async with run_server(server_header=None) as server: + async with serve(*args, server_header=None) as server: async with connect(get_uri(server)) as client: await self.assertEval( client, "'Server' in ws.response.headers", "False" @@ -313,7 +313,7 @@ async def test_remove_server(self): async def test_compression_is_enabled(self): """Server enables compression by default.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: await self.assertEval( client, @@ -323,13 +323,13 @@ async def test_compression_is_enabled(self): async def test_disable_compression(self): """Server disables compression.""" - async with run_server(compression=None) as server: + async with serve(*args, compression=None) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.protocol.extensions", "[]") async def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" - async with run_server(ping_interval=MS) as server: + async with serve(*args, ping_interval=MS) as server: async with connect(get_uri(server)) as client: await client.send("ws.latency") latency = eval(await client.recv()) @@ -341,7 +341,7 @@ async def test_keepalive_is_enabled(self): async def test_disable_keepalive(self): """Client disables keepalive.""" - async with run_server(ping_interval=None) as server: + async with serve(*args, ping_interval=None) as server: async with connect(get_uri(server)) as client: await asyncio.sleep(2 * MS) await client.send("ws.latency") @@ -351,7 +351,7 @@ async def test_disable_keepalive(self): async def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") - async with run_server(logger=logger) as server: + async with serve(*args, logger=logger) as server: self.assertIs(server.logger, logger) async def test_custom_connection_factory(self): @@ -362,13 +362,13 @@ def create_connection(*args, **kwargs): server.create_connection_ran = True return server - async with run_server(create_connection=create_connection) as server: + async with serve(*args, create_connection=create_connection) as server: async with connect(get_uri(server)) as client: await self.assertEval(client, "ws.create_connection_ran", "True") async def test_connections(self): """Server provides a connections property.""" - async with run_server() as server: + async with serve(*args) as server: self.assertEqual(server.connections, set()) async with connect(get_uri(server)) as client: self.assertEqual(len(server.connections), 1) @@ -382,7 +382,7 @@ async def test_handshake_fails(self): def remove_key_header(self, request): del request.headers["Sec-WebSocket-Key"] - async with run_server(process_request=remove_key_header) as server: + async with serve(*args, process_request=remove_key_header) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -393,7 +393,7 @@ def remove_key_header(self, request): async def test_timeout_during_handshake(self): """Server times out before receiving handshake request from client.""" - async with run_server(open_timeout=MS) as server: + async with serve(*args, open_timeout=MS) as server: reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") @@ -402,7 +402,7 @@ async def test_timeout_during_handshake(self): async def test_connection_closed_during_handshake(self): """Server reads EOF before receiving handshake request from client.""" - async with run_server() as server: + async with serve(*args) as server: _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() @@ -413,7 +413,7 @@ async def process_request(ws, _request): while ws.server.is_serving(): await asyncio.sleep(0) # pragma: no cover - async with run_server(process_request=process_request) as server: + async with serve(*args, process_request=process_request) as server: asyncio.get_running_loop().call_later(MS, server.close) with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): @@ -425,7 +425,7 @@ async def process_request(ws, _request): async def test_close_server_closes_open_connections(self): """Server closes open connections with close code 1001 when closing.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: server.close() with self.assertRaises(ConnectionClosedOK) as raised: @@ -437,7 +437,7 @@ async def test_close_server_closes_open_connections(self): async def test_close_server_keeps_connections_open(self): """Server waits for client to close open connections when closing.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server)) as client: server.close(close_connections=False) @@ -457,7 +457,7 @@ async def test_close_server_keeps_connections_open(self): async def test_close_server_keeps_handlers_running(self): """Server waits for connection handlers to terminate.""" - async with run_server() as server: + async with serve(*args) as server: async with connect(get_uri(server) + "/delay") as client: # Delay termination of connection handler. await client.send(str(3 * MS)) @@ -479,14 +479,14 @@ async def test_close_server_keeps_handlers_running(self): class SecureServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives secure connection from client.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") async def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" - async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + async with serve(*args, ssl=SERVER_CONTEXT, open_timeout=MS) as server: reader, writer = await asyncio.open_connection(*get_host_port(server)) try: self.assertEqual(await reader.read(4096), b"") @@ -495,7 +495,7 @@ async def test_timeout_during_tls_handshake(self): async def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" - async with run_server(ssl=SERVER_CONTEXT) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as server: _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() @@ -505,7 +505,7 @@ class UnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client over a Unix socket.""" with temp_unix_socket_path() as path: - async with run_unix_server(path): + async with unix_serve(handler, path): async with unix_connect(path) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") @@ -515,7 +515,7 @@ class SecureUnixServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives secure connection from client over a Unix socket.""" with temp_unix_socket_path() as path: - async with run_unix_server(path, ssl=SERVER_CONTEXT): + async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") @@ -545,7 +545,7 @@ async def test_unix_with_path_and_sock(self): async def test_invalid_subprotocol(self): """Server rejects single value of subprotocols.""" with self.assertRaises(TypeError) as raised: - await serve(handler, subprotocols="chat") + await serve(*args, subprotocols="chat") self.assertEqual( str(raised.exception), "subprotocols must be a list, not a str", @@ -554,7 +554,7 @@ async def test_invalid_subprotocol(self): async def test_unsupported_compression(self): """Server rejects incorrect value of compression.""" with self.assertRaises(ValueError) as raised: - await serve(handler, compression=False) + await serve(*args, compression=False) self.assertEqual( str(raised.exception), "unsupported compression: False", @@ -564,7 +564,8 @@ async def test_unsupported_compression(self): class BasicAuthTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): async def test_valid_authorization(self): """basic_auth authenticates client with HTTP Basic Authentication.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: async with connect( @@ -575,7 +576,8 @@ async def test_valid_authorization(self): async def test_missing_authorization(self): """basic_auth rejects client without credentials.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: @@ -588,7 +590,8 @@ async def test_missing_authorization(self): async def test_unsupported_authorization(self): """basic_auth rejects client with unsupported credentials.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: @@ -604,7 +607,8 @@ async def test_unsupported_authorization(self): async def test_authorization_with_unknown_username(self): """basic_auth rejects client with unknown username.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "iloveyou")), ) as server: with self.assertRaises(InvalidStatus) as raised: @@ -620,7 +624,8 @@ async def test_authorization_with_unknown_username(self): async def test_authorization_with_incorrect_password(self): """basic_auth rejects client with incorrect password.""" - async with run_server( + async with serve( + *args, process_request=basic_auth(credentials=("hello", "changeme")), ) as server: with self.assertRaises(InvalidStatus) as raised: @@ -636,7 +641,8 @@ async def test_authorization_with_incorrect_password(self): async def test_list_of_credentials(self): """basic_auth accepts a list of hard coded credentials.""" - async with run_server( + async with serve( + *args, process_request=basic_auth( credentials=[ ("hello", "iloveyou"), @@ -656,7 +662,8 @@ async def test_check_credentials_function(self): def check_credentials(username, password): return hmac.compare_digest(password, "iloveyou") - async with run_server( + async with serve( + *args, process_request=basic_auth(check_credentials=check_credentials), ) as server: async with connect( @@ -671,7 +678,8 @@ async def test_check_credentials_coroutine(self): async def check_credentials(username, password): return hmac.compare_digest(password, "iloveyou") - async with run_server( + async with serve( + *args, process_request=basic_auth(check_credentials=check_credentials), ) as server: async with connect( From 317123119434e22e11c273bf422917bfa0f2a626 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 21:40:33 +0200 Subject: [PATCH 597/762] Add tests for the logger argument of clients. --- tests/asyncio/test_client.py | 8 ++++++++ tests/asyncio/test_server.py | 2 +- tests/sync/test_client.py | 8 ++++++++ tests/sync/test_server.py | 2 +- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 53a6eaaf1..15178f8b8 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -1,4 +1,5 @@ import asyncio +import logging import socket import ssl import unittest @@ -78,6 +79,13 @@ async def test_disable_keepalive(self): await asyncio.sleep(2 * MS) self.assertEqual(client.latency, 0) + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with serve(*args) as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + async def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fe0cafe13..ceb0417a7 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -352,7 +352,7 @@ async def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") async with serve(*args, logger=logger) as server: - self.assertIs(server.logger, logger) + self.assertEqual(server.logger.name, logger.name) async def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 0d5273d12..812412203 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,3 +1,4 @@ +import logging import socket import ssl import threading @@ -67,6 +68,13 @@ def test_disable_compression(self): with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) + def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + with run_server() as server: + with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + def test_custom_connection_factory(self): """Client runs ClientConnection factory provided in create_connection.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index a4b537c66..541a14602 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -238,7 +238,7 @@ def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") with run_server(logger=logger) as server: - self.assertIs(server.logger, logger) + self.assertEqual(server.logger.name, logger.name) def test_custom_connection_factory(self): """Server runs ServerConnection factory provided in create_connection.""" From 04ac475805f20cf73d79bb3fd003e412c8526396 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 21:40:47 +0200 Subject: [PATCH 598/762] Fix copy-paste errors in test docstrings. --- tests/asyncio/test_server.py | 2 +- tests/sync/test_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index ceb0417a7..bc1d0444c 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -340,7 +340,7 @@ async def test_keepalive_is_enabled(self): self.assertGreater(latency, 0) async def test_disable_keepalive(self): - """Client disables keepalive.""" + """Server disables keepalive.""" async with serve(*args, ping_interval=None) as server: async with connect(get_uri(server)) as client: await asyncio.sleep(2 * MS) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 541a14602..7bcf144a2 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -557,7 +557,7 @@ def test_bad_list_of_credentials(self): class BackwardsCompatibilityTests(DeprecationTestCase): def test_ssl_context_argument(self): - """Client supports the deprecated ssl_context argument.""" + """Server supports the deprecated ssl_context argument.""" with self.assertDeprecationWarning("ssl_context was renamed to ssl"): with run_server(ssl_context=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT): From 033bf40c692029e3acaf2bb3adc946e7f8cf15ac Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 08:06:30 +0200 Subject: [PATCH 599/762] Add a generator of backoff delays. --- docs/reference/variables.rst | 39 +++++++++++++++++++++++++++++---- src/websockets/client.py | 32 +++++++++++++++++++++++++++ src/websockets/legacy/client.py | 11 +++++----- tests/test_client.py | 13 +++++++++++ 4 files changed, 86 insertions(+), 9 deletions(-) diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index 4bca112da..498132251 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -13,7 +13,7 @@ Logging See the :doc:`logging guide <../topics/logging>` for details. Security -........ +-------- .. envvar:: WEBSOCKETS_SERVER @@ -31,18 +31,49 @@ Security Maximum length of the request or status line in the opening handshake. - The default value is ``8192``. + The default value is ``8192`` bytes. .. envvar:: WEBSOCKETS_MAX_NUM_HEADERS Maximum number of HTTP headers in the opening handshake. - The default value is ``128``. + The default value is ``128`` bytes. .. envvar:: WEBSOCKETS_MAX_BODY_SIZE Maximum size of the body of an HTTP response in the opening handshake. - The default value is ``1_048_576`` (1 MiB). + The default value is ``1_048_576`` bytes (1 MiB). See the :doc:`security guide <../topics/security>` for details. + +Reconnection +------------ + +Reconnection attempts are spaced out with truncated exponential backoff. + +.. envvar:: BACKOFF_INITIAL_DELAY + + The first attempt is delayed by a random amount of time between ``0`` and + ``BACKOFF_INITIAL_DELAY`` seconds. + + The default value is ``5.0`` seconds. + +.. envvar:: BACKOFF_MIN_DELAY + + The second attempt is delayed by ``BACKOFF_MIN_DELAY`` seconds. + + The default value is ``3.1`` seconds. + +.. envvar:: BACKOFF_FACTOR + + After the second attempt, the delay is multiplied by ``BACKOFF_FACTOR`` + between each attempt. + + The default value is ``1.618``. + +.. envvar:: BACKOFF_MAX_DELAY + + The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds. + + The default value is ``90.0`` seconds. diff --git a/src/websockets/client.py b/src/websockets/client.py index ae467993a..95de99dc5 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os +import random import warnings from typing import Any, Generator, Sequence @@ -357,3 +359,33 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: DeprecationWarning, ) super().__init__(*args, **kwargs) + + +BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) +BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) +BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) +BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) + + +def backoff( + initial_delay: float = BACKOFF_INITIAL_DELAY, + min_delay: float = BACKOFF_MIN_DELAY, + max_delay: float = BACKOFF_MAX_DELAY, + factor: float = BACKOFF_FACTOR, +) -> Generator[float, None, None]: + """ + Generate a series of backoff delays between reconnection attempts. + + Yields: + How many seconds to wait before retrying to connect. + + """ + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. + yield random.random() * initial_delay + delay = min_delay + while delay < max_delay: + yield delay + delay *= factor + while True: + yield max_delay diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 25142ea25..a1bc5cbae 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -3,6 +3,7 @@ import asyncio import functools import logging +import os import random import urllib.parse import warnings @@ -591,13 +592,13 @@ def handle_redirect(self, uri: str) -> None: # async for ... in connect(...): - BACKOFF_MIN = 1.92 - BACKOFF_MAX = 60.0 - BACKOFF_FACTOR = 1.618 - BACKOFF_INITIAL = 5 + BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) + BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) + BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) + BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: - backoff_delay = self.BACKOFF_MIN + backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR while True: try: async with self as protocol: diff --git a/tests/test_client.py b/tests/test_client.py index d798a66f9..8b3bf4232 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,6 +3,7 @@ import unittest.mock from websockets.client import * +from websockets.client import backoff from websockets.datastructures import Headers from websockets.exceptions import InvalidHandshake, InvalidHeader from websockets.frames import OP_TEXT, Frame @@ -613,3 +614,15 @@ def test_client_connection_class(self): client = ClientConnection("ws://localhost/") self.assertIsInstance(client, ClientProtocol) + + +class BackoffTests(unittest.TestCase): + def test_backoff(self): + backoff_gen = backoff() + + initial_delay = next(backoff_gen) + self.assertGreaterEqual(initial_delay, 0) + self.assertLess(initial_delay, 5) + + following_delays = [int(next(backoff_gen)) for _ in range(9)] + self.assertEqual(following_delays, [3, 5, 8, 13, 21, 34, 55, 89, 90]) From 5ada5f776be2d1693f56a13de781d4b1b13f4ecf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 08:22:58 +0200 Subject: [PATCH 600/762] Reorder functions logically in ClientConnection. Define callers after callees. --- src/websockets/asyncio/client.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 033887e87..170ae5912 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -295,19 +295,6 @@ def factory() -> ClientConnection: self._open_timeout = open_timeout - # async with connect(...) as ...: ... - - async def __aenter__(self) -> ClientConnection: - return await self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> None: - await self.connection.close() - # ... = await connect(...) def __await__(self) -> Generator[Any, None, ClientConnection]: @@ -333,6 +320,19 @@ async def __await_impl__(self) -> ClientConnection: __iter__ = __await__ + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.connection.close() + def unix_connect( path: str | None = None, From 6ffb6b057caeb1d1d35b60a1a82afa7629694f82 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 08:28:53 +0200 Subject: [PATCH 601/762] Remove _-prefixed attributes. Overall websockets doesn't use this convention for private attributes. --- src/websockets/asyncio/client.py | 14 +++++++------- src/websockets/asyncio/server.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 170ae5912..f669359a5 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -281,19 +281,19 @@ def factory() -> ClientConnection: loop = asyncio.get_running_loop() if kwargs.pop("unix", False): - self._create_connection = loop.create_unix_connection(factory, **kwargs) + self.create_connection = loop.create_unix_connection(factory, **kwargs) else: if kwargs.get("sock") is None: kwargs.setdefault("host", wsuri.host) kwargs.setdefault("port", wsuri.port) - self._create_connection = loop.create_connection(factory, **kwargs) + self.create_connection = loop.create_connection(factory, **kwargs) - self._handshake_args = ( + self.handshake_args = ( additional_headers, user_agent_header, ) - self._open_timeout = open_timeout + self.open_timeout = open_timeout # ... = await connect(...) @@ -303,10 +303,10 @@ def __await__(self) -> Generator[Any, None, ClientConnection]: async def __await_impl__(self) -> ClientConnection: try: - async with asyncio_timeout(self._open_timeout): - _transport, self.connection = await self._create_connection + async with asyncio_timeout(self.open_timeout): + _transport, self.connection = await self.create_connection try: - await self.connection.handshake(*self._handshake_args) + await self.connection.handshake(*self.handshake_args) except (Exception, asyncio.CancelledError): self.connection.transport.close() raise diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index f24281252..5e71a892b 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -796,10 +796,10 @@ def protocol_select_subprotocol( loop = asyncio.get_running_loop() if kwargs.pop("unix", False): - self._create_server = loop.create_unix_server(factory, **kwargs) + self.create_server = loop.create_unix_server(factory, **kwargs) else: # mypy cannot tell that kwargs must provide sock when port is None. - self._create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] + self.create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] # async with serve(...) as ...: ... @@ -822,7 +822,7 @@ def __await__(self) -> Generator[Any, None, Server]: return self.__await_impl__().__await__() async def __await_impl__(self) -> Server: - server = await self._create_server + server = await self.create_server self.server.wrap(server) return self.server From 16456e2bec3ad324d545c8eb5ea9d1eed583dea4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 22:17:39 +0200 Subject: [PATCH 602/762] Restore id-token permission. Removing it in ed2f21e3 was a mistake. --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a68714ed1..0a30c049b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -71,6 +71,7 @@ jobs: # Don't release when running the workflow manually from GitHub's UI. if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') permissions: + id-token: write contents: write steps: - name: Download artifacts From 62d70f4a8fb5e30d8744b1378659cc999deb0715 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 22:06:53 +0200 Subject: [PATCH 603/762] Restore speedups.c in source distribution. Fix #1494. --- MANIFEST.in | 1 + docs/project/changelog.rst | 10 ++++++++++ src/websockets/version.py | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index 1c660b95b..d4598bda0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include LICENSE include src/websockets/py.typed +include src/websockets/speedups.c # required when BUILD_EXTENSION=no diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7c5998288..2292c5aa1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,16 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +13.0.1 +------ + +*August 28, 2024* + +Bug fixes +......... + +* Restored the C extension in the source distribution. + .. _13.0: 13.0 diff --git a/src/websockets/version.py b/src/websockets/version.py index 56c321940..e82b285db 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -20,7 +20,7 @@ released = True -tag = version = commit = "13.0" +tag = version = commit = "13.0.1" if not released: # pragma: no cover From 157f7908c33cb540f7cf76568dde8aa6cf400504 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 22:22:02 +0200 Subject: [PATCH 604/762] Add provenance attestations. --- .github/workflows/release.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0a30c049b..863d88aa9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -72,6 +72,7 @@ jobs: if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') permissions: id-token: write + attestations: write contents: write steps: - name: Download artifacts @@ -80,6 +81,10 @@ jobs: pattern: dist-* merge-multiple: true path: dist + - name: Attest provenance + uses: actions/attest-build-provenance@v1 + with: + subject-path: dist/* - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - name: Create GitHub release From d8ab09b2566cdfcb7a0a95dc992454bb9f4b1647 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 24 Aug 2024 14:59:28 +0200 Subject: [PATCH 605/762] Add automatic reconnection to the new asyncio implementation. Missing tests for now. Fix #1480. --- docs/howto/upgrade.rst | 33 ++++---- docs/project/changelog.rst | 8 +- docs/reference/asyncio/client.rst | 2 + src/websockets/asyncio/client.py | 123 ++++++++++++++++++++++++++++-- tests/asyncio/test_client.py | 120 ++++++++++++++++++++++++++++- 5 files changed, 263 insertions(+), 23 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 18c3cc127..42edb978a 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -79,19 +79,6 @@ Following redirects The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP redirects yet. -Automatic reconnection -...................... - -The new implementation of :func:`~asyncio.client.connect` doesn't provide -automatic reconnection yet. - -In other words, the following pattern isn't supported:: - - from websockets.asyncio.client import connect - - async for websocket in connect(...): # this doesn't work yet - ... - .. _Update import paths: Import paths @@ -185,6 +172,26 @@ it simpler. ``process_response`` replaces ``extra_headers`` and provides more flexibility. See process_request_, select_subprotocol_, and process_response_ below. +Customizing automatic reconnection +.................................. + +On the client side, if you're reconnecting automatically with ``async for ... in +connect(...)``, the behavior when a connection attempt fails was enhanced and +made configurable. + +The original implementation retried on any error. The new implementation uses an +heuristic to determine whether an error is retryable or fatal. By default, only +network errors and server errors (HTTP 500, 502, 503, or 504) are considered +retryable. You can customize this behavior with the ``process_exception`` +argument of :func:`~asyncio.client.connect`. + +See :func:`~asyncio.client.process_exception` for more information. + +Here's how to revert to the behavior of the original implementation:: + + async for ... in connect(..., process_exception=lambda exc: exc): + ... + Tracking open connections ......................... diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c239cfe5d..78a547d0b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -46,12 +46,16 @@ Backwards-incompatible changes New features ............ -* Made the set of active connections available in the :attr:`Server.connections - ` property. +* Added support for reconnecting automatically by using + :func:`~asyncio.client.connect` as an asynchronous iterator to the new + :mod:`asyncio` implementation. * Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` implementations of servers. +* Made the set of active connections available in the :attr:`Server.connections + ` property. + .. _13.0: 13.0 diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index 77a3c5d53..e2b0ff550 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -12,6 +12,8 @@ Opening a connection .. autofunction:: unix_connect :async: +.. autofunction:: process_exception + Using a connection ------------------ diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index f669359a5..860db3238 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -1,11 +1,14 @@ from __future__ import annotations import asyncio +import functools +import logging from types import TracebackType -from typing import Any, Generator, Sequence +from typing import Any, AsyncIterator, Callable, Generator, Sequence -from ..client import ClientProtocol +from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike +from ..exceptions import InvalidStatus from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -121,6 +124,46 @@ def connection_lost(self, exc: Exception | None) -> None: self.response_rcvd.set_result(None) +def process_exception(exc: Exception) -> Exception | None: + """ + Determine whether an error is retryable or fatal. + + When reconnecting automatically with ``async for ... in connect(...)``, if a + connection attempt fails, :func:`process_exception` is called to determine + whether to retry connecting or to raise the exception. + + This function defines the default behavior, which is to retry on: + + * :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors; + * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500, + 502, 503, or 504: server or proxy errors. + + All other exceptions are considered fatal. + + You can change this behavior with the ``process_exception`` argument of + :func:`connect`. + + Return :obj:`None` if the exception is retryable i.e. when the error could + be transient and trying to reconnect with the same parameters could succeed. + The exception will be logged at the ``INFO`` level. + + Return an exception, either ``exc`` or a new exception, if the exception is + fatal i.e. when trying to reconnect will most likely produce the same error. + That exception will be raised, breaking out of the retry loop. + + """ + if isinstance(exc, (OSError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidStatus) and exc.response.status_code in [ + 500, # Internal Server Error + 502, # Bad Gateway + 503, # Service Unavailable + 504, # Gateway Timeout + ]: + return None + return exc + + # This is spelled in lower case because it's exposed as a callable in the API. class connect: """ @@ -138,6 +181,21 @@ class connect: The connection is closed automatically when exiting the context. + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + try: + ... + except websockets.ConnectionClosed: + continue + + If the connection fails with a transient error, it is retried with + exponential backoff. If it fails with a fatal error, the exception is + raised, breaking out of the loop. + + The connection is closed automatically after each iteration of the loop. + Args: uri: URI of the WebSocket server. origin: Value of the ``Origin`` header, for servers that require it. @@ -153,6 +211,9 @@ class connect: compression: The "permessage-deflate" extension is enabled by default. Set ``compression`` to :obj:`None` to disable it. See the :doc:`compression guide <../../topics/compression>` for details. + process_exception: When reconnecting automatically, tell whether an + error is transient or fatal. The default behavior is defined by + :func:`process_exception`. Refer to its documentation for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -219,6 +280,7 @@ def __init__( additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, compression: str | None = "deflate", + process_exception: Callable[[Exception], Exception | None] = process_exception, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -281,19 +343,26 @@ def factory() -> ClientConnection: loop = asyncio.get_running_loop() if kwargs.pop("unix", False): - self.create_connection = loop.create_unix_connection(factory, **kwargs) + self.create_connection = functools.partial( + loop.create_unix_connection, factory, **kwargs + ) else: if kwargs.get("sock") is None: kwargs.setdefault("host", wsuri.host) kwargs.setdefault("port", wsuri.port) - self.create_connection = loop.create_connection(factory, **kwargs) + self.create_connection = functools.partial( + loop.create_connection, factory, **kwargs + ) self.handshake_args = ( additional_headers, user_agent_header, ) - + self.process_exception = process_exception self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.client") + self.logger = logger # ... = await connect(...) @@ -304,7 +373,7 @@ def __await__(self) -> Generator[Any, None, ClientConnection]: async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): - _transport, self.connection = await self.create_connection + _transport, self.connection = await self.create_connection() try: await self.connection.handshake(*self.handshake_args) except (Exception, asyncio.CancelledError): @@ -333,6 +402,48 @@ async def __aexit__( ) -> None: await self.connection.close() + # async for ... in connect(...): + + async def __aiter__(self) -> AsyncIterator[ClientConnection]: + delays: Generator[float, None, None] | None = None + while True: + try: + async with self as protocol: + yield protocol + except Exception as exc: + # Determine whether the exception is retryable or fatal. + # The API of process_exception is "return an exception or None"; + # "raise an exception" is also supported because it's a frequent + # mistake. It isn't documented in order to keep the API simple. + try: + new_exc = self.process_exception(exc) + except Exception as raised_exc: + new_exc = raised_exc + + # The connection failed with a fatal error. + # Raise the exception and exit the loop. + if new_exc is exc: + raise + if new_exc is not None: + raise new_exc from exc + + # The connection failed with a retryable error. + # Start or continue backoff and reconnect. + if delays is None: + delays = backoff() + delay = next(delays) + self.logger.info( + "! connect failed; reconnecting in %.1f seconds", + delay, + exc_info=True, + ) + await asyncio.sleep(delay) + continue + + else: + # The connection succeeded. Reset backoff. + delays = None + def unix_connect( path: str | None = None, diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 15178f8b8..7467d215a 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -1,4 +1,6 @@ import asyncio +import contextlib +import http import logging import socket import ssl @@ -7,20 +9,134 @@ from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError from websockets.asyncio.server import serve, unix_serve -from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.client import backoff +from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path from .server import args, get_host_port, get_uri, handler +# Decorate tests that need it with @short_backoff_delay() instead of using it as +# a context manager when dropping support for Python < 3.10. +@contextlib.asynccontextmanager +async def short_backoff_delay(): + defaults = backoff.__defaults__ + backoff.__defaults__ = ( + defaults[0] * MS, + defaults[1] * MS, + defaults[2] * MS, + defaults[3], + ) + try: + yield + finally: + backoff.__defaults__ = defaults + + class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): - """Client connects to server and the handshake succeeds.""" + """Client connects to server.""" async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") + async def test_reconnection(self): + """Client reconnects to server.""" + iterations = 0 + successful = 0 + + def process_request(connection, request): + nonlocal iterations + iterations += 1 + # Retriable errors + if iterations == 1: + connection.transport.close() + elif iterations == 2: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + # Fatal error + elif iterations == 5: + return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") + + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with short_backoff_delay(): + async for client in connect(get_uri(server)): + self.assertEqual(client.protocol.state.name, "OPEN") + successful += 1 + + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 402", + ) + self.assertEqual(iterations, 5) + self.assertEqual(successful, 2) + + @unittest.skipUnless( + hasattr(http.HTTPStatus, "IM_A_TEAPOT"), + "test requires Python 3.9", + ) + async def test_reconnection_with_custom_process_exception(self): + """Client runs process_exception to tell if errors are retryable or fatal.""" + iteration = 0 + + def process_request(connection, request): + nonlocal iteration + iteration += 1 + if iteration == 1: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus): + if 500 <= exc.response.status_code < 600: + return None + if exc.response.status_code == 418: + return Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async with short_backoff_delay(): + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual(iteration, 2) + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + + @unittest.skipUnless( + hasattr(http.HTTPStatus, "IM_A_TEAPOT"), + "test requires Python 3.9", + ) + async def test_reconnection_with_custom_process_exception_raising_exception(self): + """Client supports raising an exception in process_exception.""" + + def process_request(connection, request): + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus) and exc.response.status_code == 418: + raise Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async with short_backoff_delay(): + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + async def test_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: From 032463e4b9a27cb349c67e21d10ad8a34f22bb59 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 22:56:43 +0200 Subject: [PATCH 606/762] Prevent a warning in twine upload. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index ae0aaa65d..7ea3a5e5f 100644 --- a/setup.py +++ b/setup.py @@ -34,5 +34,6 @@ setuptools.setup( version=version, long_description=long_description, + long_description_content_type="text/x-rst", ext_modules=ext_modules, ) From 5be975e440e00dafc32c303975b12aaaf723c5e8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 23:12:02 +0200 Subject: [PATCH 607/762] Make make build the C extension by default. --- Makefile | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index a69248b6e..3eb3ae6b2 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,8 @@ export PYTHONASYNCIODEBUG=1 export PYTHONPATH=src export PYTHONWARNINGS=default -default: style types tests +build: + python setup.py build_ext --inplace style: black src tests @@ -26,9 +27,6 @@ maxi_cov: coverage html coverage report --show-missing --fail-under=100 -build: - python setup.py build_ext --inplace - clean: find src -name '*.so' -delete find . -name '*.pyc' -delete From 2b990e88b97d385d72d91c8354e8c9ed4a1f04c4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 28 Aug 2024 23:12:37 +0200 Subject: [PATCH 608/762] Update FAQ after implementing reconnection. --- docs/faq/client.rst | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 0dfc84253..0a7aab6e2 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -81,15 +81,9 @@ The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- -.. admonition:: This feature is only supported by the legacy :mod:`asyncio` - implementation. - :class: warning +Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: - It will be added to the new :mod:`asyncio` implementation soon. - -Use :func:`~websockets.legacy.client.connect` as an asynchronous iterator:: - - from websockets.legacy.client import connect + from websockets.asyncio.client import connect async for websocket in connect(...): try: From f842a13a234c61cc0c75743db54e698474e58c05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 21:52:19 +0200 Subject: [PATCH 609/762] Prevent false positives with latest ruff. --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index c1d34c90b..fde9c3226 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,9 @@ exclude_lines = [ "@unittest.skip", ] +[tool.ruff] +target-version = "py312" + [tool.ruff.lint] select = [ "E", # pycodestyle From 1f89db78ca6363a7e4d9830b6b280e311e3934f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 21:59:44 +0200 Subject: [PATCH 610/762] Switch from black to ruff for code formatting. It's faster and achieves pretty much the same result. --- .github/workflows/tests.yml | 4 +--- Makefile | 2 +- src/websockets/asyncio/client.py | 1 - src/websockets/asyncio/server.py | 1 - src/websockets/legacy/server.py | 11 +++++++---- tests/legacy/test_client_server.py | 1 - tox.ini | 15 +++++++-------- 7 files changed, 16 insertions(+), 19 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b9172b7fb..43193ea50 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -41,9 +41,7 @@ jobs: python-version: "3.x" - name: Install tox run: pip install tox - - name: Check code formatting - run: tox -e black - - name: Check code style + - name: Check code formatting & style run: tox -e ruff - name: Check types statically run: tox -e mypy diff --git a/Makefile b/Makefile index 3eb3ae6b2..fd36d0367 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ build: python setup.py build_ext --inplace style: - black src tests + ruff format src tests ruff check --fix src tests types: diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 860db3238..5f7a37198 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -297,7 +297,6 @@ def __init__( # Other keyword arguments are passed to loop.create_connection **kwargs: Any, ) -> None: - wsuri = parse_uri(uri) if wsuri.secure: diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 5e71a892b..1aa47af88 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -722,7 +722,6 @@ def __init__( # Other keyword arguments are passed to loop.create_server **kwargs: Any, ) -> None: - if subprotocols is not None: validate_subprotocols(subprotocols) diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index b31cc25b8..2cb9b1abb 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -100,9 +100,10 @@ class WebSocketServerProtocol(WebSocketCommonProtocol): def __init__( self, + # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] - | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), ws_server: WebSocketServer, *, @@ -983,9 +984,10 @@ class Serve: def __init__( self, + # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] - | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), host: str | Sequence[str] | None = None, port: int | None = None, @@ -1140,9 +1142,10 @@ async def __await_impl__(self) -> WebSocketServer: def unix_serve( + # The version that accepts the path in the second argument is deprecated. ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] - | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] # deprecated + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] ), path: str | None = None, **kwargs: Any, @@ -1169,7 +1172,7 @@ def remove_path_argument( ws_handler: ( Callable[[WebSocketServerProtocol], Awaitable[Any]] | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] - ) + ), ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: try: inspect.signature(ws_handler).bind(None) diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 0c5d66c92..2f3ba9b77 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1293,7 +1293,6 @@ def test_connection_error_during_closing_handshake(self, close): class ClientServerTests( CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase ): - def test_redirect_secure(self): with temp_test_redirecting_server(self): # websockets doesn't support serving non-TLS and TLS connections diff --git a/tox.ini b/tox.ini index 16d9c9f16..cba9b290b 100644 --- a/tox.ini +++ b/tox.ini @@ -7,12 +7,12 @@ env_list = py312 py313 coverage - black ruff mypy [testenv] -commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} +commands = + python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} pass_env = WEBSOCKETS_* [testenv:coverage] @@ -27,14 +27,13 @@ commands = python -m coverage report --show-missing --fail-under=100 deps = coverage -[testenv:black] -commands = black --check src tests -deps = black - [testenv:ruff] -commands = ruff check src tests +commands = + ruff format --check src tests + ruff check src tests deps = ruff [testenv:mypy] -commands = mypy --strict src +commands = + mypy --strict src deps = mypy From 6b2f06083573be750ac70c894ae13d87c03ef624 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 21:40:07 +0200 Subject: [PATCH 611/762] Follow redirects in the new asyncio implementation. Fix #631. --- docs/howto/upgrade.rst | 27 +-- docs/project/changelog.rst | 3 + docs/reference/features.rst | 2 +- docs/reference/variables.rst | 11 ++ src/websockets/asyncio/client.py | 169 ++++++++++++---- src/websockets/legacy/client.py | 2 +- tests/asyncio/test_client.py | 319 ++++++++++++++++++++++++------- 7 files changed, 403 insertions(+), 130 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 42edb978a..120509c9f 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -10,15 +10,13 @@ It provides a very similar API. However, there are a few differences. The recommended upgrade process is: -1. Make sure that your application doesn't use any `deprecated APIs`_. If it +#. Make sure that your application doesn't use any `deprecated APIs`_. If it doesn't raise any warnings, you can skip this step. -2. Check if your application depends on `missing features`_. If it does, you - should stick to the original implementation until they're added. -3. `Update import paths`_. For straightforward usage of websockets, this could +#. `Update import paths`_. For straightforward usage of websockets, this could be the only step you need to take. Upgrading could be transparent. -4. Check out `new features and improvements`_ and consider taking advantage of +#. Check out `new features and improvements`_ and consider taking advantage of them to improve your application. -5. Review `API changes`_ and adapt your application to preserve its current +#. Review `API changes`_ and adapt your application to preserve its current functionality. In the interest of brevity, only :func:`~asyncio.client.connect` and @@ -62,23 +60,6 @@ the release notes of the version in which the feature was deprecated. * The ``host``, ``port``, and ``secure`` attributes of connections — deprecated in :ref:`8.0`. -.. _missing features: - -Missing features ----------------- - -.. admonition:: All features listed below will be provided in a future release. - :class: tip - - If your application relies on one of them, you should stick to the original - implementation until the new implementation supports it in a future release. - -Following redirects -................... - -The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP -redirects yet. - .. _Update import paths: Import paths diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 61113fb81..c77876a73 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -50,6 +50,9 @@ New features :func:`~asyncio.client.connect` as an asynchronous iterator to the new :mod:`asyncio` implementation. +* :func:`~asyncio.client.connect` now follows redirects in the new + :mod:`asyncio` implementation. + * Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` implementations of servers. diff --git a/docs/reference/features.rst b/docs/reference/features.rst index d9941e408..32fc05baf 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -154,7 +154,7 @@ Client +------------------------------------+--------+--------+--------+--------+ | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Follow HTTP redirects | ❌ | ❌ | — | ✅ | + | Follow HTTP redirects | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index 498132251..b766e02a1 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -1,6 +1,8 @@ Environment variables ===================== +.. currentmodule:: websockets + Logging ------- @@ -77,3 +79,12 @@ Reconnection attempts are spaced out with truncated exponential backoff. The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds. The default value is ``90.0`` seconds. + +Redirections +------------ + +.. envvar:: WEBSOCKETS_MAX_REDIRECTS + + Maximum number of redirects that :func:`~asyncio.client.connect` follows. + + The default value is ``10``. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 5f7a37198..50f67b95f 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -1,27 +1,30 @@ from __future__ import annotations import asyncio -import functools import logging +import os +import urllib.parse from types import TracebackType from typing import Any, AsyncIterator, Callable, Generator, Sequence from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike -from ..exceptions import InvalidStatus +from ..exceptions import InvalidStatus, SecurityError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import parse_uri +from ..uri import WebSocketURI, parse_uri from .compatibility import TimeoutError, asyncio_timeout from .connection import Connection __all__ = ["connect", "unix_connect", "ClientConnection"] +MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) + class ClientConnection(Connection): """ @@ -126,7 +129,7 @@ def connection_lost(self, exc: Exception | None) -> None: def process_exception(exc: Exception) -> Exception | None: """ - Determine whether an error is retryable or fatal. + Determine whether a connection error is retryable or fatal. When reconnecting automatically with ``async for ... in connect(...)``, if a connection attempt fails, :func:`process_exception` is called to determine @@ -297,16 +300,7 @@ def __init__( # Other keyword arguments are passed to loop.create_connection **kwargs: Any, ) -> None: - wsuri = parse_uri(uri) - - if wsuri.secure: - kwargs.setdefault("ssl", True) - kwargs.setdefault("server_hostname", wsuri.host) - if kwargs.get("ssl") is None: - raise TypeError("ssl=None is incompatible with a wss:// URI") - else: - if kwargs.get("ssl") is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + self.uri = uri if subprotocols is not None: validate_subprotocols(subprotocols) @@ -316,10 +310,13 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + if logger is None: + logger = logging.getLogger("websockets.client") + if create_connection is None: create_connection = ClientConnection - def factory() -> ClientConnection: + def protocol_factory(wsuri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( wsuri, @@ -340,28 +337,104 @@ def factory() -> ClientConnection: ) return connection + self.protocol_factory = protocol_factory + self.handshake_args = ( + additional_headers, + user_agent_header, + ) + self.process_exception = process_exception + self.open_timeout = open_timeout + self.logger = logger + self.connection_kwargs = kwargs + + async def create_connection(self) -> ClientConnection: + """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() + + wsuri = parse_uri(self.uri) + kwargs = self.connection_kwargs.copy() + + def factory() -> ClientConnection: + return self.protocol_factory(wsuri) + + if wsuri.secure: + kwargs.setdefault("ssl", True) + kwargs.setdefault("server_hostname", wsuri.host) + if kwargs.get("ssl") is None: + raise TypeError("ssl=None is incompatible with a wss:// URI") + else: + if kwargs.get("ssl") is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") + if kwargs.pop("unix", False): - self.create_connection = functools.partial( - loop.create_unix_connection, factory, **kwargs - ) + _, connection = await loop.create_unix_connection(factory, **kwargs) else: if kwargs.get("sock") is None: kwargs.setdefault("host", wsuri.host) kwargs.setdefault("port", wsuri.port) - self.create_connection = functools.partial( - loop.create_connection, factory, **kwargs + _, connection = await loop.create_connection(factory, **kwargs) + return connection + + def process_redirect(self, exc: Exception) -> Exception | str: + """ + Determine whether a connection error is a redirect that can be followed. + + Return the new URI if it's a valid redirect. Else, return an exception. + + """ + if not ( + isinstance(exc, InvalidStatus) + and exc.response.status_code + in [ + 300, # Multiple Choices + 301, # Moved Permanently + 302, # Found + 303, # See Other + 307, # Temporary Redirect + 308, # Permanent Redirect + ] + and "Location" in exc.response.headers + ): + return exc + + old_wsuri = parse_uri(self.uri) + new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) + new_wsuri = parse_uri(new_uri) + + # If connect() received a socket, it is closed and cannot be reused. + if self.connection_kwargs.get("sock") is not None: + return ValueError( + f"cannot follow redirect to {new_uri} with a preexisting socket" ) - self.handshake_args = ( - additional_headers, - user_agent_header, - ) - self.process_exception = process_exception - self.open_timeout = open_timeout - if logger is None: - logger = logging.getLogger("websockets.client") - self.logger = logger + # TLS downgrade is forbidden. + if old_wsuri.secure and not new_wsuri.secure: + return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") + + # Apply restrictions to cross-origin redirects. + if ( + old_wsuri.secure != new_wsuri.secure + or old_wsuri.host != new_wsuri.host + or old_wsuri.port != new_wsuri.port + ): + # Cross-origin redirects on Unix sockets don't quite make sense. + if self.connection_kwargs.get("unix", False): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with a Unix socket" + ) + + # Cross-origin redirects when host and port are overridden are ill-defined. + if ( + self.connection_kwargs.get("host") is not None + or self.connection_kwargs.get("port") is not None + ): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with an explicit host or port" + ) + + return new_uri # ... = await connect(...) @@ -372,14 +445,38 @@ def __await__(self) -> Generator[Any, None, ClientConnection]: async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): - _transport, self.connection = await self.create_connection() - try: - await self.connection.handshake(*self.handshake_args) - except (Exception, asyncio.CancelledError): - self.connection.transport.close() - raise + for _ in range(MAX_REDIRECTS): + self.connection = await self.create_connection() + try: + await self.connection.handshake(*self.handshake_args) + except asyncio.CancelledError: + self.connection.transport.close() + raise + except Exception as exc: + # Always close the connection even though keep-alive is + # the default in HTTP/1.1 because create_connection ties + # opening the network connection with initializing the + # protocol. In the current design of connect(), there is + # no easy way to reuse the network connection that works + # in every case nor to reinitialize the protocol. + self.connection.transport.close() + + uri_or_exc = self.process_redirect(exc) + # Response is a valid redirect; follow it. + if isinstance(uri_or_exc, str): + self.uri = uri_or_exc + continue + # Response isn't a valid redirect; raise the exception. + if uri_or_exc is exc: + raise + else: + raise uri_or_exc from exc + + else: + return self.connection else: - return self.connection + raise SecurityError(f"more than {MAX_REDIRECTS} redirects") + except TimeoutError: # Re-raise exception with an informative error message. raise TimeoutError("timed out during handshake") from None diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a1bc5cbae..ec4c2ff64 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -418,7 +418,7 @@ class Connect: """ - MAX_REDIRECTS_ALLOWED = 10 + MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) def __init__( self, diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 7467d215a..b0487552e 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -10,7 +10,12 @@ from websockets.asyncio.compatibility import TimeoutError from websockets.asyncio.server import serve, unix_serve from websockets.client import backoff -from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI +from websockets.exceptions import ( + InvalidHandshake, + InvalidStatus, + InvalidURI, + SecurityError, +) from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path @@ -34,6 +39,20 @@ async def short_backoff_delay(): backoff.__defaults__ = defaults +# Decorate tests that need it with @few_redirects() instead of using it as a +# context manager when dropping support for Python < 3.10. +@contextlib.asynccontextmanager +async def few_redirects(): + from websockets.asyncio import client + + max_redirects = client.MAX_REDIRECTS + client.MAX_REDIRECTS = 2 + try: + yield + finally: + client.MAX_REDIRECTS = max_redirects + + class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server.""" @@ -41,7 +60,93 @@ async def test_connection(self): async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - async def test_reconnection(self): + async def test_explicit_host_port(self): + """Client connects using an explicit host / port.""" + async with serve(*args) as server: + host, port = get_host_port(server) + async with connect("ws://overridden/", host=host, port=port) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with serve(*args) as server: + async with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with serve(*args) as server: + async with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with serve(*args) as server: + async with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await asyncio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with serve(*args) as server: + async with connect(get_uri(server), ping_interval=None) as client: + await asyncio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with serve(*args) as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with serve(*args) as server: + async with connect( + get_uri(server), create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + async def test_reconnect(self): """Client reconnects to server.""" iterations = 0 successful = 0 @@ -76,7 +181,7 @@ def process_request(connection, request): hasattr(http.HTTPStatus, "IM_A_TEAPOT"), "test requires Python 3.9", ) - async def test_reconnection_with_custom_process_exception(self): + async def test_reconnect_with_custom_process_exception(self): """Client runs process_exception to tell if errors are retryable or fatal.""" iteration = 0 @@ -113,7 +218,7 @@ def process_exception(exc): hasattr(http.HTTPStatus, "IM_A_TEAPOT"), "test requires Python 3.9", ) - async def test_reconnection_with_custom_process_exception_raising_exception(self): + async def test_reconnect_with_custom_process_exception_raising_exception(self): """Client supports raising an exception in process_exception.""" def process_request(connection, request): @@ -137,84 +242,107 @@ def process_exception(exc): "🫖 💔 ☕️", ) - async def test_existing_socket(self): - """Client connects using a pre-existing socket.""" - async with serve(*args) as server: - with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to the right socket. - async with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async def test_redirect(self): + """Client follows redirect.""" - async def test_additional_headers(self): - """Client can set additional headers with additional_headers.""" - async with serve(*args) as server: - async with connect( - get_uri(server), additional_headers={"Authorization": "Bearer ..."} - ) as client: - self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response - async def test_override_user_agent(self): - """Client can override User-Agent header with user_agent_header.""" - async with serve(*args) as server: - async with connect(get_uri(server), user_agent_header="Smith") as client: - self.assertEqual(client.request.headers["User-Agent"], "Smith") + async with serve(*args, process_request=redirect) as server: + async with connect(get_uri(server) + "/redirect") as client: + self.assertEqual(client.protocol.wsuri.path, "/") - async def test_remove_user_agent(self): - """Client can remove User-Agent header with user_agent_header.""" - async with serve(*args) as server: - async with connect(get_uri(server), user_agent_header=None) as client: - self.assertNotIn("User-Agent", client.request.headers) + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" - async def test_compression_is_enabled(self): - """Client enables compression by default.""" - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual( - [type(ext) for ext in client.protocol.extensions], - [PerMessageDeflate], - ) + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response - async def test_disable_compression(self): - """Client disables compression.""" - async with serve(*args) as server: - async with connect(get_uri(server), compression=None) as client: - self.assertEqual(client.protocol.extensions, []) + async with serve(*args, process_request=redirect) as server: + async with serve(*args) as other_server: + async with connect(get_uri(server)): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) - async def test_keepalive_is_enabled(self): - """Client enables keepalive and measures latency by default.""" - async with serve(*args) as server: - async with connect(get_uri(server), ping_interval=MS) as client: - self.assertEqual(client.latency, 0) - await asyncio.sleep(2 * MS) - self.assertGreater(client.latency, 0) + async def test_redirect_limit(self): + """Client stops following redirects after limit is reached.""" - async def test_disable_keepalive(self): - """Client disables keepalive.""" - async with serve(*args) as server: - async with connect(get_uri(server), ping_interval=None) as client: - await asyncio.sleep(2 * MS) - self.assertEqual(client.latency, 0) + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = request.path + return response - async def test_logger(self): - """Client accepts a logger argument.""" - logger = logging.getLogger("test") - async with serve(*args) as server: - async with connect(get_uri(server), logger=logger) as client: - self.assertEqual(client.logger.name, logger.name) + async with serve(*args, process_request=redirect) as server: + async with few_redirects(): + with self.assertRaises(SecurityError) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") - async def test_custom_connection_factory(self): - """Client runs ClientConnection factory provided in create_connection.""" + self.assertEqual( + str(raised.exception), + "more than 2 redirects", + ) - def create_connection(*args, **kwargs): - client = ClientConnection(*args, **kwargs) - client.create_connection_ran = True - return client + async def test_redirect_with_explicit_host_port(self): + """Client follows redirect with an explicit host / port.""" - async with serve(*args) as server: + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with serve(*args, process_request=redirect) as server: + host, port = get_host_port(server) async with connect( - get_uri(server), create_connection=create_connection + "ws://overridden/redirect", host=host, port=port ) as client: - self.assertTrue(client.create_connection_ran) + self.assertEqual(client.protocol.wsuri.path, "/") + + async def test_cross_origin_redirect_with_explicit_host_port(self): + """Client doesn't follow cross-origin redirect with an explicit host / port.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + async with serve(*args, process_request=redirect) as server: + host, port = get_host_port(server) + with self.assertRaises(ValueError) as raised: + async with connect("ws://overridden/", host=host, port=port): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ " + "with an explicit host or port", + ) + + async def test_redirect_with_existing_socket(self): + """Client doesn't follow redirect when using a pre-existing socket.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with serve(*args, process_request=redirect) as server: + with socket.create_connection(get_host_port(server)) as sock: + with self.assertRaises(ValueError) as raised: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/redirect", sock=sock): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow redirect to ws://invalid/ with a preexisting socket", + ) async def test_invalid_uri(self): """Client receives an invalid URI.""" @@ -336,6 +464,40 @@ async def test_reject_invalid_server_hostname(self): str(raised.exception), ) + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response + + async with serve(*args, ssl=SERVER_CONTEXT, process_request=redirect) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as other_server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + async def test_redirect_to_insecure_uri(self): + """Client doesn't follow redirect from secure URI to non-secure URI.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = insecure_uri + return response + + async with serve(*args, ssl=SERVER_CONTEXT, process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + secure_uri = get_uri(server) + insecure_uri = secure_uri.replace("wss://", "ws://") + async with connect(secure_uri, ssl=CLIENT_CONTEXT): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + f"cannot follow redirect to non-secure URI {insecure_uri}", + ) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.IsolatedAsyncioTestCase): @@ -354,6 +516,25 @@ async def test_set_host_header(self): async with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") + async def test_cross_origin_redirect(self): + """Client doesn't follows redirect to a URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + with temp_unix_socket_path() as path: + async with unix_serve(handler, path, process_request=redirect): + with self.assertRaises(ValueError) as raised: + async with unix_connect(path): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ with a Unix socket", + ) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): From 566ab1d4165f9dfaa469cd19858bc5b451b182cd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 22:15:28 +0200 Subject: [PATCH 612/762] The new asyncio implementation has reached parity. --- docs/howto/upgrade.rst | 3 ++- docs/index.rst | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 120509c9f..70254d93e 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -31,7 +31,8 @@ respectively. The next steps are: - 1. Deprecating it once the new implementation reaches feature parity. + 1. Deprecating it once the new implementation is considered sufficiently + robust. 2. Maintaining it for five years per the :ref:`backwards-compatibility policy `. 3. Removing it. This is expected to happen around 2030. diff --git a/docs/index.rst b/docs/index.rst index 218a489a3..b8cd300e3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,7 +26,7 @@ with a focus on correctness, simplicity, robustness, and performance. .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API -It supports several network I/O and control flow paradigms: +It supports several network I/O and control flow paradigms. 1. The primary implementation builds upon :mod:`asyncio`, Python's standard asynchronous I/O framework. It provides an elegant coroutine-based API. It's @@ -44,10 +44,10 @@ It supports several network I/O and control flow paradigms: the Sans-I/O implementation. It adds a few features that were impossible to implement within the original design. - The new implementation will become the default as soon as it reaches - feature parity. If you're using the historical implementation, you should - :doc:`ugrade to the new implementation `. It's usually - straightforward. + The new implementation provides all features of the historical + implementation, and a few more. If you're using the historical + implementation, you should :doc:`ugrade to the new implementation + `. It's usually straightforward. 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used From ef7b1e32d26b2104d8bc5e2ac3a001e0504aba11 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 22:24:21 +0200 Subject: [PATCH 613/762] Proof-read upgrade guide. --- docs/howto/upgrade.rst | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 70254d93e..f3e42591e 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -10,14 +10,14 @@ It provides a very similar API. However, there are a few differences. The recommended upgrade process is: -#. Make sure that your application doesn't use any `deprecated APIs`_. If it - doesn't raise any warnings, you can skip this step. -#. `Update import paths`_. For straightforward usage of websockets, this could - be the only step you need to take. Upgrading could be transparent. -#. Check out `new features and improvements`_ and consider taking advantage of - them to improve your application. -#. Review `API changes`_ and adapt your application to preserve its current - functionality. +#. Make sure that your code doesn't use any `deprecated APIs`_. If it doesn't + raise warnings, you're fine. +#. `Update import paths`_. For straightforward use cases, this could be the only + step you need to take. +#. Check out `new features and improvements`_. Consider taking advantage of them + in your code. +#. Review `API changes`_. If needed, update your application to preserve its + current behavior. In the interest of brevity, only :func:`~asyncio.client.connect` and :func:`~asyncio.server.serve` are discussed below but everything also applies @@ -146,9 +146,8 @@ Customizing the opening handshake ................................. On the server side, if you're customizing how :func:`~legacy.server.serve` -processes the opening handshake with the ``process_request``, ``extra_headers``, -or ``select_subprotocol``, you must update your code and you can probably make -it simpler. +processes the opening handshake with ``process_request``, ``extra_headers``, or +``select_subprotocol``, you must update your code. Probably you can simplify it! ``process_request`` and ``select_subprotocol`` have new signatures. ``process_response`` replaces ``extra_headers`` and provides more flexibility. @@ -481,16 +480,17 @@ a ``check_credentials`` coroutine as well as an optional ``realm`` just like This new API has more obvious semantics. That makes it easier to understand and also easier to extend. -In the original implementation, overriding ``create_protocol`` changed the type +In the original implementation, overriding ``create_protocol`` changes the type of connection objects to :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, a subclass of :class:`~legacy.server.WebSocketServerProtocol` that performs HTTP -Basic Authentication in its ``process_request`` method. If you wanted to -customize ``process_request`` further, you had: +Basic Authentication in its ``process_request`` method. -* an ill-defined option: add a ``process_request`` argument to +To customize ``process_request`` further, you had only bad options: + +* the ill-defined option: add a ``process_request`` argument to :func:`~legacy.server.serve`; to tell which one would run first, you had to experiment or read the code; -* a cumbersome option: subclass +* the cumbersome option: subclass :class:`~legacy.auth.BasicAuthWebSocketServerProtocol`, then pass that subclass in the ``create_protocol`` argument of :func:`~legacy.auth.basic_auth_protocol_factory`. From 0158a246683437de2f85a27fecef06f5e5393e99 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 22:39:43 +0200 Subject: [PATCH 614/762] Forgotten in d8ab09b2. --- docs/reference/features.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 32fc05baf..eeade1462 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -142,7 +142,7 @@ Client +------------------------------------+--------+--------+--------+--------+ | Close connection on context exit | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Reconnect automatically | ❌ | ❌ | — | ✅ | + | Reconnect automatically | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ From 4a9dae23b3661210331463968cf4a5eeb54c41dd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 16:58:50 +0200 Subject: [PATCH 615/762] Simplify handling of connection close during handshake. --- src/websockets/asyncio/client.py | 13 ++++--------- src/websockets/asyncio/server.py | 13 ++++--------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 50f67b95f..2b8fbfd3a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -90,7 +90,10 @@ async def handshake( self.protocol.send_request(self.request) # May raise CancelledError if open_timeout is exceeded. - await self.response_rcvd + await asyncio.wait( + [self.response_rcvd, self.connection_lost_waiter], + return_when=asyncio.FIRST_COMPLETED, + ) if self.response is None: raise ConnectionError("connection closed during handshake") @@ -118,14 +121,6 @@ def process_event(self, event: Event) -> None: else: super().process_event(event) - def connection_lost(self, exc: Exception | None) -> None: - try: - super().connection_lost(exc) - finally: - # If the connection is closed during the handshake, unblock it. - if not self.response_rcvd.done(): - self.response_rcvd.set_result(None) - def process_exception(exc: Exception) -> Exception | None: """ diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 1aa47af88..c04c5202f 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -139,7 +139,10 @@ async def handshake( """ # May raise CancelledError if open_timeout is exceeded. - await self.request_rcvd + await asyncio.wait( + [self.request_rcvd, self.connection_lost_waiter], + return_when=asyncio.FIRST_COMPLETED, + ) if self.request is None: raise ConnectionError("connection closed during handshake") @@ -229,14 +232,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: super().connection_made(transport) self.server.start_connection_handler(self) - def connection_lost(self, exc: Exception | None) -> None: - try: - super().connection_lost(exc) - finally: - # If the connection is closed during the handshake, unblock it. - if not self.request_rcvd.done(): - self.request_rcvd.set_result(None) - class Server: """ From 96055073d3cbeb6e52bb419205b2753d3898d612 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 18:53:57 +0200 Subject: [PATCH 616/762] Close connection when client receives bad response. This avoids stalling until the opening handshake timeouts. --- src/websockets/asyncio/client.py | 21 +++++++++----------- src/websockets/client.py | 2 ++ src/websockets/protocol.py | 12 ++++++------ src/websockets/sync/client.py | 28 ++++++++++++--------------- tests/asyncio/test_client.py | 24 +++++++++++++++++++++-- tests/sync/test_client.py | 33 ++++++++++++++++++++++++++++++-- tests/test_client.py | 2 +- 7 files changed, 83 insertions(+), 39 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 2b8fbfd3a..3985bfb6a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -89,23 +89,19 @@ async def handshake( self.request.headers["User-Agent"] = user_agent_header self.protocol.send_request(self.request) - # May raise CancelledError if open_timeout is exceeded. await asyncio.wait( [self.response_rcvd, self.connection_lost_waiter], return_when=asyncio.FIRST_COMPLETED, ) - if self.response is None: - raise ConnectionError("connection closed during handshake") + # self.protocol.handshake_exc is always set when the connection is lost + # before receiving a response, when the response cannot be parsed, or + # when the response fails the handshake. if self.protocol.handshake_exc is None: self.start_keepalive() else: - try: - async with asyncio_timeout(self.close_timeout): - await self.connection_lost_waiter - finally: - raise self.protocol.handshake_exc + raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: """ @@ -132,7 +128,8 @@ def process_exception(exc: Exception) -> Exception | None: This function defines the default behavior, which is to retry on: - * :exc:`OSError` and :exc:`asyncio.TimeoutError`: network errors; + * :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network + errors; * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500, 502, 503, or 504: server or proxy errors. @@ -150,7 +147,7 @@ def process_exception(exc: Exception) -> Exception | None: That exception will be raised, breaking out of the retry loop. """ - if isinstance(exc, (OSError, asyncio.TimeoutError)): + if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): return None if isinstance(exc, InvalidStatus) and exc.response.status_code in [ 500, # Internal Server Error @@ -445,7 +442,7 @@ async def __await_impl__(self) -> ClientConnection: try: await self.connection.handshake(*self.handshake_args) except asyncio.CancelledError: - self.connection.transport.close() + self.connection.close_transport() raise except Exception as exc: # Always close the connection even though keep-alive is @@ -454,7 +451,7 @@ async def __await_impl__(self) -> ClientConnection: # protocol. In the current design of connect(), there is # no easy way to reuse the network connection that works # in every case nor to reinitialize the protocol. - self.connection.transport.close() + self.connection.close_transport() uri_or_exc = self.process_redirect(exc) # Response is a valid redirect; follow it. diff --git a/src/websockets/client.py b/src/websockets/client.py index 95de99dc5..0e36fd028 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -323,6 +323,7 @@ def parse(self) -> Generator[None, None, None]: ) except Exception as exc: self.handshake_exc = exc + self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine yield @@ -341,6 +342,7 @@ def parse(self) -> Generator[None, None, None]: response._exception = exc self.events.append(response) self.handshake_exc = exc + self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine yield diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 3b3e80cf5..8751ebdb4 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -610,18 +610,18 @@ def discard(self) -> Generator[None, None, None]: - after sending a close frame, during an abnormal closure (7.1.7). """ - # The server close the TCP connection in the same circumstances where - # discard() replaces parse(). The client closes the connection later, - # after the server closes the connection or a timeout elapses. - # (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.side is SERVER) == (self.eof_sent) + # After the opening handshake completes, the server closes the TCP + # connection in the same circumstances where discard() replaces parse(). + # The client closes it when it receives EOF from the server or times + # out. (The latter case cannot be handled in this Sans-I/O layer.) + assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. - if self.side is CLIENT: + if self.state != CONNECTING and self.side is CLIENT: self.send_eof() self.state = CLOSED # If discard() completes normally, execution ends here. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 6a04515f0..d1e20a757 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -12,7 +12,7 @@ from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response -from ..protocol import CONNECTING, OPEN, Event +from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol from ..uri import parse_uri from .connection import Connection @@ -80,19 +80,11 @@ def handshake( self.protocol.send_request(self.request) if not self.response_rcvd.wait(timeout): - self.close_socket() - self.recv_events_thread.join() raise TimeoutError("timed out during handshake") - if self.response is None: - self.close_socket() - self.recv_events_thread.join() - raise ConnectionError("connection closed during handshake") - - if self.protocol.state is not OPEN: - self.recv_events_thread.join(self.close_timeout) - self.close_socket() - self.recv_events_thread.join() + # self.protocol.handshake_exc is always set when the connection is lost + # before receiving a response, when the response cannot be parsed, or + # when the response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -295,16 +287,20 @@ def connect( protocol, close_timeout=close_timeout, ) - # On failure, handshake() closes the socket and raises an exception. + except Exception: + if sock is not None: + sock.close() + raise + + try: connection.handshake( additional_headers, user_agent_header, deadline.timeout(), ) - except Exception: - if sock is not None: - sock.close() + connection.close_socket() + connection.recv_events_thread.join() raise return connection diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index b0487552e..725bac92b 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -401,12 +401,32 @@ def close_connection(self, request): self.close_transport() async with serve(*args, process_request=close_connection) as server: - with self.assertRaises(ConnectionError) as raised: + with self.assertRaises(EOFError) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), - "connection closed during handshake", + "connection closed while reading HTTP status line", + ) + + async def test_junk_handshake(self): + """Client closes the connection when receiving non-HTTP response from server.""" + + async def junk(reader, writer): + await asyncio.sleep(MS) # wait for the client to send the handshake request + writer.write(b"220 smtp.invalid ESMTP Postfix\r\n") + await reader.read(4096) # wait for the client to close the connection + writer.close() + + server = await asyncio.start_server(junk, "localhost", 0) + host, port = get_host_port(server) + async with server: + with self.assertRaises(ValueError) as raised: + async with connect(f"ws://{host}:{port}"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: 220", ) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 812412203..96f7f0c90 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,7 +1,9 @@ import logging import socket +import socketserver import ssl import threading +import time import unittest from websockets.exceptions import InvalidHandshake, InvalidURI @@ -146,14 +148,41 @@ def close_connection(self, request): self.close_socket() with run_server(process_request=close_connection) as server: - with self.assertRaises(ConnectionError) as raised: + with self.assertRaises(EOFError) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), - "connection closed during handshake", + "connection closed while reading HTTP status line", ) + def test_junk_handshake(self): + """Client closes the connection when receiving non-HTTP response from server.""" + + class JunkHandler(socketserver.BaseRequestHandler): + def handle(self): + time.sleep(MS) # wait for the client to send the handshake request + self.request.send(b"220 smtp.invalid ESMTP Postfix\r\n") + self.request.recv(4096) # wait for the client to close the connection + self.request.close() + + server = socketserver.TCPServer(("localhost", 0), JunkHandler) + host, port = server.server_address + with server: + thread = threading.Thread(target=server.serve_forever, args=(MS,)) + thread.start() + try: + with self.assertRaises(ValueError) as raised: + with connect(f"ws://{host}:{port}"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "unsupported HTTP version: 220", + ) + finally: + server.shutdown() + thread.join() + class SecureClientTests(unittest.TestCase): def test_connection(self): diff --git a/tests/test_client.py b/tests/test_client.py index 8b3bf4232..47558c1c0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -154,7 +154,7 @@ def test_receive_reject(self): ) [response] = client.events_received() self.assertIsInstance(response, Response) - self.assertEqual(client.data_to_send(), []) + self.assertEqual(client.data_to_send(), [b""]) self.assertTrue(client.close_expected()) self.assertEqual(client.state, CONNECTING) From 560f6eed24a516bf9c8c4c7eaf83f84d2ce606f0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 21:26:32 +0200 Subject: [PATCH 617/762] Log error when server receives bad request. --- src/websockets/asyncio/server.py | 155 ++++++++++++++++--------------- src/websockets/sync/server.py | 136 ++++++++++++++------------- tests/asyncio/test_server.py | 21 +++++ tests/sync/test_server.py | 32 +++++++ tests/test_server.py | 2 +- 5 files changed, 206 insertions(+), 140 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index c04c5202f..228b20012 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -138,81 +138,76 @@ async def handshake( Perform the opening handshake. """ - # May raise CancelledError if open_timeout is exceeded. await asyncio.wait( [self.request_rcvd, self.connection_lost_waiter], return_when=asyncio.FIRST_COMPLETED, ) - if self.request is None: - raise ConnectionError("connection closed during handshake") - - async with self.send_context(expected_state=CONNECTING): - response = None - - if process_request is not None: - try: - response = process_request(self, self.request) - if isinstance(response, Awaitable): - response = await response - except Exception as exc: - self.protocol.handshake_exc = exc - self.logger.error("opening handshake failed", exc_info=True) - response = self.protocol.reject( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ), - ) - - if response is None: - if self.server.is_serving(): - self.response = self.protocol.accept(self.request) + if self.request is not None: + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if self.server.is_serving(): + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) else: - self.response = self.protocol.reject( - http.HTTPStatus.SERVICE_UNAVAILABLE, - "Server is shutting down.\n", - ) - else: - assert isinstance(response, Response) # help mypy - self.response = response - - if server_header: - self.response.headers["Server"] = server_header - - response = None - - if process_response is not None: - try: - response = process_response(self, self.request, self.response) - if isinstance(response, Awaitable): - response = await response - except Exception as exc: - self.protocol.handshake_exc = exc - self.logger.error("opening handshake failed", exc_info=True) - response = self.protocol.reject( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ), - ) - - if response is not None: - assert isinstance(response, Response) # help mypy - self.response = response - - self.protocol.send_response(self.response) + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + # self.protocol.handshake_exc is always set when the connection is lost + # before receiving a request, when the request cannot be parsed, or when + # the response fails the handshake. if self.protocol.handshake_exc is None: self.start_keepalive() else: - try: - async with asyncio_timeout(self.close_timeout): - await self.connection_lost_waiter - finally: - raise self.protocol.handshake_exc + raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: """ @@ -359,25 +354,35 @@ async def conn_handler(self, connection: ServerConnection) -> None: """ try: - # On failure, handshake() closes the transport, raises an - # exception, and logs it. async with asyncio_timeout(self.open_timeout): - await connection.handshake( - self.process_request, - self.process_response, - self.server_header, - ) + try: + await connection.handshake( + self.process_request, + self.process_response, + self.server_header, + ) + except asyncio.CancelledError: + connection.close_transport() + raise + except Exception: + connection.logger.error("opening handshake failed", exc_info=True) + connection.close_transport() + return try: await self.handler(connection) except Exception: - self.logger.error("connection handler failed", exc_info=True) + connection.logger.error("connection handler failed", exc_info=True) await connection.close(CloseCode.INTERNAL_ERROR) else: await connection.close() - except Exception: - # Don't leak connections on errors. + except TimeoutError: + # When the opening handshake times out, there's nothing to log. + pass + + except Exception: # pragma: no cover + # Don't leak connections on unexpected errors. connection.transport.abort() finally: diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 15de458b5..eb0536013 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -23,7 +23,7 @@ validate_subprotocols, ) from ..http11 import SERVER, Request, Response -from ..protocol import CONNECTING, OPEN, Event +from ..protocol import CONNECTING, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection @@ -118,61 +118,56 @@ def handshake( """ if not self.request_rcvd.wait(timeout): - self.close_socket() - self.recv_events_thread.join() raise TimeoutError("timed out during handshake") - if self.request is None: - self.close_socket() - self.recv_events_thread.join() - raise ConnectionError("connection closed during handshake") - - with self.send_context(expected_state=CONNECTING): - self.response = None - - if process_request is not None: - try: - self.response = process_request(self, self.request) - except Exception as exc: - self.protocol.handshake_exc = exc - self.logger.error("opening handshake failed", exc_info=True) - self.response = self.protocol.reject( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ), - ) - - if self.response is None: - self.response = self.protocol.accept(self.request) - - if server_header: - self.response.headers["Server"] = server_header - - if process_response is not None: - try: - response = process_response(self, self.request, self.response) - except Exception as exc: - self.protocol.handshake_exc = exc - self.logger.error("opening handshake failed", exc_info=True) - self.response = self.protocol.reject( - http.HTTPStatus.INTERNAL_SERVER_ERROR, - ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ), - ) + if self.request is not None: + with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + self.response = self.protocol.accept(self.request) else: + self.response = response + + if server_header: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + if response is not None: self.response = response - self.protocol.send_response(self.response) + self.protocol.send_response(self.response) - if self.protocol.state is not OPEN: - self.recv_events_thread.join(self.close_timeout) - self.close_socket() - self.recv_events_thread.join() + # self.protocol.handshake_exc is always set when the connection is lost + # before receiving a request, when the request cannot be parsed, or when + # the response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -552,26 +547,39 @@ def protocol_select_subprotocol( protocol, close_timeout=close_timeout, ) - # On failure, handshake() closes the socket, raises an exception, and - # logs it. - connection.handshake( - process_request, - process_response, - server_header, - deadline.timeout(), - ) - except Exception: sock.close() return try: - handler(connection) - except Exception: - protocol.logger.error("connection handler failed", exc_info=True) - connection.close(CloseCode.INTERNAL_ERROR) - else: - connection.close() + try: + connection.handshake( + process_request, + process_response, + server_header, + deadline.timeout(), + ) + except TimeoutError: + connection.close_socket() + connection.recv_events_thread.join() + return + except Exception: + connection.logger.error("opening handshake failed", exc_info=True) + connection.close_socket() + connection.recv_events_thread.join() + return + + try: + handler(connection) + except Exception: + connection.logger.error("connection handler failed", exc_info=True) + connection.close(CloseCode.INTERNAL_ERROR) + else: + connection.close() + + except Exception: # pragma: no cover + # Don't leak sockets on unexpected errors. + sock.close() # Initialize server diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index bc1d0444c..fdcbf9780 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -406,6 +406,27 @@ async def test_connection_closed_during_handshake(self): _reader, writer = await asyncio.open_connection(*get_host_port(server)) writer.close() + async def test_junk_handshake(self): + """Server closes the connection when receiving non-HTTP request from client.""" + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args) as server: + reader, writer = await asyncio.open_connection(*get_host_port(server)) + writer.write(b"HELO relay.invalid\r\n") + try: + # Wait for the server to close the connection. + self.assertEqual(await reader.read(4096), b"") + finally: + writer.close() + + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["invalid HTTP request line: HELO relay.invalid"], + ) + async def test_close_server_rejects_connecting_connections(self): """Server rejects connecting connections with HTTP 503 when closing.""" diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 7bcf144a2..56565dab7 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -310,6 +310,38 @@ def handler(sock, addr): # Wait for the server thread to terminate. server_thread.join() + def test_junk_handshake(self): + """Server closes the connection when receiving non-HTTP request from client.""" + with self.assertLogs("websockets.server", logging.ERROR) as logs: + with run_server() as server: + # Patch handler to record a reference to the thread running it. + server_thread = None + original_handler = server.handler + + def handler(sock, addr): + nonlocal server_thread + server_thread = threading.current_thread() + original_handler(sock, addr) + + server.handler = handler + + with socket.create_connection(server.socket.getsockname()) as sock: + sock.send(b"HELO relay.invalid\r\n") + # Wait for the server to close the connection. + self.assertEqual(sock.recv(4096), b"") + + # Wait for the server thread to terminate. + server_thread.join() + + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["invalid HTTP request line: HELO relay.invalid"], + ) + class SecureServerTests(EvalShellMixin, unittest.TestCase): def test_connection(self): diff --git a/tests/test_server.py b/tests/test_server.py index e7d249f49..d34c8e83d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -83,7 +83,7 @@ def test_partial_request(self): server.receive_eof() self.assertEqual(server.events_received(), []) - def test_random_request(self): + def test_junk_request(self): server = ServerProtocol() server.receive_data(b"HELO relay.invalid\r\n") server.receive_data(b"MAIL FROM: \r\n") From cb42484bb9e79b12f5d7a19bc96e43a022bc511a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 22:15:33 +0200 Subject: [PATCH 618/762] Improve error messages on HTTP parsing errors. --- src/websockets/http11.py | 30 +++++++++++++++++------------- tests/asyncio/test_client.py | 3 ++- tests/sync/test_client.py | 3 ++- tests/test_http11.py | 25 ++++++++++++------------- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 61865bb92..47cef7a9b 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -135,14 +135,15 @@ def parse( raise EOFError("connection closed while reading HTTP request line") from exc try: - method, raw_path, version = request_line.split(b" ", 2) + method, raw_path, protocol = request_line.split(b" ", 2) except ValueError: # not enough values to unpack (expected 3, got 1-2) raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None - + if protocol != b"HTTP/1.1": + raise ValueError( + f"unsupported protocol; expected HTTP/1.1: {d(request_line)}" + ) if method != b"GET": - raise ValueError(f"unsupported HTTP method: {d(method)}") - if version != b"HTTP/1.1": - raise ValueError(f"unsupported HTTP version: {d(version)}") + raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}") path = raw_path.decode("ascii", "surrogateescape") headers = yield from parse_headers(read_line) @@ -236,23 +237,26 @@ def parse( raise EOFError("connection closed while reading HTTP status line") from exc try: - version, raw_status_code, raw_reason = status_line.split(b" ", 2) + protocol, raw_status_code, raw_reason = status_line.split(b" ", 2) except ValueError: # not enough values to unpack (expected 3, got 1-2) raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None - - if version != b"HTTP/1.1": - raise ValueError(f"unsupported HTTP version: {d(version)}") + if protocol != b"HTTP/1.1": + raise ValueError( + f"unsupported protocol; expected HTTP/1.1: {d(status_line)}" + ) try: status_code = int(raw_status_code) except ValueError: # invalid literal for int() with base 10 raise ValueError( - f"invalid HTTP status code: {d(raw_status_code)}" + f"invalid status code; expected integer; got {d(raw_status_code)}" ) from None - if not 100 <= status_code < 1000: - raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") + if not 100 <= status_code < 600: + raise ValueError( + f"invalid status code; expected 100–599; got {d(raw_status_code)}" + ) if not _value_re.fullmatch(raw_reason): raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") - reason = raw_reason.decode() + reason = raw_reason.decode("ascii", "surrogateescape") headers = yield from parse_headers(read_line) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 725bac92b..999ef1b71 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -426,7 +426,8 @@ async def junk(reader, writer): self.fail("did not raise") self.assertEqual( str(raised.exception), - "unsupported HTTP version: 220", + "unsupported protocol; expected HTTP/1.1: " + "220 smtp.invalid ESMTP Postfix", ) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 96f7f0c90..e63d774b7 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -177,7 +177,8 @@ def handle(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "unsupported HTTP version: 220", + "unsupported protocol; expected HTTP/1.1: " + "220 smtp.invalid ESMTP Postfix", ) finally: server.shutdown() diff --git a/tests/test_http11.py b/tests/test_http11.py index d2e5e0462..1fbcb3ba4 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -50,22 +50,22 @@ def test_parse_invalid_request_line(self): "invalid HTTP request line: GET /", ) - def test_parse_unsupported_method(self): - self.reader.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") + def test_parse_unsupported_protocol(self): + self.reader.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "unsupported HTTP method: OPTIONS", + "unsupported protocol; expected HTTP/1.1: GET /chat HTTP/1.0", ) - def test_parse_unsupported_version(self): - self.reader.feed_data(b"GET /chat HTTP/1.0\r\n\r\n") + def test_parse_unsupported_method(self): + self.reader.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "unsupported HTTP version: HTTP/1.0", + "unsupported HTTP method; expected GET; got OPTIONS", ) def test_parse_invalid_header(self): @@ -171,31 +171,30 @@ def test_parse_invalid_status_line(self): "invalid HTTP status line: Hello!", ) - def test_parse_unsupported_version(self): + def test_parse_unsupported_protocol(self): self.reader.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "unsupported HTTP version: HTTP/1.0", + "unsupported protocol; expected HTTP/1.1: HTTP/1.0 400 Bad Request", ) - def test_parse_invalid_status(self): + def test_parse_non_integer_status(self): self.reader.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "invalid HTTP status code: OMG", + "invalid status code; expected integer; got OMG", ) - def test_parse_unsupported_status(self): + def test_parse_non_three_digit_status(self): self.reader.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n") with self.assertRaises(ValueError) as raised: next(self.parse()) self.assertEqual( - str(raised.exception), - "unsupported HTTP status code: 007", + str(raised.exception), "invalid status code; expected 100–599; got 007" ) def test_parse_invalid_reason(self): From 90270d8b1bc19ee0c4ac94c6a200c95614ab8771 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 22:20:24 +0200 Subject: [PATCH 619/762] Add changelog for previous commits. --- docs/project/changelog.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c77876a73..5716c6f25 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -59,6 +59,11 @@ New features * Made the set of active connections available in the :attr:`Server.connections ` property. +Improvements +............ + +* Improved reporting of errors during the opening handshake. + 13.0.1 ------ From 14d9d40acd23ecc66a8ac8d0535af9d8a1b0fa07 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 8 Sep 2024 22:52:04 +0200 Subject: [PATCH 620/762] Fix typo in convenience imports. Fix #1496. --- src/websockets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index ac02a9f7e..7bd35cfa7 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -183,7 +183,7 @@ "ExtensionParameter": ".typing", "LoggerLike": ".typing", "Origin": ".typing", - "StatusLike": "typing", + "StatusLike": ".typing", "Subprotocol": ".typing", }, deprecated_aliases={ From f9cea9cca568dc92704de3744639eb4248278a8f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 9 Sep 2024 21:45:49 +0200 Subject: [PATCH 621/762] Improve isolation of tests of sync implementation. Before this change, threads handling requests could continue running after the end of the test. This caused spurious failures. Specifically, a test expecting an error log could get an error log from a previous tests. This happened sporadically on PyPy. --- tests/sync/server.py | 18 +++++++++++++ tests/sync/test_server.py | 54 +++------------------------------------ 2 files changed, 21 insertions(+), 51 deletions(-) diff --git a/tests/sync/server.py b/tests/sync/server.py index 114c1545b..fd7a03d82 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -38,12 +38,30 @@ def run_server(handler=handler, host="localhost", port=0, **kwargs): with serve(handler, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() + + # HACK: since the sync server doesn't track connections (yet), we record + # a reference to the thread handling the most recent connection, then we + # can wait for that thread to terminate when exiting the context. + handler_thread = None + original_handler = server.handler + + def handler(sock, addr): + nonlocal handler_thread + handler_thread = threading.current_thread() + original_handler(sock, addr) + + server.handler = handler + try: yield server finally: server.shutdown() thread.join() + # HACK: wait for the thread handling the most recent connection. + if handler_thread is not None: + handler_thread.join() + @contextlib.contextmanager def run_unix_server(path, handler=handler, **kwargs): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 56565dab7..d0d2c0955 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -3,7 +3,7 @@ import http import logging import socket -import threading +import time import unittest from websockets.exceptions import ( @@ -289,50 +289,19 @@ def test_timeout_during_handshake(self): def test_connection_closed_during_handshake(self): """Server reads EOF before receiving handshake request from client.""" with run_server() as server: - # Patch handler to record a reference to the thread running it. - server_thread = None - conn_received = threading.Event() - original_handler = server.handler - - def handler(sock, addr): - nonlocal server_thread - server_thread = threading.current_thread() - nonlocal conn_received - conn_received.set() - original_handler(sock, addr) - - server.handler = handler - with socket.create_connection(server.socket.getsockname()): # Wait for the server to receive the connection, then close it. - conn_received.wait() - - # Wait for the server thread to terminate. - server_thread.join() + time.sleep(MS) def test_junk_handshake(self): """Server closes the connection when receiving non-HTTP request from client.""" with self.assertLogs("websockets.server", logging.ERROR) as logs: with run_server() as server: - # Patch handler to record a reference to the thread running it. - server_thread = None - original_handler = server.handler - - def handler(sock, addr): - nonlocal server_thread - server_thread = threading.current_thread() - original_handler(sock, addr) - - server.handler = handler - with socket.create_connection(server.socket.getsockname()) as sock: sock.send(b"HELO relay.invalid\r\n") # Wait for the server to close the connection. self.assertEqual(sock.recv(4096), b"") - # Wait for the server thread to terminate. - server_thread.join() - self.assertEqual( [record.getMessage() for record in logs.records], ["opening handshake failed"], @@ -360,26 +329,9 @@ def test_timeout_during_tls_handshake(self): def test_connection_closed_during_tls_handshake(self): """Server reads EOF before receiving TLS handshake request from client.""" with run_server(ssl=SERVER_CONTEXT) as server: - # Patch handler to record a reference to the thread running it. - server_thread = None - conn_received = threading.Event() - original_handler = server.handler - - def handler(sock, addr): - nonlocal server_thread - server_thread = threading.current_thread() - nonlocal conn_received - conn_received.set() - original_handler(sock, addr) - - server.handler = handler - with socket.create_connection(server.socket.getsockname()): # Wait for the server to receive the connection, then close it. - conn_received.wait() - - # Wait for the server thread to terminate. - server_thread.join() + time.sleep(MS) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") From 070ff1a3e536e497a9ee11ea3b0649f82f3974c9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 9 Sep 2024 22:36:05 +0200 Subject: [PATCH 622/762] Add dedicated ConcurrencyError exception. Previously, a generic RuntimeError was used. Fix #1499. --- docs/project/changelog.rst | 4 +++ docs/reference/exceptions.rst | 5 ++++ src/websockets/__init__.py | 3 +++ src/websockets/asyncio/connection.py | 36 ++++++++++++++++----------- src/websockets/asyncio/messages.py | 19 +++++++------- src/websockets/exceptions.py | 12 +++++++++ src/websockets/sync/connection.py | 37 ++++++++++++++++------------ src/websockets/sync/messages.py | 15 +++++------ tests/asyncio/test_connection.py | 28 ++++++++++++--------- tests/asyncio/test_messages.py | 19 +++++++------- tests/sync/test_connection.py | 28 ++++++++++++--------- tests/sync/test_messages.py | 17 +++++++------ tests/test_exceptions.py | 4 +++ 13 files changed, 140 insertions(+), 87 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5716c6f25..f92ca68b6 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -64,6 +64,10 @@ Improvements * Improved reporting of errors during the opening handshake. +* Raised :exc:`~exceptions.ConcurrencyError` on unsupported concurrent calls. + Previously, :exc:`RuntimeError` was raised. For backwards compatibility, + :exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`. + 13.0.1 ------ diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 14a8edcd1..75934ef99 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -64,6 +64,11 @@ translated to :exc:`ConnectionClosedError` in the other implementations. .. autoexception:: InvalidState +Miscellaneous exceptions +------------------------ + +.. autoexception:: ConcurrencyError + Legacy exceptions ----------------- diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 7bd35cfa7..54591e9fd 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -14,6 +14,7 @@ "HeadersLike", "MultipleValuesError", # .exceptions + "ConcurrencyError", "ConnectionClosed", "ConnectionClosedError", "ConnectionClosedOK", @@ -72,6 +73,7 @@ from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( + ConcurrencyError, ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, @@ -134,6 +136,7 @@ "HeadersLike": ".datastructures", "MultipleValuesError": ".datastructures", # .exceptions + "ConcurrencyError": ".exceptions", "ConnectionClosed": ".exceptions", "ConnectionClosedError": ".exceptions", "ConnectionClosedOK": ".exceptions", diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 069b3e1d2..1b24f9af0 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -19,7 +19,12 @@ cast, ) -from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State @@ -262,7 +267,7 @@ async def recv(self, decode: bool | None = None) -> Data: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If two coroutines call :meth:`recv` or + ConcurrencyError: If two coroutines call :meth:`recv` or :meth:`recv_streaming` concurrently. """ @@ -270,8 +275,8 @@ async def recv(self, decode: bool | None = None) -> Data: return await self.recv_messages.get(decode) except EOFError: raise self.protocol.close_exc from self.recv_exc - except RuntimeError: - raise RuntimeError( + except ConcurrencyError: + raise ConcurrencyError( "cannot call recv while another coroutine " "is already running recv or recv_streaming" ) from None @@ -283,8 +288,9 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data This method is designed for receiving fragmented messages. It returns an asynchronous iterator that yields each fragment as it is received. This iterator must be fully consumed. Else, future calls to :meth:`recv` or - :meth:`recv_streaming` will raise :exc:`RuntimeError`, making the - connection unusable. + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection + unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. @@ -315,7 +321,7 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If two coroutines call :meth:`recv` or + ConcurrencyError: If two coroutines call :meth:`recv` or :meth:`recv_streaming` concurrently. """ @@ -324,8 +330,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data yield frame except EOFError: raise self.protocol.close_exc from self.recv_exc - except RuntimeError: - raise RuntimeError( + except ConcurrencyError: + raise ConcurrencyError( "cannot call recv_streaming while another coroutine " "is already running recv or recv_streaming" ) from None @@ -593,7 +599,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If another ping was sent with the same data and + ConcurrencyError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ @@ -607,7 +613,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: async with self.send_context(): # Protect against duplicates if a payload is explicitly set. if data in self.pong_waiters: - raise RuntimeError("already waiting for a pong with the same data") + raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.pong_waiters: @@ -793,7 +799,7 @@ async def send_context( # Let the caller interact with the protocol. try: yield - except (ProtocolError, RuntimeError): + except (ProtocolError, ConcurrencyError): # The protocol state wasn't changed. Exit immediately. raise except Exception as exc: @@ -1092,15 +1098,17 @@ def broadcast( if raise_exceptions: if sys.version_info[:2] < (3, 11): # pragma: no cover raise ValueError("raise_exceptions requires at least Python 3.11") - exceptions = [] + exceptions: list[Exception] = [] for connection in connections: + exception: Exception + if connection.protocol.state is not OPEN: continue if connection.fragmented_send_waiter is not None: if raise_exceptions: - exception = RuntimeError("sending a fragmented message") + exception = ConcurrencyError("sending a fragmented message") exceptions.append(exception) else: connection.logger.warning( diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 33ab6a5e9..c2b4afd67 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -12,6 +12,7 @@ TypeVar, ) +from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -49,7 +50,7 @@ async def get(self) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: if self.get_waiter is not None: - raise RuntimeError("get is already running") + raise ConcurrencyError("get is already running") self.get_waiter = self.loop.create_future() try: await self.get_waiter @@ -135,15 +136,15 @@ async def get(self, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ if self.closed: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get() or get_iter() is already running") + raise ConcurrencyError("get() or get_iter() is already running") # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True @@ -190,7 +191,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :class:`str` or :class:`bytes` for each frame in the message. The iterator must be fully consumed before calling :meth:`get_iter` or - :meth:`get` again. Else, :exc:`RuntimeError` is raised. + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. @@ -202,15 +203,15 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ if self.closed: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get() or get_iter() is already running") + raise ConcurrencyError("get() or get_iter() is already running") # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True @@ -236,7 +237,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # We cannot handle asyncio.CancelledError because we don't buffer # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to - # get() or get_iter() will raise RuntimeError. + # get() or get_iter() will raise ConcurrencyError. frame = await self.frames.get() self.maybe_resume() assert frame.opcode is OP_CONT diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index b2b679e6b..d723f2fec 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -25,6 +25,7 @@ * :exc:`ProtocolError` (Sans-I/O) * :exc:`PayloadTooBig` (Sans-I/O) * :exc:`InvalidState` (Sans-I/O) + * :exc:`ConcurrencyError` """ @@ -62,6 +63,7 @@ "WebSocketProtocolError", "PayloadTooBig", "InvalidState", + "ConcurrencyError", ] @@ -354,6 +356,16 @@ class InvalidState(WebSocketException, AssertionError): """ +class ConcurrencyError(WebSocketException, RuntimeError): + """ + Raised when receiving or sending messages concurrently. + + WebSocket is a connection-oriented protocol. Reads must be serialized; so + must be writes. However, reading and writing concurrently is possible. + + """ + + # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: from .legacy.exceptions import ( diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 16e51abda..65a7b63ed 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -10,7 +10,12 @@ from types import TracebackType from typing import Any, Iterable, Iterator, Mapping -from ..exceptions import ConnectionClosed, ConnectionClosedOK, ProtocolError +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State @@ -194,7 +199,7 @@ def recv(self, timeout: float | None = None) -> Data: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If two threads call :meth:`recv` or + ConcurrencyError: If two threads call :meth:`recv` or :meth:`recv_streaming` concurrently. """ @@ -202,8 +207,8 @@ def recv(self, timeout: float | None = None) -> Data: return self.recv_messages.get(timeout) except EOFError: raise self.protocol.close_exc from self.recv_exc - except RuntimeError: - raise RuntimeError( + except ConcurrencyError: + raise ConcurrencyError( "cannot call recv while another thread " "is already running recv or recv_streaming" ) from None @@ -227,7 +232,7 @@ def recv_streaming(self) -> Iterator[Data]: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If two threads call :meth:`recv` or + ConcurrencyError: If two threads call :meth:`recv` or :meth:`recv_streaming` concurrently. """ @@ -236,8 +241,8 @@ def recv_streaming(self) -> Iterator[Data]: yield frame except EOFError: raise self.protocol.close_exc from self.recv_exc - except RuntimeError: - raise RuntimeError( + except ConcurrencyError: + raise ConcurrencyError( "cannot call recv_streaming while another thread " "is already running recv or recv_streaming" ) from None @@ -277,7 +282,7 @@ def send(self, message: Data | Iterable[Data]) -> None: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If the connection is sending a fragmented message. + ConcurrencyError: If the connection is sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -287,7 +292,7 @@ def send(self, message: Data | Iterable[Data]) -> None: if isinstance(message, str): with self.send_context(): if self.send_in_progress: - raise RuntimeError( + raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) @@ -296,7 +301,7 @@ def send(self, message: Data | Iterable[Data]) -> None: elif isinstance(message, BytesLike): with self.send_context(): if self.send_in_progress: - raise RuntimeError( + raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) @@ -322,7 +327,7 @@ def send(self, message: Data | Iterable[Data]) -> None: text = True with self.send_context(): if self.send_in_progress: - raise RuntimeError( + raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) @@ -335,7 +340,7 @@ def send(self, message: Data | Iterable[Data]) -> None: text = False with self.send_context(): if self.send_in_progress: - raise RuntimeError( + raise ConcurrencyError( "cannot call send while another thread " "is already running send" ) @@ -371,7 +376,7 @@ def send(self, message: Data | Iterable[Data]) -> None: self.protocol.send_continuation(b"", fin=True) self.send_in_progress = False - except RuntimeError: + except ConcurrencyError: # We didn't start sending a fragmented message. # The connection is still usable. raise @@ -445,7 +450,7 @@ def ping(self, data: Data | None = None) -> threading.Event: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If another ping was sent with the same data and + ConcurrencyError: If another ping was sent with the same data and the corresponding pong wasn't received yet. """ @@ -459,7 +464,7 @@ def ping(self, data: Data | None = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. if data in self.ping_waiters: - raise RuntimeError("already waiting for a pong with the same data") + raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. while data is None or data in self.ping_waiters: @@ -665,7 +670,7 @@ def send_context( # Let the caller interact with the protocol. try: yield - except (ProtocolError, RuntimeError): + except (ProtocolError, ConcurrencyError): # The protocol state wasn't changed. Exit immediately. raise except Exception as exc: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index ff90345ac..8d090538f 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -5,6 +5,7 @@ import threading from typing import Iterator, cast +from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data @@ -74,7 +75,7 @@ def get(self, timeout: float | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` + ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. @@ -85,7 +86,7 @@ def get(self, timeout: float | None = None) -> Data: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get() or get_iter() is already running") + raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -128,14 +129,14 @@ def get_iter(self) -> Iterator[Data]: :class:`bytes` for each frame in the message. The iterator must be fully consumed before calling :meth:`get_iter` or - :meth:`get` again. Else, :exc:`RuntimeError` is raised. + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`get` or :meth:`get_iter` + ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` concurrently. """ @@ -144,7 +145,7 @@ def get_iter(self) -> Iterator[Data]: raise EOFError("stream of frames ended") if self.get_in_progress: - raise RuntimeError("get() or get_iter() is already running") + raise ConcurrencyError("get() or get_iter() is already running") chunks = self.chunks self.chunks = [] @@ -198,7 +199,7 @@ def put(self, frame: Frame) -> None: Raises: EOFError: If the stream of frames has ended. - RuntimeError: If two threads run :meth:`put` concurrently. + ConcurrencyError: If two threads run :meth:`put` concurrently. """ with self.mutex: @@ -206,7 +207,7 @@ def put(self, frame: Frame) -> None: raise EOFError("stream of frames ended") if self.put_in_progress: - raise RuntimeError("put is already running") + raise ConcurrencyError("put is already running") if frame.opcode is OP_TEXT: self.decoder = UTF8Decoder(errors="strict") diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 78f3adf68..70d9dad63 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -10,7 +10,11 @@ from websockets.asyncio.compatibility import TimeoutError, aiter, anext, asyncio_timeout from websockets.asyncio.connection import * from websockets.asyncio.connection import broadcast -from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.exceptions import ( + ConcurrencyError, + ConnectionClosedError, + ConnectionClosedOK, +) from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol, State @@ -219,12 +223,12 @@ async def test_recv_connection_closed_error(self): await self.connection.recv() async def test_recv_during_recv(self): - """recv raises RuntimeError when called concurrently with itself.""" + """recv raises ConcurrencyError when called concurrently.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task self.addCleanup(recv_task.cancel) - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: await self.connection.recv() self.assertEqual( str(raised.exception), @@ -233,14 +237,14 @@ async def test_recv_during_recv(self): ) async def test_recv_during_recv_streaming(self): - """recv raises RuntimeError when called concurrently with recv_streaming.""" + """recv raises ConcurrencyError when called concurrently with recv_streaming.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task self.addCleanup(recv_streaming_task.cancel) - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: await self.connection.recv() self.assertEqual( str(raised.exception), @@ -349,12 +353,12 @@ async def test_recv_streaming_connection_closed_error(self): self.fail("did not raise") async def test_recv_streaming_during_recv(self): - """recv_streaming raises RuntimeError when called concurrently with recv.""" + """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task self.addCleanup(recv_task.cancel) - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: async for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( @@ -364,14 +368,14 @@ async def test_recv_streaming_during_recv(self): ) async def test_recv_streaming_during_recv_streaming(self): - """recv_streaming raises RuntimeError when called concurrently with itself.""" + """recv_streaming raises ConcurrencyError when called concurrently.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task self.addCleanup(recv_streaming_task.cancel) - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: async for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( @@ -419,7 +423,7 @@ async def fragments(): gate.set_result(None) # Running recv_streaming again fails. - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await alist(self.connection.recv_streaming()) # Test send. @@ -856,7 +860,7 @@ async def test_ping_duplicate_payload(self): async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("idem") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: await self.connection.ping("idem") self.assertEqual( str(raised.exception), @@ -1252,7 +1256,7 @@ async def fragments(): self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)") exc = raised.exception.exceptions[0] self.assertEqual(str(exc), "sending a fragmented message") - self.assertIsInstance(exc, RuntimeError) + self.assertIsInstance(exc, ConcurrencyError) gate.set_result(None) await send_task diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 615b1f3a8..d2cf25c9c 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -5,6 +5,7 @@ from websockets.asyncio.compatibility import aiter, anext from websockets.asyncio.messages import * from websockets.asyncio.messages import SimpleQueue +from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from .utils import alist @@ -37,10 +38,10 @@ async def test_get_then_put(self): self.assertEqual(item, 42) async def test_get_concurrently(self): - """get cannot be called concurrently with itself.""" + """get cannot be called concurrently.""" getter_task = asyncio.create_task(self.queue.get()) await asyncio.sleep(0) # let the task start - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await self.queue.get() getter_task.cancel() @@ -361,7 +362,7 @@ async def test_cancel_get_iter_after_first_frame(self): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) # Test put @@ -423,10 +424,10 @@ async def test_close_is_idempotent(self): # Test (non-)concurrency async def test_get_fails_when_get_is_running(self): - """get cannot be called concurrently with itself.""" + """get cannot be called concurrently.""" asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await self.assembler.get() self.assembler.close() # let task terminate @@ -434,7 +435,7 @@ async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await self.assembler.get() self.assembler.close() # let task terminate @@ -442,15 +443,15 @@ async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" asyncio.create_task(self.assembler.get()) await asyncio.sleep(0) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) self.assembler.close() # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): - """get_iter cannot be called concurrently with itself.""" + """get_iter cannot be called concurrently.""" asyncio.create_task(alist(self.assembler.get_iter())) await asyncio.sleep(0) - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) self.assembler.close() # let task terminate diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index d9fb2093b..16f92e164 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -9,7 +9,11 @@ import uuid from unittest.mock import patch -from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from websockets.exceptions import ( + ConcurrencyError, + ConnectionClosedError, + ConnectionClosedOK, +) from websockets.frames import CloseCode, Frame, Opcode from websockets.protocol import CLIENT, SERVER, Protocol from websockets.sync.connection import * @@ -173,11 +177,11 @@ def test_recv_connection_closed_error(self): self.connection.recv() def test_recv_during_recv(self): - """recv raises RuntimeError when called concurrently with itself.""" + """recv raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: self.connection.recv() self.assertEqual( str(raised.exception), @@ -189,13 +193,13 @@ def test_recv_during_recv(self): recv_thread.join() def test_recv_during_recv_streaming(self): - """recv raises RuntimeError when called concurrently with recv_streaming.""" + """recv raises ConcurrencyError when called concurrently with recv_streaming.""" recv_streaming_thread = threading.Thread( target=lambda: list(self.connection.recv_streaming()) ) recv_streaming_thread.start() - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: self.connection.recv() self.assertEqual( str(raised.exception), @@ -257,11 +261,11 @@ def test_recv_streaming_connection_closed_error(self): self.fail("did not raise") def test_recv_streaming_during_recv(self): - """recv_streaming raises RuntimeError when called concurrently with recv.""" + """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( @@ -274,13 +278,13 @@ def test_recv_streaming_during_recv(self): recv_thread.join() def test_recv_streaming_during_recv_streaming(self): - """recv_streaming raises RuntimeError when called concurrently with itself.""" + """recv_streaming raises ConcurrencyError when called concurrently.""" recv_streaming_thread = threading.Thread( target=lambda: list(self.connection.recv_streaming()) ) recv_streaming_thread.start() - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: for _ in self.connection.recv_streaming(): self.fail("did not raise") self.assertEqual( @@ -335,7 +339,7 @@ def test_send_connection_closed_error(self): self.connection.send("😀") def test_send_during_send(self): - """send raises RuntimeError when called concurrently with itself.""" + """send raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.remote_connection.recv) recv_thread.start() @@ -363,7 +367,7 @@ def fragments(): [b"\x01\x02", b"\xfe\xff"], ]: with self.subTest(message=message): - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: self.connection.send(message) self.assertEqual( str(raised.exception), @@ -653,7 +657,7 @@ def test_ping_duplicate_payload(self): with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("idem") - with self.assertRaises(RuntimeError) as raised: + with self.assertRaises(ConcurrencyError) as raised: self.connection.ping("idem") self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index c134b8304..d44b39b88 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,5 +1,6 @@ import time +from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from websockets.sync.messages import * @@ -411,40 +412,40 @@ def test_close_is_idempotent(self): # Test (non-)concurrency def test_get_fails_when_get_is_running(self): - """get cannot be called concurrently with itself.""" + """get cannot be called concurrently.""" with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): self.assembler.get() self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" with self.run_in_thread(self.assembler.get): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_get_iter_fails_when_get_iter_is_running(self): - """get_iter cannot be called concurrently with itself.""" + """get_iter cannot be called concurrently.""" with self.run_in_thread(lambda: list(self.assembler.get_iter())): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently with itself.""" + """put cannot be called concurrently.""" def putter(): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) with self.run_in_thread(putter): - with self.assertRaises(RuntimeError): + with self.assertRaises(ConcurrencyError): self.assembler.put(Frame(OP_BINARY, b"tea")) self.assembler.get() # unblock other thread diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 5620b8a53..8d41bf915 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -167,6 +167,10 @@ def test_str(self): InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", ), + ( + ConcurrencyError("get() or get_iter() is already running"), + "get() or get_iter() is already running", + ), ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str) From d19ed267b5e04ded752f78da6c85ad2905cf89e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 11 Sep 2024 07:01:12 +0200 Subject: [PATCH 623/762] Run spellcheck. --- docs/reference/variables.rst | 4 ++-- docs/spelling_wordlist.txt | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index b766e02a1..a24773074 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -80,8 +80,8 @@ Reconnection attempts are spaced out with truncated exponential backoff. The default value is ``90.0`` seconds. -Redirections ------------- +Redirects +--------- .. envvar:: WEBSOCKETS_MAX_REDIRECTS diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a1ba59a37..11b13250a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -59,6 +59,7 @@ reconnection redis redistributions retransmit +retryable runtime scalable stateful From 98f236f89529d317628fe8ee3d4d0564e675f68d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 10 Sep 2024 08:01:49 +0200 Subject: [PATCH 624/762] Run handler only when opening handshake succeeds. When process_request() or process_response() returned a HTTP response without calling accept() or reject() and with a status code other than 101, the connection handler used to start, which was incorrect. Fix #1419. Also move start_keepalive() outside of handshake() and bring it together with starting the connection handler, which is more logical. --- docs/project/changelog.rst | 7 +++++ src/websockets/asyncio/client.py | 5 ++-- src/websockets/asyncio/server.py | 11 ++++---- src/websockets/server.py | 23 +++++++++------- src/websockets/sync/server.py | 8 +++--- tests/asyncio/test_server.py | 10 +++++-- tests/sync/test_server.py | 5 +++- tests/test_server.py | 45 +++++++++++++++++++++++++++++--- 8 files changed, 88 insertions(+), 26 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f92ca68b6..69051d287 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -68,6 +68,13 @@ Improvements Previously, :exc:`RuntimeError` was raised. For backwards compatibility, :exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`. +Bug fixes +......... + +* The new :mod:`asyncio` and :mod:`threading` implementations of servers don't + start the connection handler anymore when ``process_request`` or + ``process_response`` returns a HTTP response. + 13.0.1 ------ diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 3985bfb6a..b1beb3e00 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -98,9 +98,7 @@ async def handshake( # before receiving a response, when the response cannot be parsed, or # when the response fails the handshake. - if self.protocol.handshake_exc is None: - self.start_keepalive() - else: + if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: @@ -465,6 +463,7 @@ async def __await_impl__(self) -> ClientConnection: raise uri_or_exc from exc else: + self.connection.start_keepalive() return self.connection else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 228b20012..78ee760d2 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -201,12 +201,11 @@ async def handshake( self.protocol.send_response(self.response) # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, or when - # the response fails the handshake. + # before receiving a request, when the request cannot be parsed, when + # the handshake encounters an error, or when process_request or + # process_response sends a HTTP response that rejects the handshake. - if self.protocol.handshake_exc is None: - self.start_keepalive() - else: + if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc def process_event(self, event: Event) -> None: @@ -369,7 +368,9 @@ async def conn_handler(self, connection: ServerConnection) -> None: connection.close_transport() return + assert connection.protocol.state is OPEN try: + connection.start_keepalive() await self.handler(connection) except Exception: connection.logger.error("connection handler failed", exc_info=True) diff --git a/src/websockets/server.py b/src/websockets/server.py index ac62800d6..b2671f402 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -204,7 +204,6 @@ def accept(self, request: Request) -> Response: if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - self.logger.info("connection open") return Response(101, "Switching Protocols", headers) def process_request( @@ -515,14 +514,7 @@ def reject(self, status: StatusLike, text: str) -> Response: ("Content-Type", "text/plain; charset=utf-8"), ] ) - response = Response(status.value, status.phrase, headers, body) - # When reject() is called from accept(), handshake_exc is already set. - # If a user calls reject(), set handshake_exc to guarantee invariant: - # "handshake_exc is None if and only if opening handshake succeeded." - if self.handshake_exc is None: - self.handshake_exc = InvalidStatus(response) - self.logger.info("connection rejected (%d %s)", status.value, status.phrase) - return response + return Response(status.value, status.phrase, headers, body) def send_response(self, response: Response) -> None: """ @@ -545,7 +537,20 @@ def send_response(self, response: Response) -> None: if response.status_code == 101: assert self.state is CONNECTING self.state = OPEN + self.logger.info("connection open") + else: + # handshake_exc may be already set if accept() encountered an error. + # If the connection isn't open, set handshake_exc to guarantee that + # handshake_exc is None if and only if opening handshake succeeded. + if self.handshake_exc is None: + self.handshake_exc = InvalidStatus(response) + self.logger.info( + "connection rejected (%d %s)", + response.status_code, + response.reason_phrase, + ) + self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index eb0536013..0b19201a9 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -23,7 +23,7 @@ validate_subprotocols, ) from ..http11 import SERVER, Request, Response -from ..protocol import CONNECTING, Event +from ..protocol import CONNECTING, OPEN, Event from ..server import ServerProtocol from ..typing import LoggerLike, Origin, StatusLike, Subprotocol from .connection import Connection @@ -166,8 +166,9 @@ def handshake( self.protocol.send_response(self.response) # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, or when - # the response fails the handshake. + # before receiving a request, when the request cannot be parsed, when + # the handshake encounters an error, or when process_request or + # process_response sends a HTTP response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -569,6 +570,7 @@ def protocol_select_subprotocol( connection.recv_events_thread.join() return + assert connection.protocol.state is OPEN try: handler(connection) except Exception: diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index fdcbf9780..47e0148a6 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -145,7 +145,10 @@ async def test_process_request_returns_response(self): def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with serve(*args, process_request=process_request) as server: + async def handler(ws): + self.fail("handler must not run") + + async with serve(handler, *args[1:], process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") @@ -160,7 +163,10 @@ async def test_async_process_request_returns_response(self): async def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - async with serve(*args, process_request=process_request) as server: + async def handler(ws): + self.fail("handler must not run") + + async with serve(handler, *args[1:], process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with connect(get_uri(server)): self.fail("did not raise") diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index d0d2c0955..3bc6f76cd 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -133,7 +133,10 @@ def test_process_request_returns_response(self): def process_request(ws, request): return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") - with run_server(process_request=process_request) as server: + def handler(ws): + self.fail("handler must not run") + + with run_server(handler, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: with connect(get_uri(server)): self.fail("did not raise") diff --git a/tests/test_server.py b/tests/test_server.py index d34c8e83d..52c8a2b99 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -106,10 +106,11 @@ def make_request(self): ), ) - def test_send_accept(self): + def test_send_response_after_successful_accept(self): server = ServerProtocol() + request = self.make_request() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(self.make_request()) + response = server.accept(request) self.assertIsInstance(response, Response) server.send_response(response) self.assertEqual( @@ -126,7 +127,32 @@ def test_send_accept(self): self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) - def test_send_reject(self): + def test_send_response_after_failed_accept(self): + server = ServerProtocol() + request = self.make_request() + del request.headers["Sec-WebSocket-Key"] + with unittest.mock.patch("email.utils.formatdate", return_value=DATE): + response = server.accept(request) + self.assertIsInstance(response, Response) + server.send_response(response) + self.assertEqual( + server.data_to_send(), + [ + f"HTTP/1.1 400 Bad Request\r\n" + f"Date: {DATE}\r\n" + f"Connection: close\r\n" + f"Content-Length: 94\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"\r\n" + f"Failed to open a WebSocket connection: " + f"missing Sec-WebSocket-Key header; 'sec-websocket-key'.\n".encode(), + b"", + ], + ) + self.assertTrue(server.close_expected()) + self.assertEqual(server.state, CONNECTING) + + def test_send_response_after_reject(self): server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") @@ -148,6 +174,19 @@ def test_send_reject(self): self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) + def test_send_response_without_accept_or_reject(self): + server = ServerProtocol() + server.send_response(Response(410, "Gone", Headers(), b"AWOL.\n")) + self.assertEqual( + server.data_to_send(), + [ + "HTTP/1.1 410 Gone\r\n\r\nAWOL.\n".encode(), + b"", + ], + ) + self.assertTrue(server.close_expected()) + self.assertEqual(server.state, CONNECTING) + def test_accept_response(self): server = ServerProtocol() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): From 206624a6a700b8ee572b1730df5072adc47a17e5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 14:01:59 +0200 Subject: [PATCH 625/762] Standard spelling on "an HTTP". --- docs/project/changelog.rst | 2 +- docs/reference/features.rst | 7 +++---- src/websockets/asyncio/server.py | 2 +- src/websockets/sync/server.py | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 69051d287..456c15dac 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -73,7 +73,7 @@ Bug fixes * The new :mod:`asyncio` and :mod:`threading` implementations of servers don't start the connection handler anymore when ``process_request`` or - ``process_response`` returns a HTTP response. + ``process_response`` returns an HTTP response. 13.0.1 ------ diff --git a/docs/reference/features.rst b/docs/reference/features.rst index eeade1462..8b04034eb 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -161,10 +161,9 @@ Client | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | | (`#784`_) | | | | | +------------------------------------+--------+--------+--------+--------+ - | Connect via a HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + | Connect via HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ - | Connect via a SOCKS5 proxy | ❌ | ❌ | — | ❌ | - | (`#475`_) | | | | | + | Connect via SOCKS5 proxy (`#475`_) | ❌ | ❌ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ .. _#364: https://github.com/python-websockets/websockets/issues/364 @@ -179,7 +178,7 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 -The server doesn't check the Host header and doesn't respond with a HTTP 400 Bad +The server doesn't check the Host header and doesn't respond with HTTP 400 Bad Request if it is missing or invalid (`#1246`). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 78ee760d2..19dae44b7 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -203,7 +203,7 @@ async def handshake( # self.protocol.handshake_exc is always set when the connection is lost # before receiving a request, when the request cannot be parsed, when # the handshake encounters an error, or when process_request or - # process_response sends a HTTP response that rejects the handshake. + # process_response sends an HTTP response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 0b19201a9..1b7cbb4b4 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -168,7 +168,7 @@ def handshake( # self.protocol.handshake_exc is always set when the connection is lost # before receiving a request, when the request cannot be parsed, when # the handshake encounters an error, or when process_request or - # process_response sends a HTTP response that rejects the handshake. + # process_response sends an HTTP response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc From 20739e03ec6ccb010391b5179315368a2dd3a594 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 13:58:33 +0200 Subject: [PATCH 626/762] Improve exception handling during handshake. Also refactor tests for Sans-I/O client and server. --- src/websockets/client.py | 8 +- src/websockets/server.py | 22 +- tests/test_client.py | 726 ++++++++++++++++++++------------------- tests/test_connection.py | 1 + tests/test_server.py | 658 +++++++++++++++++++++-------------- 5 files changed, 799 insertions(+), 616 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 0e36fd028..e5f294986 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -175,10 +175,10 @@ def process_response(self, response: Response) -> None: try: s_w_accept = headers["Sec-WebSocket-Accept"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Accept") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Accept") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) diff --git a/src/websockets/server.py b/src/websockets/server.py index b2671f402..006d5bdd5 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -253,10 +253,10 @@ def process_request( try: key = headers["Sec-WebSocket-Key"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Key") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Key") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None try: raw_key = base64.b64decode(key.encode(), validate=True) @@ -267,10 +267,10 @@ def process_request( try: version = headers["Sec-WebSocket-Version"] - except KeyError as exc: - raise InvalidHeader("Sec-WebSocket-Version") from exc - except MultipleValuesError as exc: - raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc + except KeyError: + raise InvalidHeader("Sec-WebSocket-Version") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) @@ -308,8 +308,8 @@ def process_origin(self, headers: Headers) -> Origin | None: # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. try: origin = headers.get("Origin") - except MultipleValuesError as exc: - raise InvalidHeader("Origin", "multiple values") from exc + except MultipleValuesError: + raise InvalidHeader("Origin", "multiple values") from None if origin is not None: origin = cast(Origin, origin) if self.origins is not None: @@ -503,7 +503,7 @@ def reject(self, status: StatusLike, text: str) -> Response: HTTP response to send to the client. """ - # If a user passes an int instead of a HTTPStatus, fix it automatically. + # If status is an int instead of an HTTPStatus, fix it automatically. status = http.HTTPStatus(status) body = text.encode() headers = Headers( diff --git a/tests/test_client.py b/tests/test_client.py index 47558c1c0..2468be85e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,11 +1,14 @@ +import contextlib +import dataclasses import logging +import types import unittest -import unittest.mock +from unittest.mock import patch from websockets.client import * from websockets.client import backoff from websockets.datastructures import Headers -from websockets.exceptions import InvalidHandshake, InvalidHeader +from websockets.exceptions import InvalidHandshake, InvalidHeader, InvalidStatus from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN @@ -22,13 +25,19 @@ from .utils import DATE, DeprecationTestCase -class ConnectTests(unittest.TestCase): - def test_send_connect(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("wss://example.com/test")) +URI = parse_uri("wss://example.com/test") # for tests where the URI doesn't matter + + +@patch("websockets.client.generate_key", return_value=KEY) +class BasicTests(unittest.TestCase): + """Test basic opening handshake scenarios.""" + + def test_send_request(self, _generate_key): + """Client sends a handshake request.""" + client = ClientProtocol(URI) request = client.connect() - self.assertIsInstance(request, Request) client.send_request(request) + self.assertEqual( client.data_to_send(), [ @@ -42,11 +51,56 @@ def test_send_connect(self): ], ) self.assertFalse(client.close_expected()) + self.assertEqual(client.state, CONNECTING) + + def test_receive_successful_response(self, _generate_key): + """Client receives a successful handshake response.""" + client = ClientProtocol(URI) + client.receive_data( + ( + f"HTTP/1.1 101 Switching Protocols\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {ACCEPT}\r\n" + f"Date: {DATE}\r\n" + f"\r\n" + ).encode(), + ) + + self.assertEqual(client.data_to_send(), []) + self.assertFalse(client.close_expected()) + self.assertEqual(client.state, OPEN) + + def test_receive_failed_response(self, _generate_key): + """Client receives a failed handshake response.""" + client = ClientProtocol(URI) + client.receive_data( + ( + f"HTTP/1.1 404 Not Found\r\n" + f"Date: {DATE}\r\n" + f"Content-Length: 13\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"Connection: close\r\n" + f"\r\n" + f"Sorry folks.\n" + ).encode(), + ) + + self.assertEqual(client.data_to_send(), [b""]) + self.assertTrue(client.close_expected()) + self.assertEqual(client.state, CONNECTING) + + +class RequestTests(unittest.TestCase): + """Test generating opening handshake requests.""" - def test_connect_request(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("wss://example.com/test")) + @patch("websockets.client.generate_key", return_value=KEY) + def test_connect(self, _generate_key): + """connect() creates an opening handshake request.""" + client = ClientProtocol(URI) request = client.connect() + + self.assertIsInstance(request, Request) self.assertEqual(request.path, "/test") self.assertEqual( request.headers, @@ -62,12 +116,14 @@ def test_connect_request(self): ) def test_path(self): + """connect() uses the path from the URI.""" client = ClientProtocol(parse_uri("wss://example.com/endpoint?test=1")) request = client.connect() self.assertEqual(request.path, "/endpoint?test=1") def test_port(self): + """connect() uses the port from the URI or the default port.""" for uri, host in [ ("ws://example.com/", "example.com"), ("ws://example.com:80/", "example.com"), @@ -83,85 +139,41 @@ def test_port(self): self.assertEqual(request.headers["Host"], host) def test_user_info(self): + """connect() perfoms HTTP Basic Authentication with user info from the URI.""" client = ClientProtocol(parse_uri("wss://hello:iloveyou@example.com/")) request = client.connect() self.assertEqual(request.headers["Authorization"], "Basic aGVsbG86aWxvdmV5b3U=") def test_origin(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - origin="https://example.com", - ) + """connect(origin=...) generates an Origin header.""" + client = ClientProtocol(URI, origin="https://example.com") request = client.connect() self.assertEqual(request.headers["Origin"], "https://example.com") def test_extensions(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory()], - ) + """connect(extensions=...) generates a Sec-WebSocket-Extensions header.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory()]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-op; op") def test_subprotocols(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["chat"], - ) + """connect(subprotocols=...) generates a Sec-WebSocket-Protocol header.""" + client = ClientProtocol(URI, subprotocols=["chat"]) request = client.connect() self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") -class AcceptRejectTests(unittest.TestCase): - def test_receive_accept(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() - client.receive_data( - ( - f"HTTP/1.1 101 Switching Protocols\r\n" - f"Upgrade: websocket\r\n" - f"Connection: Upgrade\r\n" - f"Sec-WebSocket-Accept: {ACCEPT}\r\n" - f"Date: {DATE}\r\n" - f"\r\n" - ).encode(), - ) - [response] = client.events_received() - self.assertIsInstance(response, Response) - self.assertEqual(client.data_to_send(), []) - self.assertFalse(client.close_expected()) - self.assertEqual(client.state, OPEN) - - def test_receive_reject(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() - client.receive_data( - ( - f"HTTP/1.1 404 Not Found\r\n" - f"Date: {DATE}\r\n" - f"Content-Length: 13\r\n" - f"Content-Type: text/plain; charset=utf-8\r\n" - f"Connection: close\r\n" - f"\r\n" - f"Sorry folks.\n" - ).encode(), - ) - [response] = client.events_received() - self.assertIsInstance(response, Response) - self.assertEqual(client.data_to_send(), [b""]) - self.assertTrue(client.close_expected()) - self.assertEqual(client.state, CONNECTING) +@patch("websockets.client.generate_key", return_value=KEY) +class ResponseTests(unittest.TestCase): + """Test receiving opening handshake responses.""" - def test_accept_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_successful_response(self, _generate_key): + """Client receives a successful handshake response.""" + client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" @@ -173,6 +185,7 @@ def test_accept_response(self): ).encode(), ) [response] = client.events_received() + self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( @@ -187,11 +200,11 @@ def test_accept_response(self): ), ) self.assertIsNone(response.body) + self.assertIsNone(client.handshake_exc) - def test_reject_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_failed_response(self, _generate_key): + """Client receives a failed handshake response.""" + client = ClientProtocol(URI) client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" @@ -204,6 +217,7 @@ def test_reject_response(self): ).encode(), ) [response] = client.events_received() + self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( @@ -218,394 +232,416 @@ def test_reject_response(self): ), ) self.assertEqual(response.body, b"Sorry folks.\n") + self.assertIsInstance(client.handshake_exc, InvalidStatus) + self.assertEqual( + str(client.handshake_exc), + "server rejected WebSocket connection: HTTP 404", + ) - def test_no_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_no_response(self, _generate_key): + """Client receives no handshake response.""" + client = ClientProtocol(URI) client.receive_eof() + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, EOFError) + self.assertEqual( + str(client.handshake_exc), + "connection closed while reading HTTP status line", + ) - def test_partial_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_truncated_response(self, _generate_key): + """Client receives a truncated handshake response.""" + client = ClientProtocol(URI) client.receive_data(b"HTTP/1.1 101 Switching Protocols\r\n") client.receive_eof() + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, EOFError) + self.assertEqual( + str(client.handshake_exc), + "connection closed while reading HTTP headers", + ) - def test_random_response(self): - with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): - client = ClientProtocol(parse_uri("ws://example.com/test")) - client.connect() + def test_receive_random_response(self, _generate_key): + """Client receives a junk handshake response.""" + client = ClientProtocol(URI) client.receive_data(b"220 smtp.invalid\r\n") client.receive_data(b"250 Hello relay.invalid\r\n") client.receive_data(b"250 Ok\r\n") client.receive_data(b"250 Ok\r\n") - client.receive_eof() - self.assertEqual(client.events_received(), []) - def make_accept_response(self, client): - request = client.connect() - return Response( - status_code=101, - reason_phrase="Switching Protocols", - headers=Headers( - { - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Accept": accept_key( - request.headers["Sec-WebSocket-Key"] - ), - } - ), + self.assertEqual(client.events_received(), []) + self.assertIsInstance(client.handshake_exc, ValueError) + self.assertEqual( + str(client.handshake_exc), + "invalid HTTP status line: 220 smtp.invalid", ) - def test_basic(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() +@contextlib.contextmanager +def alter_and_receive_response(client): + """Generate a handshake response that can be altered for testing.""" + # We could start by sending a handshake request, i.e.: + # request = client.connect() + # client.send_request(request) + # However, in the current implementation, these calls have no effect on the + # state of the client. Therefore, they're unnecessary and can be skipped. + response = Response( + status_code=101, + reason_phrase="Switching Protocols", + headers=Headers( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Accept": accept_key(client.key), + } + ), + ) + yield response + client.receive_data(response.serialize()) + [parsed_response] = client.events_received() + assert response == dataclasses.replace(parsed_response, _exception=None) + + +class HandshakeTests(unittest.TestCase): + """Test processing of handshake responses to configure the connection.""" + + def assertHandshakeSuccess(self, client): + """Assert that the opening handshake succeeded.""" self.assertEqual(client.state, OPEN) + self.assertIsNone(client.handshake_exc) - def test_missing_connection(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Connection"] - client.receive_data(response.serialize()) - [response] = client.events_received() - + def assertHandshakeError(self, client, exc_type, msg): + """Assert that the opening handshake failed with the given exception.""" self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Connection header") + self.assertIsInstance(client.handshake_exc, exc_type) + # Exception chaining isn't used is client handshake implementation. + assert client.handshake_exc.__cause__ is None + self.assertEqual(str(client.handshake_exc), msg) - def test_invalid_connection(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Connection"] - response.headers["Connection"] = "close" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_basic(self): + """Handshake succeeds.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "invalid Connection header: close") + self.assertHandshakeSuccess(client) - def test_missing_upgrade(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Upgrade"] - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_missing_connection(self): + """Handshake fails when the Connection header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Connection"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Connection header", + ) - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Upgrade header") + def test_invalid_connection(self): + """Handshake fails when the Connection header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Connection"] + response.headers["Connection"] = "close" + + self.assertHandshakeError( + client, + InvalidHeader, + "invalid Connection header: close", + ) - def test_invalid_upgrade(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Upgrade"] - response.headers["Upgrade"] = "h2c" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_missing_upgrade(self): + """Handshake fails when the Upgrade header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Upgrade"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Upgrade header", + ) - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c") + def test_invalid_upgrade(self): + """Handshake fails when the Upgrade header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Upgrade"] + response.headers["Upgrade"] = "h2c" + + self.assertHandshakeError( + client, + InvalidHeader, + "invalid Upgrade header: h2c", + ) def test_missing_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Sec-WebSocket-Accept"] - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header") + """Handshake fails when the Sec-WebSocket-Accept header is missing.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Sec-WebSocket-Accept"] + + self.assertHandshakeError( + client, + InvalidHeader, + "missing Sec-WebSocket-Accept header", + ) def test_multiple_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Accept"] = ACCEPT - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + """Handshake fails when the Sec-WebSocket-Accept header is repeated.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Accept"] = ACCEPT + + self.assertHandshakeError( + client, + InvalidHeader, "invalid Sec-WebSocket-Accept header: multiple values", ) def test_invalid_accept(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - del response.headers["Sec-WebSocket-Accept"] - response.headers["Sec-WebSocket-Accept"] = ACCEPT - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHeader) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), f"invalid Sec-WebSocket-Accept header: {ACCEPT}" + """Handshake fails when the Sec-WebSocket-Accept header is invalid.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + del response.headers["Sec-WebSocket-Accept"] + response.headers["Sec-WebSocket-Accept"] = ACCEPT + + self.assertHandshakeError( + client, + InvalidHeader, + f"invalid Sec-WebSocket-Accept header: {ACCEPT}", ) def test_no_extensions(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + """Handshake succeeds without extensions.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, []) - def test_no_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory()], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_offer_extension(self): + """Client offers an extension.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + request = client.connect() - self.assertEqual(client.state, OPEN) - self.assertEqual(client.extensions, [OpExtension()]) + self.assertEqual(request.headers["Sec-WebSocket-Extensions"], "x-rsv2") - def test_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientRsv2ExtensionFactory()], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_enable_extension(self): + """Client offers an extension and the server enables it.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [Rsv2Extension()]) - def test_unexpected_extension(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_extension_not_enabled(self): + """Client offers an extension, but the server doesn't enable it.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "no extensions supported") + self.assertHandshakeSuccess(client) + self.assertEqual(client.extensions, []) - def test_unsupported_extension(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientRsv2ExtensionFactory()], + def test_no_extensions_offered(self): + """Server enables an extension when the client didn't offer any.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + + self.assertHandshakeError( + client, + InvalidHandshake, + "no extensions supported", ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + def test_extension_not_offered(self): + """Server enables an extension that the client didn't offer.""" + client = ClientProtocol(URI, extensions=[ClientRsv2ExtensionFactory()]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + + self.assertHandshakeError( + client, + InvalidHandshake, "Unsupported extension: name = x-op, params = [('op', None)]", ) def test_supported_extension_parameters(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory("this")], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - client.receive_data(response.serialize()) - [response] = client.events_received() + """Server enables an extension with parameters supported by the client.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory("this")]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory("this")], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - client.receive_data(response.serialize()) - [response] = client.events_received() - - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), + """Server enables an extension with parameters unsupported by the client.""" + client = ClientProtocol(URI, extensions=[ClientOpExtensionFactory("this")]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" + + self.assertHandshakeError( + client, + InvalidHandshake, "Unsupported extension: name = x-op, params = [('op', 'that')]", ) def test_multiple_supported_extension_parameters(self): + """Client offers the same extension with several parameters.""" client = ClientProtocol( - parse_uri("wss://example.com/"), + URI, extensions=[ ClientOpExtensionFactory("this"), ClientOpExtensionFactory("that"), ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension("that")]) def test_multiple_extensions(self): + """Client offers several extensions and the server enables them.""" client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], + URI, + extensions=[ + ClientOpExtensionFactory(), + ClientRsv2ExtensionFactory(), + ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): + """Client respects the order of extensions chosen by the server.""" client = ClientProtocol( - parse_uri("wss://example.com/"), - extensions=[ClientOpExtensionFactory(), ClientRsv2ExtensionFactory()], + URI, + extensions=[ + ClientOpExtensionFactory(), + ClientRsv2ExtensionFactory(), + ], ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" - response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - client.receive_data(response.serialize()) - [response] = client.events_received() + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Extensions"] = "x-rsv2" + response.headers["Sec-WebSocket-Extensions"] = "x-op; op" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + """Handshake succeeds without subprotocols.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertIsNone(client.subprotocol) - def test_no_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) - response = self.make_accept_response(client) - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_no_subprotocol_requested(self): + """Client doesn't offer a subprotocol, but the server enables one.""" + client = ClientProtocol(URI) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, OPEN) - self.assertIsNone(client.subprotocol) + self.assertHandshakeError( + client, + InvalidHandshake, + "no subprotocols supported", + ) - def test_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/"), subprotocols=["chat"]) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_offer_subprotocol(self): + """Client offers a subprotocol.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + request = client.connect() - self.assertEqual(client.state, OPEN) - self.assertEqual(client.subprotocol, "chat") + self.assertEqual(request.headers["Sec-WebSocket-Protocol"], "chat") - def test_unexpected_subprotocol(self): - client = ClientProtocol(parse_uri("wss://example.com/")) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_enable_subprotocol(self): + """Client offers a subprotocol and the server enables it.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "no subprotocols supported") + self.assertHandshakeSuccess(client) + self.assertEqual(client.subprotocol, "chat") - def test_multiple_subprotocols(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "superchat" - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_no_subprotocol_accepted(self): + """Client offers a subprotocol, but the server doesn't enable it.""" + client = ClientProtocol(URI, subprotocols=["chat"]) + with alter_and_receive_response(client): + pass - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual( - str(raised.exception), - "invalid Sec-WebSocket-Protocol header: " - "multiple values: superchat, chat", - ) + self.assertHandshakeSuccess(client) + self.assertIsNone(client.subprotocol) - def test_supported_subprotocol(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], - ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "chat" - client.receive_data(response.serialize()) - [response] = client.events_received() + def test_multiple_subprotocols(self): + """Client offers several subprotocols and the server enables one.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "chat" - self.assertEqual(client.state, OPEN) + self.assertHandshakeSuccess(client) self.assertEqual(client.subprotocol, "chat") def test_unsupported_subprotocol(self): - client = ClientProtocol( - parse_uri("wss://example.com/"), - subprotocols=["superchat", "chat"], + """Client offers subprotocols but the server enables another one.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "otherchat" + + self.assertHandshakeError( + client, + InvalidHandshake, + "unsupported subprotocol: otherchat", ) - response = self.make_accept_response(client) - response.headers["Sec-WebSocket-Protocol"] = "otherchat" - client.receive_data(response.serialize()) - [response] = client.events_received() - self.assertEqual(client.state, CONNECTING) - with self.assertRaises(InvalidHandshake) as raised: - raise client.handshake_exc - self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat") + def test_multiple_subprotocols_accepted(self): + """Server attempts to enable multiple subprotocols.""" + client = ClientProtocol(URI, subprotocols=["superchat", "chat"]) + with alter_and_receive_response(client) as response: + response.headers["Sec-WebSocket-Protocol"] = "superchat" + response.headers["Sec-WebSocket-Protocol"] = "chat" + + self.assertHandshakeError( + client, + InvalidHandshake, + "invalid Sec-WebSocket-Protocol header: " + "multiple values: superchat, chat", + ) class MiscTests(unittest.TestCase): def test_bypass_handshake(self): - client = ClientProtocol(parse_uri("ws://example.com/test"), state=OPEN) + """ClientProtocol bypasses the opening handshake.""" + client = ClientProtocol(URI, state=OPEN) client.receive_data(b"\x81\x06Hello!") [frame] = client.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): + """ClientProtocol accepts a logger argument.""" logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: - ClientProtocol(parse_uri("wss://example.com/test"), logger=logger) + ClientProtocol(URI, logger=logger) self.assertEqual(len(logs.records), 1) class BackwardsCompatibilityTests(DeprecationTestCase): def test_client_connection_class(self): + """ClientConnection is a deprecated alias for ClientProtocol.""" with self.assertDeprecationWarning( "ClientConnection was renamed to ClientProtocol" ): @@ -618,7 +654,9 @@ def test_client_connection_class(self): class BackoffTests(unittest.TestCase): def test_backoff(self): + """backoff() yields a random delay, then exponentially increasing delays.""" backoff_gen = backoff() + self.assertIsInstance(backoff_gen, types.GeneratorType) initial_delay = next(backoff_gen) self.assertGreaterEqual(initial_delay, 0) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6592d67d0..9ad2ebea4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,6 +5,7 @@ class BackwardsCompatibilityTests(DeprecationTestCase): def test_connection_class(self): + """Connection is a deprecated alias for Protocol.""" with self.assertDeprecationWarning( "websockets.connection was renamed to websockets.protocol " "and Connection was renamed to Protocol" diff --git a/tests/test_server.py b/tests/test_server.py index 52c8a2b99..844ba64ec 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,7 +1,8 @@ import http import logging +import sys import unittest -import unittest.mock +from unittest.mock import patch from websockets.datastructures import Headers from websockets.exceptions import ( @@ -25,8 +26,28 @@ from .utils import DATE, DeprecationTestCase -class ConnectTests(unittest.TestCase): - def test_receive_connect(self): +def make_request(): + """Generate a handshake request that can be altered for testing.""" + return Request( + path="/test", + headers=Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + } + ), + ) + + +@patch("email.utils.formatdate", return_value=DATE) +class BasicTests(unittest.TestCase): + """Test basic opening handshake scenarios.""" + + def test_receive_request(self, _formatdate): + """Server receives a handshake request.""" server = ServerProtocol() server.receive_data( ( @@ -39,80 +60,18 @@ def test_receive_connect(self): f"\r\n" ).encode(), ) - [request] = server.events_received() - self.assertIsInstance(request, Request) + self.assertEqual(server.data_to_send(), []) self.assertFalse(server.close_expected()) + self.assertEqual(server.state, CONNECTING) - def test_connect_request(self): - server = ServerProtocol() - server.receive_data( - ( - f"GET /test HTTP/1.1\r\n" - f"Host: example.com\r\n" - f"Upgrade: websocket\r\n" - f"Connection: Upgrade\r\n" - f"Sec-WebSocket-Key: {KEY}\r\n" - f"Sec-WebSocket-Version: 13\r\n" - f"\r\n" - ).encode(), - ) - [request] = server.events_received() - self.assertEqual(request.path, "/test") - self.assertEqual( - request.headers, - Headers( - { - "Host": "example.com", - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Key": KEY, - "Sec-WebSocket-Version": "13", - } - ), - ) - - def test_no_request(self): - server = ServerProtocol() - server.receive_eof() - self.assertEqual(server.events_received(), []) - - def test_partial_request(self): - server = ServerProtocol() - server.receive_data(b"GET /test HTTP/1.1\r\n") - server.receive_eof() - self.assertEqual(server.events_received(), []) - - def test_junk_request(self): - server = ServerProtocol() - server.receive_data(b"HELO relay.invalid\r\n") - server.receive_data(b"MAIL FROM: \r\n") - server.receive_data(b"RCPT TO: \r\n") - self.assertEqual(server.events_received(), []) - - -class AcceptRejectTests(unittest.TestCase): - def make_request(self): - return Request( - path="/test", - headers=Headers( - { - "Host": "example.com", - "Upgrade": "websocket", - "Connection": "Upgrade", - "Sec-WebSocket-Key": KEY, - "Sec-WebSocket-Version": "13", - } - ), - ) - - def test_send_response_after_successful_accept(self): + def test_accept_and_send_successful_response(self, _formatdate): + """Server accepts a handshake request and sends a successful response.""" server = ServerProtocol() - request = self.make_request() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(request) - self.assertIsInstance(response, Response) + request = make_request() + response = server.accept(request) server.send_response(response) + self.assertEqual( server.data_to_send(), [ @@ -127,37 +86,37 @@ def test_send_response_after_successful_accept(self): self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) - def test_send_response_after_failed_accept(self): + def test_send_response_after_failed_accept(self, _formatdate): + """Server accepts a handshake request but sends a failed response.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(request) - self.assertIsInstance(response, Response) + response = server.accept(request) server.send_response(response) + self.assertEqual( server.data_to_send(), [ f"HTTP/1.1 400 Bad Request\r\n" f"Date: {DATE}\r\n" f"Connection: close\r\n" - f"Content-Length: 94\r\n" + f"Content-Length: 73\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"\r\n" f"Failed to open a WebSocket connection: " - f"missing Sec-WebSocket-Key header; 'sec-websocket-key'.\n".encode(), + f"missing Sec-WebSocket-Key header.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_send_response_after_reject(self): + def test_send_response_after_reject(self, _formatdate): + """Server rejects a handshake request and sends a failed response.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") - self.assertIsInstance(response, Response) + response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") server.send_response(response) + self.assertEqual( server.data_to_send(), [ @@ -174,23 +133,124 @@ def test_send_response_after_reject(self): self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_send_response_without_accept_or_reject(self): + def test_send_response_without_accept_or_reject(self, _formatdate): + """Server doesn't accept or reject and sends a failed response.""" server = ServerProtocol() - server.send_response(Response(410, "Gone", Headers(), b"AWOL.\n")) + server.send_response( + Response( + 410, + "Gone", + Headers( + { + "Connection": "close", + "Content-Length": 6, + "Content-Type": "text/plain", + } + ), + b"AWOL.\n", + ) + ) self.assertEqual( server.data_to_send(), [ - "HTTP/1.1 410 Gone\r\n\r\nAWOL.\n".encode(), + "HTTP/1.1 410 Gone\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "AWOL.\n".encode(), b"", ], ) self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) - def test_accept_response(self): + +class RequestTests(unittest.TestCase): + """Test receiving opening handshake requests.""" + + def test_receive_request(self): + """Server receives a handshake request.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.accept(self.make_request()) + server.receive_data( + ( + f"GET /test HTTP/1.1\r\n" + f"Host: example.com\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {KEY}\r\n" + f"Sec-WebSocket-Version: 13\r\n" + f"\r\n" + ).encode(), + ) + [request] = server.events_received() + + self.assertIsInstance(request, Request) + self.assertEqual(request.path, "/test") + self.assertEqual( + request.headers, + Headers( + { + "Host": "example.com", + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": KEY, + "Sec-WebSocket-Version": "13", + } + ), + ) + self.assertIsNone(server.handshake_exc) + + def test_receive_no_request(self): + """Server receives no handshake request.""" + server = ServerProtocol() + server.receive_eof() + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual( + str(server.handshake_exc), + "connection closed while reading HTTP request line", + ) + + def test_receive_truncated_request(self): + """Server receives a truncated handshake request.""" + server = ServerProtocol() + server.receive_data(b"GET /test HTTP/1.1\r\n") + server.receive_eof() + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual( + str(server.handshake_exc), + "connection closed while reading HTTP headers", + ) + + def test_receive_junk_request(self): + """Server receives a junk handshake request.""" + server = ServerProtocol() + server.receive_data(b"HELO relay.invalid\r\n") + server.receive_data(b"MAIL FROM: \r\n") + server.receive_data(b"RCPT TO: \r\n") + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, ValueError) + self.assertEqual( + str(server.handshake_exc), + "invalid HTTP request line: HELO relay.invalid", + ) + + +class ResponseTests(unittest.TestCase): + """Test generating opening handshake responses.""" + + @patch("email.utils.formatdate", return_value=DATE) + def test_accept_response(self, _formatdate): + """accept() creates a successful opening handshake response.""" + server = ServerProtocol() + request = make_request() + response = server.accept(request) + self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") @@ -207,10 +267,12 @@ def test_accept_response(self): ) self.assertIsNone(response.body) - def test_reject_response(self): + @patch("email.utils.formatdate", return_value=DATE) + def test_reject_response(self, _formatdate): + """reject() creates a failed opening handshake response.""" server = ServerProtocol() - with unittest.mock.patch("email.utils.formatdate", return_value=DATE): - response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") + self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") @@ -228,477 +290,552 @@ def test_reject_response(self): self.assertEqual(response.body, b"Sorry folks.\n") def test_reject_response_supports_int_status(self): + """reject() accepts an integer status code instead of an HTTPStatus.""" server = ServerProtocol() response = server.reject(404, "Sorry folks.\n") + self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") - def test_basic(self): + @patch("websockets.server.ServerProtocol.process_request") + def test_unexpected_error(self, process_request): + """accept() handles unexpected errors and returns an error response.""" server = ServerProtocol() - request = self.make_request() + request = make_request() + process_request.side_effect = (Exception("BOOM"),) response = server.accept(request) - self.assertEqual(response.status_code, 101) + self.assertEqual(response.status_code, 500) + self.assertIsInstance(server.handshake_exc, Exception) + self.assertEqual(str(server.handshake_exc), "BOOM") - def test_unexpected_exception(self): + +class HandshakeTests(unittest.TestCase): + """Test processing of handshake responses to configure the connection.""" + + def assertHandshakeSuccess(self, server): + """Assert that the opening handshake succeeded.""" + self.assertEqual(server.state, OPEN) + self.assertIsNone(server.handshake_exc) + + def assertHandshakeError(self, server, exc_type, msg): + """Assert that the opening handshake failed with the given exception.""" + self.assertEqual(server.state, CONNECTING) + self.assertIsInstance(server.handshake_exc, exc_type) + exc = server.handshake_exc + exc_str = str(exc) + while exc.__cause__ is not None: + exc = exc.__cause__ + exc_str += "; " + str(exc) + self.assertEqual(exc_str, msg) + + def test_basic(self): + """Handshake succeeds.""" server = ServerProtocol() - request = self.make_request() - with unittest.mock.patch( - "websockets.server.ServerProtocol.process_request", - side_effect=Exception("BOOM"), - ): - response = server.accept(request) + request = make_request() + response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 500) - with self.assertRaises(Exception) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "BOOM", - ) + self.assertHandshakeSuccess(server) def test_missing_connection(self): + """Handshake fails when the Connection header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Connection"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "missing Connection header", ) def test_invalid_connection(self): + """Handshake fails when the Connection header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Connection"] request.headers["Connection"] = "close" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "invalid Connection header: close", ) def test_missing_upgrade(self): + """Handshake fails when the Upgrade header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Upgrade"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "missing Upgrade header", ) def test_invalid_upgrade(self): + """Handshake fails when the Upgrade header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Upgrade"] request.headers["Upgrade"] = "h2c" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 426) self.assertEqual(response.headers["Upgrade"], "websocket") - with self.assertRaises(InvalidUpgrade) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidUpgrade, "invalid Upgrade header: h2c", ) def test_missing_key(self): + """Handshake fails when the Sec-WebSocket-Key header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "missing Sec-WebSocket-Key header", ) def test_multiple_key(self): + """Handshake fails when the Sec-WebSocket-Key header is repeated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Key"] = KEY response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Key header: multiple values", ) def test_invalid_key(self): + """Handshake fails when the Sec-WebSocket-Key header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - request.headers["Sec-WebSocket-Key"] = "not Base64 data!" + request.headers["Sec-WebSocket-Key"] = "" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "invalid Sec-WebSocket-Key header: not Base64 data!", + if sys.version_info[:2] >= (3, 11): + b64_exc = "Only base64 data is allowed" + else: # pragma: no cover + b64_exc = "Non-base64 digit found" + self.assertHandshakeError( + server, + InvalidHeader, + f"invalid Sec-WebSocket-Key header: ; {b64_exc}", ) def test_truncated_key(self): + """Handshake fails when the Sec-WebSocket-Key header is truncated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Key"] - request.headers["Sec-WebSocket-Key"] = KEY[ - :16 - ] # 12 bytes instead of 16, Base64-encoded + # 12 bytes instead of 16, Base64-encoded + request.headers["Sec-WebSocket-Key"] = KEY[:16] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, f"invalid Sec-WebSocket-Key header: {KEY[:16]}", ) def test_missing_version(self): + """Handshake fails when the Sec-WebSocket-Version header is missing.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Version"] response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "missing Sec-WebSocket-Version header", ) def test_multiple_version(self): + """Handshake fails when the Sec-WebSocket-Version header is repeated.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Version header: multiple values", ) def test_invalid_version(self): + """Handshake fails when the Sec-WebSocket-Version header is invalid.""" server = ServerProtocol() - request = self.make_request() + request = make_request() del request.headers["Sec-WebSocket-Version"] request.headers["Sec-WebSocket-Version"] = "11" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Sec-WebSocket-Version header: 11", ) - def test_no_origin(self): + def test_origin(self): + """Handshake succeeds when checking origin.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() + request = make_request() + request.headers["Origin"] = "https://example.com" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), - "missing Origin header", - ) + self.assertHandshakeSuccess(server) + self.assertEqual(server.origin, "https://example.com") - def test_origin(self): + def test_no_origin(self): + """Handshake fails when checking origin and the Origin header is missing.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() - request.headers["Origin"] = "https://example.com" + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) - self.assertEqual(server.origin, "https://example.com") + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "missing Origin header", + ) def test_unexpected_origin(self): + """Handshake fails when checking origin and the Origin header is unexpected.""" server = ServerProtocol(origins=["https://example.com"]) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidOrigin, "invalid Origin header: https://other.example.com", ) def test_multiple_origin(self): + """Handshake fails when checking origins and the Origin header is repeated.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://example.com" request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) # This is prohibited by the HTTP specification, so the return code is # 400 Bad Request rather than 403 Forbidden. self.assertEqual(response.status_code, 400) - with self.assertRaises(InvalidHeader) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidHeader, "invalid Origin header: multiple values", ) def test_supported_origin(self): + """Handshake succeeds when checking origins and the origin is supported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://other.example.com" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): + """Handshake succeeds when checking origins and the origin is unsupported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) - request = self.make_request() + request = make_request() request.headers["Origin"] = "https://original.example.com" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 403) - with self.assertRaises(InvalidOrigin) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + InvalidOrigin, "invalid Origin header: https://original.example.com", ) def test_no_origin_accepted(self): + """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertIsNone(server.origin) def test_no_extensions(self): + """Handshake succeeds without extensions.""" server = ServerProtocol() - request = self.make_request() - response = server.accept(request) - - self.assertEqual(response.status_code, 101) - self.assertNotIn("Sec-WebSocket-Extensions", response.headers) - self.assertEqual(server.extensions, []) - - def test_no_extension(self): - server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_extension(self): + """Server enables an extension when the client offers it.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op") self.assertEqual(server.extensions, [OpExtension()]) - def test_unexpected_extension(self): + def test_extension_not_enabled(self): + """Server doesn't enable an extension when the client doesn't offer it.""" + server = ServerProtocol(extensions=[ServerOpExtensionFactory()]) + request = make_request() + response = server.accept(request) + server.send_response(response) + + self.assertHandshakeSuccess(server) + self.assertNotIn("Sec-WebSocket-Extensions", response.headers) + self.assertEqual(server.extensions, []) + + def test_no_extensions_supported(self): + """Client offers an extension, but the server doesn't support any.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) - def test_unsupported_extension(self): + def test_extension_not_supported(self): + """Client offers an extension, but the server doesn't support it.""" server = ServerProtocol(extensions=[ServerRsv2ExtensionFactory()]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_supported_extension_parameters(self): + """Client offers an extension with parameters supported by the server.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=this" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=this") self.assertEqual(server.extensions, [OpExtension("this")]) def test_unsupported_extension_parameters(self): + """Client offers an extension with parameters unsupported by the server.""" server = ServerProtocol(extensions=[ServerOpExtensionFactory("this")]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Extensions", response.headers) self.assertEqual(server.extensions, []) def test_multiple_supported_extension_parameters(self): + """Server supports the same extension with several parameters.""" server = ServerProtocol( extensions=[ ServerOpExtensionFactory("this"), ServerOpExtensionFactory("that"), ] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op=that" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Extensions"], "x-op; op=that") self.assertEqual(server.extensions, [OpExtension("that")]) def test_multiple_extensions(self): + """Server enables several extensions when the client offers them.""" server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-op; op" request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual( response.headers["Sec-WebSocket-Extensions"], "x-op; op, x-rsv2" ) self.assertEqual(server.extensions, [OpExtension(), Rsv2Extension()]) def test_multiple_extensions_order(self): + """Server respects the order of extensions set in its configuration.""" server = ServerProtocol( extensions=[ServerOpExtensionFactory(), ServerRsv2ExtensionFactory()] ) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Extensions"] = "x-rsv2" request.headers["Sec-WebSocket-Extensions"] = "x-op; op" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual( response.headers["Sec-WebSocket-Extensions"], "x-rsv2, x-op; op" ) self.assertEqual(server.extensions, [Rsv2Extension(), OpExtension()]) def test_no_subprotocols(self): + """Handshake succeeds without subprotocols.""" server = ServerProtocol() - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) - def test_no_subprotocol(self): + def test_no_subprotocol_requested(self): + """Server expects a subprotocol, but the client doesn't offer it.""" server = ServerProtocol(subprotocols=["chat"]) - request = self.make_request() + request = make_request() response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(NegotiationError) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + NegotiationError, "missing subprotocol", ) def test_subprotocol(self): + """Server enables a subprotocol when the client offers it.""" server = ServerProtocol(subprotocols=["chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") - def test_unexpected_subprotocol(self): + def test_no_subprotocols_supported(self): + """Client offers a subprotocol, but the server doesn't support any.""" server = ServerProtocol() - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) def test_multiple_subprotocols(self): + """Server enables all of the subprotocols when the client offers them.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" request.headers["Sec-WebSocket-Protocol"] = "superchat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "superchat") self.assertEqual(server.subprotocol, "superchat") def test_supported_subprotocol(self): + """Server enables one of the subprotocols when the client offers it.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_unsupported_subprotocol(self): + """Server expects one of the subprotocols, but the client doesn't offer any.""" server = ServerProtocol(subprotocols=["superchat", "chat"]) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) + server.send_response(response) self.assertEqual(response.status_code, 400) - with self.assertRaises(NegotiationError) as raised: - raise server.handshake_exc - self.assertEqual( - str(raised.exception), + self.assertHandshakeError( + server, + NegotiationError, "invalid subprotocol; expected one of superchat, chat", ) @@ -708,34 +845,40 @@ def optional_chat(protocol, subprotocols): return "chat" def test_select_subprotocol(self): + """Server enables a subprotocol with select_subprotocol.""" server = ServerProtocol(select_subprotocol=self.optional_chat) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "chat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat") self.assertEqual(server.subprotocol, "chat") def test_select_no_subprotocol(self): + """Server doesn't enable any subprotocol with select_subprotocol.""" server = ServerProtocol(select_subprotocol=self.optional_chat) - request = self.make_request() + request = make_request() request.headers["Sec-WebSocket-Protocol"] = "otherchat" response = server.accept(request) + server.send_response(response) - self.assertEqual(response.status_code, 101) + self.assertHandshakeSuccess(server) self.assertNotIn("Sec-WebSocket-Protocol", response.headers) self.assertIsNone(server.subprotocol) class MiscTests(unittest.TestCase): def test_bypass_handshake(self): + """ServerProtocol bypasses the opening handshake.""" server = ServerProtocol(state=OPEN) server.receive_data(b"\x81\x86\x00\x00\x00\x00Hello!") [frame] = server.events_received() self.assertEqual(frame, Frame(OP_TEXT, b"Hello!")) def test_custom_logger(self): + """ServerProtocol accepts a logger argument.""" logger = logging.getLogger("test") with self.assertLogs("test", logging.DEBUG) as logs: ServerProtocol(logger=logger) @@ -744,6 +887,7 @@ def test_custom_logger(self): class BackwardsCompatibilityTests(DeprecationTestCase): def test_server_connection_class(self): + """ServerConnection is a deprecated alias for ServerProtocol.""" with self.assertDeprecationWarning( "ServerConnection was renamed to ServerProtocol" ): From 36409237c9377fb8fc2fbb90f9f74192777b518c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:04:08 +0200 Subject: [PATCH 627/762] Wait until state is CLOSED to acces close_exc. Fix #1449. --- docs/project/changelog.rst | 4 ++++ src/websockets/asyncio/connection.py | 14 +++++++++++--- src/websockets/sync/connection.py | 14 ++++++++++++-- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 456c15dac..615d3ab71 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -75,6 +75,10 @@ Bug fixes start the connection handler anymore when ``process_request`` or ``process_response`` returns an HTTP response. +* Fixed a bug in the :mod:`threading` implementation that could lead to + incorrect error reporting when closing a connection while + :meth:`~sync.connection.Connection.recv` is running. + 13.0.1 ------ diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 1b24f9af0..6af61a4a9 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -274,6 +274,8 @@ async def recv(self, decode: bool | None = None) -> Data: try: return await self.recv_messages.get(decode) except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -329,6 +331,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data async for frame in self.recv_messages.get_iter(decode): yield frame except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -864,6 +868,7 @@ async def send_context( # raise an exception. if raise_close_exc: self.close_transport() + # Wait for the protocol state to be CLOSED before accessing close_exc. await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from original_exc @@ -926,11 +931,14 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = transport def connection_lost(self, exc: Exception | None) -> None: - self.protocol.receive_eof() # receive_eof is idempotent + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED - # Abort recv() and pending pings with a ConnectionClosed exception. - # Set recv_exc first to get proper exception reporting. self.set_recv_exc(exc) + + # Abort recv() and pending pings with a ConnectionClosed exception. self.recv_messages.close() self.abort_pings() diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 65a7b63ed..77b488093 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -206,6 +206,8 @@ def recv(self, timeout: float | None = None) -> Data: try: return self.recv_messages.get(timeout) except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -240,6 +242,8 @@ def recv_streaming(self) -> Iterator[Data]: for frame in self.recv_messages.get_iter(): yield frame except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -629,8 +633,6 @@ def recv_events(self) -> None: self.logger.error("unexpected internal error", exc_info=True) with self.protocol_mutex: self.set_recv_exc(exc) - # We don't know where we crashed. Force protocol state to CLOSED. - self.protocol.state = CLOSED finally: # This isn't expected to raise an exception. self.close_socket() @@ -738,6 +740,7 @@ def send_context( # raise an exception. if raise_close_exc: self.close_socket() + # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() raise self.protocol.close_exc from original_exc @@ -788,4 +791,11 @@ def close_socket(self) -> None: except OSError: pass # socket is already closed self.socket.close() + + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + # Abort recv() with a ConnectionClosed exception. self.recv_messages.close() From 0afccc956639684af5082513da8ba8f5105448f5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:12:07 +0200 Subject: [PATCH 628/762] Clarify comment. --- src/websockets/sync/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 77b488093..97588870e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -87,9 +87,9 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: dict[bytes, threading.Event] = {} - # Receiving events from the socket. This thread explicitly is marked as - # to support creating a connection in a non-daemon thread then using it - # in a daemon thread; this shouldn't block the intpreter from exiting. + # Receiving events from the socket. This thread is marked as daemon to + # allow creating a connection in a non-daemon thread and using it in a + # daemon thread. This mustn't prevent the interpreter from exiting. self.recv_events_thread = threading.Thread( target=self.recv_events, daemon=True, From 4d229bf9f583d593aa103287aee0a77c9fbc3a79 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:13:12 +0200 Subject: [PATCH 629/762] Release version 13.1. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 615d3ab71..7e4bce9c6 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 13.1 ---- -*In development* +*September 21, 2024* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index bbda56d6b..00b0a985e 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "13.1" From 37c7f6529a0877c87f191bc566d2e14f7c96e192 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:16:44 +0200 Subject: [PATCH 630/762] Start version 14.0. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7e4bce9c6..65e26008f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.0: + +14.0 +---- + +*In development* + .. _13.1: 13.1 diff --git a/src/websockets/version.py b/src/websockets/version.py index 00b0a985e..34fc2eaef 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "13.1" +tag = version = commit = "14.0" if not released: # pragma: no cover From 44ccee17c519ea1397ee28a8ac3a7d7685cd0b89 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:30:56 +0200 Subject: [PATCH 631/762] Drop Python 3.8. It is EOL at the end of October. --- .github/workflows/tests.yml | 1 - docs/howto/django.rst | 3 +-- docs/intro/index.rst | 2 +- docs/project/changelog.rst | 14 +++++++++----- pyproject.toml | 3 +-- src/websockets/asyncio/client.py | 5 +++-- src/websockets/asyncio/connection.py | 11 ++--------- src/websockets/asyncio/messages.py | 10 ++-------- src/websockets/asyncio/server.py | 16 ++++------------ src/websockets/client.py | 7 ++++--- src/websockets/datastructures.py | 15 +++------------ src/websockets/extensions/base.py | 2 +- src/websockets/extensions/permessage_deflate.py | 3 ++- src/websockets/frames.py | 3 ++- src/websockets/headers.py | 3 ++- src/websockets/http11.py | 3 ++- src/websockets/imports.py | 3 ++- src/websockets/legacy/auth.py | 6 +++--- src/websockets/legacy/client.py | 10 ++-------- src/websockets/legacy/framing.py | 3 ++- src/websockets/legacy/protocol.py | 13 ++----------- src/websockets/legacy/server.py | 16 +++------------- src/websockets/protocol.py | 7 ++++--- src/websockets/server.py | 5 +++-- src/websockets/streams.py | 2 +- src/websockets/sync/client.py | 3 ++- src/websockets/sync/connection.py | 6 +++--- src/websockets/sync/messages.py | 6 +++--- src/websockets/sync/server.py | 7 ++++--- src/websockets/typing.py | 8 +++----- tests/asyncio/test_client.py | 8 -------- tox.ini | 1 - 32 files changed, 76 insertions(+), 129 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 43193ea50..beaf9d12b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -55,7 +55,6 @@ jobs: strategy: matrix: python: - - "3.8" - "3.9" - "3.10" - "3.11" diff --git a/docs/howto/django.rst b/docs/howto/django.rst index dada9c5e4..4fe2311cb 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -121,8 +121,7 @@ authentication fails, it closes the connection and exits. When we call an API that makes a database query such as ``get_user()``, we wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't support asynchronous I/O. It would block the event loop if it didn't run in a -separate thread. :func:`~asyncio.to_thread` is available since Python 3.9. In -earlier versions, use :meth:`~asyncio.loop.run_in_executor` instead. +separate thread. Finally, we start a server with :func:`~websockets.asyncio.server.serve`. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 095262a20..642e50094 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -6,7 +6,7 @@ Getting started Requirements ------------ -websockets requires Python ≥ 3.8. +websockets requires Python ≥ 3.9. .. admonition:: Use the most recent Python release :class: tip diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 65e26008f..5f07fc09f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,15 @@ notice. *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: websockets 14.0 requires Python ≥ 3.9. + :class: tip + + websockets 13.1 is the last version supporting Python 3.8. + + .. _13.1: 13.1 @@ -106,11 +115,6 @@ Bug fixes Backwards-incompatible changes .............................. -.. admonition:: websockets 13.0 requires Python ≥ 3.8. - :class: tip - - websockets 12.0 is the last version supporting Python 3.7. - .. admonition:: Receiving the request path in the second parameter of connection handlers is deprecated. :class: note diff --git a/pyproject.toml b/pyproject.toml index fde9c3226..6a0ab8d7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "websockets" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { text = "BSD-3-Clause" } authors = [ { name = "Aymeric Augustin", email = "aymeric.augustin@m4x.org" }, @@ -19,7 +19,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b1beb3e00..23b1a348a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -4,8 +4,9 @@ import logging import os import urllib.parse +from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, AsyncIterator, Callable, Generator, Sequence +from typing import Any, Callable from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike @@ -492,7 +493,7 @@ async def __aexit__( # async for ... in connect(...): async def __aiter__(self) -> AsyncIterator[ClientConnection]: - delays: Generator[float, None, None] | None = None + delays: Generator[float] | None = None while True: try: async with self as protocol: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 6af61a4a9..702e69995 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -8,16 +8,9 @@ import struct import sys import uuid +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Iterable, - Mapping, - cast, -) +from typing import Any, cast from ..exceptions import ( ConcurrencyError, diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index c2b4afd67..e3ec5062f 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -3,14 +3,8 @@ import asyncio import codecs import collections -from typing import ( - Any, - AsyncIterator, - Callable, - Generic, - Iterable, - TypeVar, -) +from collections.abc import AsyncIterator, Iterable +from typing import Any, Callable, Generic, TypeVar from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 19dae44b7..e11dd91f1 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -6,17 +6,9 @@ import logging import socket import sys +from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, - Tuple, - cast, -) +from typing import Any, Callable, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -905,9 +897,9 @@ def basic_auth( if credentials is not None: if is_credentials(credentials): - credentials_list = [cast(Tuple[str, str], credentials)] + credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/client.py b/src/websockets/client.py index e5f294986..bce82d66b 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -3,7 +3,8 @@ import os import random import warnings -from typing import Any, Generator, Sequence +from collections.abc import Generator, Sequence +from typing import Any from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -313,7 +314,7 @@ def send_request(self, request: Request) -> None: self.writes.append(request.serialize()) - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: if self.state is CONNECTING: try: response = yield from Response.parse( @@ -374,7 +375,7 @@ def backoff( min_delay: float = BACKOFF_MIN_DELAY, max_delay: float = BACKOFF_MAX_DELAY, factor: float = BACKOFF_FACTOR, -) -> Generator[float, None, None]: +) -> Generator[float]: """ Generate a series of backoff delays between reconnection attempts. diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 106d6f393..77b6f86fa 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -1,15 +1,7 @@ from __future__ import annotations -from typing import ( - Any, - Iterable, - Iterator, - Mapping, - MutableMapping, - Protocol, - Tuple, - Union, -) +from collections.abc import Iterable, Iterator, Mapping, MutableMapping +from typing import Any, Protocol, Union __all__ = ["Headers", "HeadersLike", "MultipleValuesError"] @@ -179,8 +171,7 @@ def __getitem__(self, key: str) -> str: ... HeadersLike = Union[ Headers, Mapping[str, str], - # Change to tuple[str, str] when dropping Python < 3.9. - Iterable[Tuple[str, str]], + Iterable[tuple[str, str]], SupportsKeysAndGetItem, ] """ diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 75bae6b77..42dd6c5fa 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence from ..frames import Frame from ..typing import ExtensionName, ExtensionParameter diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 25d2c1c45..f962b65fb 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,8 @@ import dataclasses import zlib -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from .. import frames from ..exceptions import ( diff --git a/src/websockets/frames.py b/src/websockets/frames.py index a63bdc3b6..dace2c902 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -6,7 +6,8 @@ import os import secrets import struct -from typing import Callable, Generator, Sequence +from collections.abc import Generator, Sequence +from typing import Callable from .exceptions import PayloadTooBig, ProtocolError diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 9103018a0..e05948a1f 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -4,7 +4,8 @@ import binascii import ipaddress import re -from typing import Callable, Sequence, TypeVar, cast +from collections.abc import Sequence +from typing import Callable, TypeVar, cast from .exceptions import InvalidHeaderFormat, InvalidHeaderValue from .typing import ( diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 47cef7a9b..af542c77b 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -5,7 +5,8 @@ import re import sys import warnings -from typing import Callable, Generator +from collections.abc import Generator +from typing import Callable from .datastructures import Headers from .exceptions import SecurityError diff --git a/src/websockets/imports.py b/src/websockets/imports.py index bb80e4eac..c63fb212e 100644 --- a/src/websockets/imports.py +++ b/src/websockets/imports.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any __all__ = ["lazy_import"] diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index 4d030e5e2..a262fcd79 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -3,7 +3,8 @@ import functools import hmac import http -from typing import Any, Awaitable, Callable, Iterable, Tuple, cast +from collections.abc import Awaitable, Iterable +from typing import Any, Callable, cast from ..datastructures import Headers from ..exceptions import InvalidHeader @@ -13,8 +14,7 @@ __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] -# Change to tuple[str, str] when dropping Python < 3.9. -Credentials = Tuple[str, str] +Credentials = tuple[str, str] def is_credentials(value: Any) -> bool: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index ec4c2ff64..116445e25 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -7,15 +7,9 @@ import random import urllib.parse import warnings +from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import ( - Any, - AsyncIterator, - Callable, - Generator, - Sequence, - cast, -) +from typing import Any, Callable, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4c2f8c23f..4ec194ed7 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -1,7 +1,8 @@ from __future__ import annotations import struct -from typing import Any, Awaitable, Callable, NamedTuple, Sequence +from collections.abc import Awaitable, Sequence +from typing import Any, Callable, NamedTuple from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 998e390d4..cedde6200 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -11,17 +11,8 @@ import time import uuid import warnings -from typing import ( - Any, - AsyncIterable, - AsyncIterator, - Awaitable, - Callable, - Deque, - Iterable, - Mapping, - cast, -) +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping +from typing import Any, Callable, Deque, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 2cb9b1abb..9326b6100 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -8,18 +8,9 @@ import logging import socket import warnings +from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import ( - Any, - Awaitable, - Callable, - Generator, - Iterable, - Sequence, - Tuple, - Union, - cast, -) +from typing import Any, Callable, Union, cast from ..asyncio.compatibility import asyncio_timeout from ..datastructures import Headers, HeadersLike, MultipleValuesError @@ -59,8 +50,7 @@ # Change to HeadersLike | ... when dropping Python < 3.10. HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] -# Change to tuple[...] when dropping Python < 3.9. -HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] +HTTPResponse = tuple[StatusLike, HeadersLike, bytes] class WebSocketServerProtocol(WebSocketCommonProtocol): diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 8751ebdb4..091b4a23a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -3,7 +3,8 @@ import enum import logging import uuid -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from .exceptions import ( ConnectionClosed, @@ -529,7 +530,7 @@ def close_expected(self) -> bool: # Private methods for receiving data. - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: """ Parse incoming data into frames. @@ -600,7 +601,7 @@ def parse(self) -> Generator[None, None, None]: yield raise AssertionError("parse() shouldn't step after error") - def discard(self) -> Generator[None, None, None]: + def discard(self) -> Generator[None]: """ Discard incoming data. diff --git a/src/websockets/server.py b/src/websockets/server.py index 006d5bdd5..9fe970619 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -5,7 +5,8 @@ import email.utils import http import warnings -from typing import Any, Callable, Generator, Sequence, cast +from collections.abc import Generator, Sequence +from typing import Any, Callable, cast from .datastructures import Headers, MultipleValuesError from .exceptions import ( @@ -555,7 +556,7 @@ def send_response(self, response: Response) -> None: self.parser = self.discard() next(self.parser) # start coroutine - def parse(self) -> Generator[None, None, None]: + def parse(self) -> Generator[None]: if self.state is CONNECTING: try: request = yield from Request.parse( diff --git a/src/websockets/streams.py b/src/websockets/streams.py index 956f139d4..f52e6193a 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generator +from collections.abc import Generator class StreamReader: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index d1e20a757..5e1ba6d84 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -4,7 +4,8 @@ import ssl as ssl_module import threading import warnings -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from ..client import ClientProtocol from ..datastructures import HeadersLike diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 97588870e..8c5df9592 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -7,8 +7,9 @@ import struct import threading import uuid +from collections.abc import Iterable, Iterator, Mapping from types import TracebackType -from typing import Any, Iterable, Iterator, Mapping +from typing import Any from ..exceptions import ( ConcurrencyError, @@ -239,8 +240,7 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - for frame in self.recv_messages.get_iter(): - yield frame + yield from self.recv_messages.get_iter() except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 8d090538f..b96cd6880 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,8 @@ import codecs import queue import threading -from typing import Iterator, cast +from collections.abc import Iterator +from typing import cast from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -150,8 +151,7 @@ def get_iter(self) -> Iterator[Data]: chunks = self.chunks self.chunks = [] self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.9. - "queue.SimpleQueue[Data | None]", + queue.SimpleQueue[Data | None], queue.SimpleQueue(), ) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 1b7cbb4b4..464c4a173 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -10,8 +10,9 @@ import sys import threading import warnings +from collections.abc import Iterable, Sequence from types import TracebackType -from typing import Any, Callable, Iterable, Sequence, Tuple, cast +from typing import Any, Callable, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -663,9 +664,9 @@ def basic_auth( if credentials is not None: if is_credentials(credentials): - credentials_list = [cast(Tuple[str, str], credentials)] + credentials_list = [cast(tuple[str, str], credentials)] elif isinstance(credentials, Iterable): - credentials_list = list(cast(Iterable[Tuple[str, str]], credentials)) + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 447fe79da..0a37141c6 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -3,7 +3,7 @@ import http import logging import typing -from typing import Any, List, NewType, Optional, Tuple, Union +from typing import Any, NewType, Optional, Union __all__ = [ @@ -56,16 +56,14 @@ ExtensionName = NewType("ExtensionName", str) """Name of a WebSocket extension.""" -# Change to tuple[str, Optional[str]] when dropping Python < 3.9. # Change to tuple[str, str | None] when dropping Python < 3.10. -ExtensionParameter = Tuple[str, Optional[str]] +ExtensionParameter = tuple[str, Optional[str]] """Parameter of a WebSocket extension.""" # Private types -# Change to tuple[.., list[...]] when dropping Python < 3.9. -ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]] +ExtensionHeader = tuple[ExtensionName, list[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 999ef1b71..9354a6e0a 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -177,10 +177,6 @@ def process_request(connection, request): self.assertEqual(iterations, 5) self.assertEqual(successful, 2) - @unittest.skipUnless( - hasattr(http.HTTPStatus, "IM_A_TEAPOT"), - "test requires Python 3.9", - ) async def test_reconnect_with_custom_process_exception(self): """Client runs process_exception to tell if errors are retryable or fatal.""" iteration = 0 @@ -214,10 +210,6 @@ def process_exception(exc): "🫖 💔 ☕️", ) - @unittest.skipUnless( - hasattr(http.HTTPStatus, "IM_A_TEAPOT"), - "test requires Python 3.9", - ) async def test_reconnect_with_custom_process_exception_raising_exception(self): """Client supports raising an exception in process_exception.""" diff --git a/tox.ini b/tox.ini index cba9b290b..0bcec5ded 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,5 @@ [tox] env_list = - py38 py39 py310 py311 From e44a1eadf3c287335fd10c381ba5ccb8bd1ff4ce Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:00:24 +0200 Subject: [PATCH 632/762] Document why packagers mustn't run the test suite. Refs #1509, #1496, #1427, #1426, #1081, #1026, perhaps others. --- docs/project/contributing.rst | 26 ++++++++++++++++++++++++-- tests/utils.py | 18 +++++++++++++++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 020ed7ad8..3988c028a 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -17,8 +17,8 @@ apologies. I know I can mess up. I can't expect you to tell me, but if you choose to do so, I'll do my best to handle criticism constructively. -- Aymeric)* -Contributions -------------- +Contributing +------------ Bug reports, patches and suggestions are welcome! @@ -34,6 +34,28 @@ websockets. .. _issue: https://github.com/python-websockets/websockets/issues/new .. _pull request: https://github.com/python-websockets/websockets/compare/ +Packaging +--------- + +Some distributions package websockets so that it can be installed with the +system package manager rather than with pip, possibly in a virtualenv. + +If you're packaging websockets for a distribution, you must use `releases +published on PyPI`_ as input. You may check `SLSA attestations on GitHub`_. + +.. _releases published on PyPI: https://pypi.org/project/websockets/#files +.. _SLSA attestations on GitHub: https://github.com/python-websockets/websockets/attestations + +You mustn't rely on the git repository as input. Specifically, you mustn't +attempt to run the main test suite. It isn't treated as a deliverable of the +project. It doesn't do what you think it does. It's designed for the needs of +developers, not packagers. + +On a typical build farm for a distribution, tests that exercise timeouts will +fail randomly. Indeed, the test suite is optimized for running very fast, with a +tolerable level of flakiness, on a high-end laptop without noisy neighbors. This +isn't your context. + Questions --------- diff --git a/tests/utils.py b/tests/utils.py index 960439135..639fb7fe5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,8 @@ import unittest import warnings +from websockets.version import released + # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ @@ -39,9 +41,19 @@ DATE = email.utils.formatdate(usegmt=True) -# Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. -MS = 0.001 * float(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", "1")) +# Unit for timeouts. May be increased in slow or noisy environments by setting +# the WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. + +# Downstream distributors insist on running the test suite despites my pleas to +# the contrary. They do it on build farms with unstable performance, leading to +# flakiness, and then they file bugs. Make tests 100x slower to avoid flakiness. + +MS = 0.001 * float( + os.environ.get( + "WEBSOCKETS_TESTS_TIMEOUT_FACTOR", + "100" if released else "1", + ) +) # PyPy, asyncio's debug mode, and coverage penalize performance of this # test suite. Increase timeouts to reduce the risk of spurious failures. From 4f4e64442e84763ca634f67ba0062c3fb8c985c6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:46:55 +0200 Subject: [PATCH 633/762] Restore compatibility with Python 3.9. It was broken in 44ccee17. --- src/websockets/sync/messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index b96cd6880..997fa98df 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -151,7 +151,8 @@ def get_iter(self) -> Iterator[Data]: chunks = self.chunks self.chunks = [] self.chunks_queue = cast( - queue.SimpleQueue[Data | None], + # Remove quotes around type when dropping Python < 3.10. + "queue.SimpleQueue[Data | None]", queue.SimpleQueue(), ) From ddafc6682a3f6d0bed9d88dfbecd15dd246973bf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 20:25:10 +0200 Subject: [PATCH 634/762] Split "getting support" from "contributing". Also add a page about financial contributions. --- docs/project/contributing.rst | 31 ---------------------- docs/project/index.rst | 4 ++- docs/project/sponsoring.rst | 11 ++++++++ docs/project/support.rst | 49 +++++++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 32 deletions(-) create mode 100644 docs/project/sponsoring.rst create mode 100644 docs/project/support.rst diff --git a/docs/project/contributing.rst b/docs/project/contributing.rst index 3988c028a..6ecd175f8 100644 --- a/docs/project/contributing.rst +++ b/docs/project/contributing.rst @@ -55,34 +55,3 @@ On a typical build farm for a distribution, tests that exercise timeouts will fail randomly. Indeed, the test suite is optimized for running very fast, with a tolerable level of flakiness, on a high-end laptop without noisy neighbors. This isn't your context. - -Questions ---------- - -GitHub issues aren't a good medium for handling questions. There are better -places to ask questions, for example Stack Overflow. - -If you want to ask a question anyway, please make sure that: - -- it's a question about websockets and not about :mod:`asyncio`; -- it isn't answered in the documentation; -- it wasn't asked already. - -A good question can be written as a suggestion to improve the documentation. - -Cryptocurrency users --------------------- - -websockets appears to be quite popular for interfacing with Bitcoin or other -cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. - -I'm aware of efforts to build proof-of-stake models. I'll care once the total -energy consumption of all cryptocurrencies drops to a non-bullshit level. - -You already negated all of humanity's efforts to develop renewable energy. -Please stop heating the planet where my children will have to live. - -Since websockets is released under an open-source license, you can use it for -any purpose you like. However, I won't spend any of my time to help you. - -I will summarily close issues related to Bitcoin or cryptocurrency in any way. diff --git a/docs/project/index.rst b/docs/project/index.rst index 459146345..56c98196a 100644 --- a/docs/project/index.rst +++ b/docs/project/index.rst @@ -8,5 +8,7 @@ This is about websockets-the-project rather than websockets-the-software. changelog contributing - license + sponsoring For enterprise + support + license diff --git a/docs/project/sponsoring.rst b/docs/project/sponsoring.rst new file mode 100644 index 000000000..77a4fd1d8 --- /dev/null +++ b/docs/project/sponsoring.rst @@ -0,0 +1,11 @@ +Sponsoring +========== + +You may sponsor the development of websockets through: + +* `GitHub Sponsors`_ +* `Open Collective`_ +* :doc:`Tidelift ` + +.. _GitHub Sponsors: https://github.com/sponsors/python-websockets +.. _Open Collective: https://opencollective.com/websockets diff --git a/docs/project/support.rst b/docs/project/support.rst new file mode 100644 index 000000000..21aad6e02 --- /dev/null +++ b/docs/project/support.rst @@ -0,0 +1,49 @@ +Getting support +=============== + +.. admonition:: There are no free support channels. + :class: tip + + websockets is an open-source project. It's primarily maintained by one + person as a hobby. + + For this reason, the focus is on flawless code and self-service + documentation, not support. + +Enterprise +---------- + +websockets is maintained with high standards, making it suitable for enterprise +use cases. Additional guarantees are available via :doc:`Tidelift `. +If you're using it in a professional setting, consider subscribing. + +Questions +--------- + +GitHub issues aren't a good medium for handling questions. There are better +places to ask questions, for example Stack Overflow. + +If you want to ask a question anyway, please make sure that: + +- it's a question about websockets and not about :mod:`asyncio`; +- it isn't answered in the documentation; +- it wasn't asked already. + +A good question can be written as a suggestion to improve the documentation. + +Cryptocurrency users +-------------------- + +websockets appears to be quite popular for interfacing with Bitcoin or other +cryptocurrency trackers. I'm strongly opposed to Bitcoin's carbon footprint. + +I'm aware of efforts to build proof-of-stake models. I'll care once the total +energy consumption of all cryptocurrencies drops to a non-bullshit level. + +You already negated all of humanity's efforts to develop renewable energy. +Please stop heating the planet where my children will have to live. + +Since websockets is released under an open-source license, you can use it for +any purpose you like. However, I won't spend any of my time to help you. + +I will summarily close issues related to cryptocurrency in any way. From a942fcc13f47d9a5bdcd93b6f84da21e1d185e63 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 22:05:54 +0200 Subject: [PATCH 635/762] Switch convenience aliases to the new implementation. --- docs/howto/upgrade.rst | 29 ++++++++++++++--------------- docs/project/changelog.rst | 11 +++++++++++ docs/reference/index.rst | 2 +- src/websockets/__init__.py | 38 ++++++++++++++++++++------------------ 4 files changed, 46 insertions(+), 34 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index f3e42591e..5b1b8e4a2 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -3,8 +3,8 @@ Upgrade to the new :mod:`asyncio` implementation .. currentmodule:: websockets -The new :mod:`asyncio` implementation is a rewrite of the original -implementation of websockets. +The new :mod:`asyncio` implementation, which is now the default, is a rewrite of +the original implementation of websockets. It provides a very similar API. However, there are a few differences. @@ -70,9 +70,8 @@ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package. -* The ``websockets`` package provides aliases for convenience. Currently, they - point to the original implementation. They will be updated to point to the new - implementation when it feels mature. +* The ``websockets`` package provides aliases for convenience. They were + switched to the new implementation in version 14.0. * The ``websockets.client`` and ``websockets.server`` packages provide aliases for backwards-compatibility with earlier versions of websockets. They will be deprecated together with the original implementation. @@ -90,12 +89,12 @@ Client APIs +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ -| ``websockets.connect()`` |br| | :func:`websockets.asyncio.client.connect` | -| ``websockets.client.connect()`` |br| | | +| ``websockets.connect()`` *(before 14.0)* |br| | ``websockets.connect()`` *(since 14.0)* |br| | +| ``websockets.client.connect()`` |br| | :func:`websockets.asyncio.client.connect` | | :func:`websockets.legacy.client.connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | -| ``websockets.client.unix_connect()`` |br| | | +| ``websockets.unix_connect()`` *(before 14.0)* |br| | ``websockets.unix_connect()`` *(since 14.0)* |br| | +| ``websockets.client.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | | :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | @@ -109,12 +108,12 @@ Server APIs +-------------------------------------------------------------------+-----------------------------------------------------+ | Legacy :mod:`asyncio` implementation | New :mod:`asyncio` implementation | +===================================================================+=====================================================+ -| ``websockets.serve()`` |br| | :func:`websockets.asyncio.server.serve` | -| ``websockets.server.serve()`` |br| | | +| ``websockets.serve()`` *(before 14.0)* |br| | ``websockets.serve()`` *(since 14.0)* |br| | +| ``websockets.server.serve()`` |br| | :func:`websockets.asyncio.server.serve` | | :func:`websockets.legacy.server.serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | -| ``websockets.server.unix_serve()`` |br| | | +| ``websockets.unix_serve()`` *(before 14.0)* |br| | ``websockets.unix_serve()`` *(since 14.0)* |br| | +| ``websockets.server.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | @@ -125,8 +124,8 @@ Server APIs | ``websockets.server.WebSocketServerProtocol`` |br| | | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.broadcast()`` |br| | :func:`websockets.asyncio.server.broadcast` | -| :func:`websockets.legacy.server.broadcast()` | | +| ``websockets.broadcast()`` *(before 14.0)* |br| | ``websockets.broadcast()`` *(since 14.0)* |br| | +| :func:`websockets.legacy.server.broadcast()` | :func:`websockets.asyncio.server.broadcast` | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.BasicAuthWebSocketServerProtocol`` |br| | See below :ref:`how to migrate ` to | | ``websockets.auth.BasicAuthWebSocketServerProtocol`` |br| | :func:`websockets.asyncio.server.basic_auth`. | diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5f07fc09f..45e4ef0cc 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -40,6 +40,17 @@ Backwards-incompatible changes websockets 13.1 is the last version supporting Python 3.8. +.. admonition:: The new :mod:`asyncio` implementation is now the default. + :class: caution + + The following aliases in the ``websockets`` package were switched to the new + :mod:`asyncio` implementation:: + + from websockets import connect, unix_connext + from websockets import broadcast, serve, unix_serve + + If you're using any of them, then you must follow the :doc:`upgrade guide + <../howto/upgrade>` immediately. .. _13.1: diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 77b538b78..ed2341cc6 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -98,5 +98,5 @@ guarantees of behavior or backwards-compatibility for private APIs. Convenience imports ------------------- -For convenience, many public APIs can be imported directly from the +For convenience, some public APIs can be imported directly from the ``websockets`` package. diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 54591e9fd..036e71c23 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -7,6 +7,14 @@ __all__ = [ + # .asyncio.client + "connect", + "unix_connect", + # .asyncio.server + "basic_auth", + "broadcast", + "serve", + "unix_serve", # .client "ClientProtocol", # .datastructures @@ -41,8 +49,6 @@ "basic_auth_protocol_factory", # .legacy.client "WebSocketClientProtocol", - "connect", - "unix_connect", # .legacy.exceptions "AbortHandshake", "InvalidMessage", @@ -53,9 +59,6 @@ # .legacy.server "WebSocketServer", "WebSocketServerProtocol", - "broadcast", - "serve", - "unix_serve", # .server "ServerProtocol", # .typing @@ -70,6 +73,8 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: + from .asyncio.client import connect, unix_connect + from .asyncio.server import basic_auth, broadcast, serve, unix_serve from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( @@ -100,7 +105,7 @@ BasicAuthWebSocketServerProtocol, basic_auth_protocol_factory, ) - from .legacy.client import WebSocketClientProtocol, connect, unix_connect + from .legacy.client import WebSocketClientProtocol from .legacy.exceptions import ( AbortHandshake, InvalidMessage, @@ -108,13 +113,7 @@ RedirectHandshake, ) from .legacy.protocol import WebSocketCommonProtocol - from .legacy.server import ( - WebSocketServer, - WebSocketServerProtocol, - broadcast, - serve, - unix_serve, - ) + from .legacy.server import WebSocketServer, WebSocketServerProtocol from .server import ServerProtocol from .typing import ( Data, @@ -129,6 +128,14 @@ lazy_import( globals(), aliases={ + # .asyncio.client + "connect": ".asyncio.client", + "unix_connect": ".asyncio.client", + # .asyncio.server + "basic_auth": ".asyncio.server", + "broadcast": ".asyncio.server", + "serve": ".asyncio.server", + "unix_serve": ".asyncio.server", # .client "ClientProtocol": ".client", # .datastructures @@ -163,8 +170,6 @@ "basic_auth_protocol_factory": ".legacy.auth", # .legacy.client "WebSocketClientProtocol": ".legacy.client", - "connect": ".legacy.client", - "unix_connect": ".legacy.client", # .legacy.exceptions "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", @@ -175,9 +180,6 @@ # .legacy.server "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", - "broadcast": ".legacy.server", - "serve": ".legacy.server", - "unix_serve": ".legacy.server", # .server "ServerProtocol": ".server", # .typing From 8d055ebf383520f46c1f0b6d40742e3ef8c3d723 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 22:40:09 +0200 Subject: [PATCH 636/762] Deprecate aliases pointing to the legacy implementation. --- docs/howto/upgrade.rst | 3 +- docs/project/changelog.rst | 6 +++ src/websockets/__init__.py | 63 ++++++++---------------------- src/websockets/auth.py | 10 ++++- src/websockets/client.py | 20 ++++++---- src/websockets/exceptions.py | 41 ++++++------------- src/websockets/server.py | 22 +++++++---- tests/legacy/test_auth.py | 2 +- tests/legacy/test_client_server.py | 2 +- tests/test_exports.py | 30 ++++++++++---- 10 files changed, 99 insertions(+), 100 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 5b1b8e4a2..bdaefd768 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -71,7 +71,8 @@ For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. * The original implementation was moved to the ``websockets.legacy`` package. * The ``websockets`` package provides aliases for convenience. They were - switched to the new implementation in version 14.0. + switched to the new implementation in version 14.0 or deprecated when there + isn't an equivalent API. * The ``websockets.client`` and ``websockets.server`` packages provide aliases for backwards-compatibility with earlier versions of websockets. They will be deprecated together with the original implementation. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 45e4ef0cc..30e89a69c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,6 +52,12 @@ Backwards-incompatible changes If you're using any of them, then you must follow the :doc:`upgrade guide <../howto/upgrade>` immediately. +.. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. + :class: caution + + Aliases for deprecated API were removed from ``__all__``. As a consequence, + they cannot be imported e.g. with ``from websockets import *`` anymore. + .. _13.1: 13.1 diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 036e71c23..531ce49f7 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -43,22 +43,6 @@ "ProtocolError", "SecurityError", "WebSocketException", - "WebSocketProtocolError", - # .legacy.auth - "BasicAuthWebSocketServerProtocol", - "basic_auth_protocol_factory", - # .legacy.client - "WebSocketClientProtocol", - # .legacy.exceptions - "AbortHandshake", - "InvalidMessage", - "InvalidStatusCode", - "RedirectHandshake", - # .legacy.protocol - "WebSocketCommonProtocol", - # .legacy.server - "WebSocketServer", - "WebSocketServerProtocol", # .server "ServerProtocol", # .typing @@ -99,21 +83,7 @@ ProtocolError, SecurityError, WebSocketException, - WebSocketProtocolError, ) - from .legacy.auth import ( - BasicAuthWebSocketServerProtocol, - basic_auth_protocol_factory, - ) - from .legacy.client import WebSocketClientProtocol - from .legacy.exceptions import ( - AbortHandshake, - InvalidMessage, - InvalidStatusCode, - RedirectHandshake, - ) - from .legacy.protocol import WebSocketCommonProtocol - from .legacy.server import WebSocketServer, WebSocketServerProtocol from .server import ServerProtocol from .typing import ( Data, @@ -164,22 +134,6 @@ "ProtocolError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", - "WebSocketProtocolError": ".exceptions", - # .legacy.auth - "BasicAuthWebSocketServerProtocol": ".legacy.auth", - "basic_auth_protocol_factory": ".legacy.auth", - # .legacy.client - "WebSocketClientProtocol": ".legacy.client", - # .legacy.exceptions - "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", - "InvalidStatusCode": ".legacy.exceptions", - "RedirectHandshake": ".legacy.exceptions", - # .legacy.protocol - "WebSocketCommonProtocol": ".legacy.protocol", - # .legacy.server - "WebSocketServer": ".legacy.server", - "WebSocketServerProtocol": ".legacy.server", # .server "ServerProtocol": ".server", # .typing @@ -197,5 +151,22 @@ "handshake": ".legacy", "parse_uri": ".uri", "WebSocketURI": ".uri", + # deprecated in 14.0 + # .legacy.auth + "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "basic_auth_protocol_factory": ".legacy.auth", + # .legacy.client + "WebSocketClientProtocol": ".legacy.client", + # .legacy.exceptions + "AbortHandshake": ".legacy.exceptions", + "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + # .legacy.protocol + "WebSocketCommonProtocol": ".legacy.protocol", + # .legacy.server + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", }, ) diff --git a/src/websockets/auth.py b/src/websockets/auth.py index b792e02f5..1e0002cee 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -1,6 +1,12 @@ from __future__ import annotations -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. +import warnings + from .legacy.auth import * from .legacy.auth import __all__ # noqa: F401 + + +warnings.warn( # deprecated in 14.0 + "websockets.auth is deprecated", + DeprecationWarning, +) diff --git a/src/websockets/client.py b/src/websockets/client.py index bce82d66b..8b66900a8 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -27,6 +27,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .imports import lazy_import from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State from .typing import ( ConnectionOption, @@ -40,13 +41,7 @@ from .utils import accept_key, generate_key -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.client import * # isort:skip # noqa: I001 -from .legacy.client import __all__ as legacy__all__ - - -__all__ = ["ClientProtocol"] + legacy__all__ +__all__ = ["ClientProtocol"] class ClientProtocol(Protocol): @@ -392,3 +387,14 @@ def backoff( delay *= factor while True: yield max_delay + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "WebSocketClientProtocol": ".legacy.client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + }, +) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index d723f2fec..7681736a4 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -31,7 +31,6 @@ from __future__ import annotations -import typing import warnings from .imports import lazy_import @@ -45,9 +44,7 @@ "InvalidURI", "InvalidHandshake", "SecurityError", - "InvalidMessage", "InvalidStatus", - "InvalidStatusCode", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", @@ -57,10 +54,7 @@ "DuplicateParameter", "InvalidParameterName", "InvalidParameterValue", - "AbortHandshake", - "RedirectHandshake", "ProtocolError", - "WebSocketProtocolError", "PayloadTooBig", "InvalidState", "ConcurrencyError", @@ -366,27 +360,18 @@ class ConcurrencyError(WebSocketException, RuntimeError): """ -# When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: - from .legacy.exceptions import ( - AbortHandshake, - InvalidMessage, - InvalidStatusCode, - RedirectHandshake, - ) - - WebSocketProtocolError = ProtocolError -else: - lazy_import( - globals(), - aliases={ - "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", - "InvalidStatusCode": ".legacy.exceptions", - "RedirectHandshake": ".legacy.exceptions", - "WebSocketProtocolError": ".legacy.exceptions", - }, - ) - # At the bottom to break import cycles created by type annotations. from . import frames, http11 # noqa: E402 + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "AbortHandshake": ".legacy.exceptions", + "InvalidMessage": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + }, +) diff --git a/src/websockets/server.py b/src/websockets/server.py index 9fe970619..527db8990 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -27,6 +27,7 @@ parse_upgrade, ) from .http11 import Request, Response +from .imports import lazy_import from .protocol import CONNECTING, OPEN, SERVER, Protocol, State from .typing import ( ConnectionOption, @@ -40,13 +41,7 @@ from .utils import accept_key -# See #940 for why lazy_import isn't used here for backwards compatibility. -# See #1400 for why listing compatibility imports in __all__ helps PyCharm. -from .legacy.server import * # isort:skip # noqa: I001 -from .legacy.server import __all__ as legacy__all__ - - -__all__ = ["ServerProtocol"] + legacy__all__ +__all__ = ["ServerProtocol"] class ServerProtocol(Protocol): @@ -586,3 +581,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: DeprecationWarning, ) super().__init__(*args, **kwargs) + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + }, +) diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index 3754bcf3a..dabd4212a 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -2,10 +2,10 @@ import unittest import urllib.error -from websockets.exceptions import InvalidStatusCode from websockets.headers import build_authorization_basic from websockets.legacy.auth import * from websockets.legacy.auth import is_credentials +from websockets.legacy.exceptions import InvalidStatusCode from .test_client_server import ClientServerTestsMixin, with_client, with_server from .utils import AsyncioTestCase diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 2f3ba9b77..502ab68e7 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -21,7 +21,6 @@ ConnectionClosed, InvalidHandshake, InvalidHeader, - InvalidStatusCode, NegotiationError, ) from websockets.extensions.permessage_deflate import ( @@ -32,6 +31,7 @@ from websockets.frames import CloseCode from websockets.http11 import USER_AGENT from websockets.legacy.client import * +from websockets.legacy.exceptions import InvalidStatusCode from websockets.legacy.handshake import build_response from websockets.legacy.http import read_response from websockets.legacy.server import * diff --git a/tests/test_exports.py b/tests/test_exports.py index 67a1a6f99..93b0684f7 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -1,30 +1,46 @@ import unittest import websockets -import websockets.auth +import websockets.asyncio.client +import websockets.asyncio.server import websockets.client import websockets.datastructures import websockets.exceptions -import websockets.legacy.protocol import websockets.server import websockets.typing import websockets.uri combined_exports = ( - websockets.auth.__all__ + [] + + websockets.asyncio.client.__all__ + + websockets.asyncio.server.__all__ + websockets.client.__all__ + websockets.datastructures.__all__ + websockets.exceptions.__all__ - + websockets.legacy.protocol.__all__ + websockets.server.__all__ + websockets.typing.__all__ ) +# These API are intentionally not re-exported by the top-level module. +missing_reexports = [ + # websockets.asyncio.client + "ClientConnection", + # websockets.asyncio.server + "ServerConnection", + "Server", +] + class ExportsTests(unittest.TestCase): - def test_top_level_module_reexports_all_submodule_exports(self): - self.assertEqual(set(combined_exports), set(websockets.__all__)) + def test_top_level_module_reexports_submodule_exports(self): + self.assertEqual( + set(combined_exports), + set(websockets.__all__ + missing_reexports), + ) def test_submodule_exports_are_globally_unique(self): - self.assertEqual(len(set(combined_exports)), len(combined_exports)) + self.assertEqual( + len(set(combined_exports)), + len(combined_exports), + ) From d62e423744b3e20ef6405019f96a63faa21612a2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 12:10:15 +0200 Subject: [PATCH 637/762] Deprecate the legacy asyncio implementation. --- docs/howto/upgrade.rst | 21 ++++++--------- docs/index.rst | 43 ++++++++++++++++--------------- docs/project/changelog.rst | 3 +++ docs/reference/index.rst | 30 +++++++++++---------- docs/reference/legacy/client.rst | 6 +++++ docs/reference/legacy/common.rst | 6 +++++ docs/reference/legacy/server.rst | 6 +++++ docs/topics/design.rst | 2 ++ docs/topics/index.rst | 1 - src/websockets/auth.py | 12 ++++++--- src/websockets/http.py | 7 ++++- src/websockets/legacy/__init__.py | 11 ++++++++ tests/legacy/__init__.py | 9 +++++++ tests/maxi_cov.py | 2 -- tests/test_auth.py | 14 ++++++++++ 15 files changed, 118 insertions(+), 55 deletions(-) create mode 100644 tests/test_auth.py diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index bdaefd768..02d4c6f01 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -27,15 +27,9 @@ respectively. .. admonition:: What will happen to the original implementation? :class: hint - The original implementation is now considered legacy. - - The next steps are: - - 1. Deprecating it once the new implementation is considered sufficiently - robust. - 2. Maintaining it for five years per the :ref:`backwards-compatibility - policy `. - 3. Removing it. This is expected to happen around 2030. + The original implementation is deprecated. It will be maintained for five + years after deprecation according to the :ref:`backwards-compatibility + policy `. Then, by 2030, it will be removed. .. _deprecated APIs: @@ -69,13 +63,14 @@ Import paths For context, the ``websockets`` package is structured as follows: * The new implementation is found in the ``websockets.asyncio`` package. -* The original implementation was moved to the ``websockets.legacy`` package. +* The original implementation was moved to the ``websockets.legacy`` package + and deprecated. * The ``websockets`` package provides aliases for convenience. They were switched to the new implementation in version 14.0 or deprecated when there - isn't an equivalent API. + wasn't an equivalent API. * The ``websockets.client`` and ``websockets.server`` packages provide aliases - for backwards-compatibility with earlier versions of websockets. They will - be deprecated together with the original implementation. + for backwards-compatibility with earlier versions of websockets. They were + deprecated. To upgrade to the new :mod:`asyncio` implementation, change import paths as shown in the tables below. diff --git a/docs/index.rst b/docs/index.rst index b8cd300e3..f9576f2dc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,30 +28,13 @@ with a focus on correctness, simplicity, robustness, and performance. It supports several network I/O and control flow paradigms. -1. The primary implementation builds upon :mod:`asyncio`, Python's standard - asynchronous I/O framework. It provides an elegant coroutine-based API. It's - ideal for servers that handle many clients concurrently. - - .. admonition:: As of version :ref:`13.0`, there is a new :mod:`asyncio` - implementation. - :class: important - - The historical implementation in ``websockets.legacy`` traces its roots to - early versions of websockets. Although it's stable and robust, it is now - considered legacy. - - The new implementation in ``websockets.asyncio`` is a rewrite on top of - the Sans-I/O implementation. It adds a few features that were impossible - to implement within the original design. - - The new implementation provides all features of the historical - implementation, and a few more. If you're using the historical - implementation, you should :doc:`ugrade to the new implementation - `. It's usually straightforward. +1. The default implementation builds upon :mod:`asyncio`, Python's built-in + asynchronous I/O library. It provides an elegant coroutine-based API. It's + ideal for servers that handle many client connections. 2. The :mod:`threading` implementation is a good alternative for clients, especially if you aren't familiar with :mod:`asyncio`. It may also be used - for servers that don't need to serve many clients. + for servers that handle few client connections. 3. The `Sans-I/O`_ implementation is designed for integrating in third-party libraries, typically application servers, in addition being used internally @@ -59,6 +42,24 @@ It supports several network I/O and control flow paradigms. .. _Sans-I/O: https://sans-io.readthedocs.io/ +Refer to the :doc:`feature support matrices ` for the full +list of features provided by each implementation. + +.. admonition:: The :mod:`asyncio` implementation was rewritten. + :class: tip + + The new implementation in ``websockets.asyncio`` builds upon the Sans-I/O + implementation. It adds features that were impossible to provide in the + original design. It was introduced in version 13.0. + + The historical implementation in ``websockets.legacy`` traces its roots to + early versions of websockets. While it's stable and robust, it was deprecated + in version 14.0 and it will be removed by 2030. + + The new implementation provides the same features as the historical + implementation, and then some. If you're using the historical implementation, + you should :doc:`ugrade to the new implementation `. + Here's an echo server using the :mod:`asyncio` API: .. literalinclude:: ../example/echo.py diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 30e89a69c..c8d854ba4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -55,6 +55,9 @@ Backwards-incompatible changes .. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. :class: caution + The :doc:`upgrade guide <../howto/upgrade>` provides complete instructions + to migrate your application. + Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index ed2341cc6..c78a3c095 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -13,12 +13,12 @@ Check which implementations support which features and known limitations. features -:mod:`asyncio` (new) --------------------- +:mod:`asyncio` +-------------- It's ideal for servers that handle many clients concurrently. -It's a rewrite of the legacy :mod:`asyncio` implementation. +This is the default implementation. .. toctree:: :titlesonly: @@ -26,17 +26,6 @@ It's a rewrite of the legacy :mod:`asyncio` implementation. asyncio/server asyncio/client -:mod:`asyncio` (legacy) ------------------------ - -This is the historical implementation. - -.. toctree:: - :titlesonly: - - legacy/server - legacy/client - :mod:`threading` ---------------- @@ -62,6 +51,19 @@ application servers. sansio/server sansio/client +:mod:`asyncio` (legacy) +----------------------- + +This is the historical implementation. + +It is deprecated and will be removed. + +.. toctree:: + :titlesonly: + + legacy/server + legacy/client + Extensions ---------- diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst index fca45d218..a798409f0 100644 --- a/docs/reference/legacy/client.rst +++ b/docs/reference/legacy/client.rst @@ -1,6 +1,12 @@ Client (legacy :mod:`asyncio`) ============================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.client Opening a connection diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst index aee774479..45c56fccd 100644 --- a/docs/reference/legacy/common.rst +++ b/docs/reference/legacy/common.rst @@ -3,6 +3,12 @@ Both sides (legacy :mod:`asyncio`) ================================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.protocol .. autoclass:: WebSocketCommonProtocol(*, logger=None, ping_interval=20, ping_timeout=20, close_timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16) diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index b6c383ce7..3c1d19fc6 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -1,6 +1,12 @@ Server (legacy :mod:`asyncio`) ============================== +.. admonition:: The legacy :mod:`asyncio` implementation is deprecated. + :class: caution + + The :doc:`upgrade guide <../../howto/upgrade>` provides complete instructions + to migrate your application. + .. automodule:: websockets.legacy.server Starting a server diff --git a/docs/topics/design.rst b/docs/topics/design.rst index d2fd18d0c..b73ace517 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,3 +1,5 @@ +:orphan: + Design (legacy :mod:`asyncio`) ============================== diff --git a/docs/topics/index.rst b/docs/topics/index.rst index a2b8ca879..616753c6c 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -12,7 +12,6 @@ Get a deeper understanding of how websockets is built and why. broadcast compression keepalive - design memory security performance diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 1e0002cee..98e62af3c 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -2,11 +2,17 @@ import warnings -from .legacy.auth import * -from .legacy.auth import __all__ # noqa: F401 + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.auth import * + from .legacy.auth import __all__ # noqa: F401 warnings.warn( # deprecated in 14.0 - "websockets.auth is deprecated", + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", DeprecationWarning, ) diff --git a/src/websockets/http.py b/src/websockets/http.py index 0ff5598c7..0d860e537 100644 --- a/src/websockets/http.py +++ b/src/websockets/http.py @@ -3,7 +3,12 @@ import warnings from .datastructures import Headers, MultipleValuesError # noqa: F401 -from .legacy.http import read_request, read_response # noqa: F401 + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.http import read_request, read_response # noqa: F401 warnings.warn( # deprecated in 9.0 - 2021-09-01 diff --git a/src/websockets/legacy/__init__.py b/src/websockets/legacy/__init__.py index e69de29bb..84f870f3a 100644 --- a/src/websockets/legacy/__init__.py +++ b/src/websockets/legacy/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import warnings + + +warnings.warn( # deprecated in 14.0 + "websockets.legacy is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + DeprecationWarning, +) diff --git a/tests/legacy/__init__.py b/tests/legacy/__init__.py index e69de29bb..035834a89 100644 --- a/tests/legacy/__init__.py +++ b/tests/legacy/__init__.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import warnings + + +with warnings.catch_warnings(): + # Suppress DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + import websockets.legacy # noqa: F401 diff --git a/tests/maxi_cov.py b/tests/maxi_cov.py index 83686c3d3..8ccef7d39 100755 --- a/tests/maxi_cov.py +++ b/tests/maxi_cov.py @@ -9,7 +9,6 @@ UNMAPPED_SRC_FILES = [ - "websockets/auth.py", "websockets/typing.py", "websockets/version.py", ] @@ -105,7 +104,6 @@ def get_ignored_files(src_dir="src"): # or websockets (import locations). "*/websockets/asyncio/async_timeout.py", "*/websockets/asyncio/compatibility.py", - "*/websockets/auth.py", # This approach isn't applicable to the test suite of the legacy # implementation, due to the huge test_client_server test module. "*/websockets/legacy/*", diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 000000000..16c00c1b9 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,14 @@ +from .utils import DeprecationTestCase + + +class BackwardsCompatibilityTests(DeprecationTestCase): + def test_headers_class(self): + with self.assertDeprecationWarning( + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + ): + from websockets.auth import ( + BasicAuthWebSocketServerProtocol, # noqa: F401 + basic_auth_protocol_factory, # noqa: F401 + ) From a0b20f081d7ae48409c6b79ed16bbc261d5109f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 22 Sep 2024 21:16:26 +0200 Subject: [PATCH 638/762] Document that only asyncio supports keepalive. Fix #1508. --- docs/topics/keepalive.rst | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 91f11fb11..458fa3d05 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -1,6 +1,11 @@ Keepalive and latency ===================== +.. admonition:: This guide applies only to the :mod:`asyncio` implementation. + :class: tip + + The :mod:`threading` implementation doesn't provide keepalive yet. + .. currentmodule:: websockets Long-lived connections @@ -31,14 +36,8 @@ based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 -It loops through these steps: - -1. Wait 20 seconds. -2. Send a Ping frame. -3. Receive a corresponding Pong frame within 20 seconds. - -If the Pong frame isn't received, websockets considers the connection broken and -closes it. +It sends a Ping frame every 20 seconds. It expects a Pong frame in return within +20 seconds. Else, it considers the connection broken and terminates it. This mechanism serves three purposes: From baadc33364131ff7236249c9077ae10b561395b6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 23 Sep 2024 23:52:25 +0200 Subject: [PATCH 639/762] Profile and optimize the permessage-deflate extension. dataclasses.replace is surprisingly expensive. zlib functions make up the bulk of the cost now. --- experiments/compression/corpus.py | 2 +- experiments/profiling/compression.py | 45 +++++++++++++++++++ .../extensions/permessage_deflate.py | 29 +++++++++--- 3 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 experiments/profiling/compression.py diff --git a/experiments/compression/corpus.py b/experiments/compression/corpus.py index da5661dfa..56e262114 100644 --- a/experiments/compression/corpus.py +++ b/experiments/compression/corpus.py @@ -47,6 +47,6 @@ def main(corpus): if __name__ == "__main__": if len(sys.argv) < 2: - print(f"Usage: {sys.argv[0]} [directory]") + print(f"Usage: {sys.argv[0]} ") sys.exit(2) main(pathlib.Path(sys.argv[1])) diff --git a/experiments/profiling/compression.py b/experiments/profiling/compression.py new file mode 100644 index 000000000..1ece1f10e --- /dev/null +++ b/experiments/profiling/compression.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python + +""" +Profile the permessage-deflate extension. + +Usage:: + $ pip install line_profiler + $ python experiments/compression/corpus.py experiments/compression/corpus + $ PYTHONPATH=src python -m kernprof \ + --line-by-line \ + --prof-mod src/websockets/extensions/permessage_deflate.py \ + --view \ + experiments/profiling/compression.py experiments/compression/corpus 12 5 6 + +""" + +import pathlib +import sys + +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.frames import OP_TEXT, Frame + + +def compress_and_decompress(corpus, max_window_bits, memory_level, level): + extension = PerMessageDeflate( + remote_no_context_takeover=False, + local_no_context_takeover=False, + remote_max_window_bits=max_window_bits, + local_max_window_bits=max_window_bits, + compress_settings={"memLevel": memory_level, "level": level}, + ) + for data in corpus: + frame = Frame(OP_TEXT, data) + frame = extension.encode(frame) + frame = extension.decode(frame) + + +if __name__ == "__main__": + if len(sys.argv) < 2 or not pathlib.Path(sys.argv[1]).is_dir(): + print(f"Usage: {sys.argv[0]} [] []") + corpus = [file.read_bytes() for file in pathlib.Path(sys.argv[1]).iterdir()] + max_window_bits = int(sys.argv[2]) if len(sys.argv) > 2 else 12 + memory_level = int(sys.argv[3]) if len(sys.argv) > 3 else 5 + level = int(sys.argv[4]) if len(sys.argv) > 4 else 6 + compress_and_decompress(corpus, max_window_bits, memory_level, level) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index f962b65fb..21df804fd 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import zlib from collections.abc import Sequence from typing import Any @@ -120,7 +119,6 @@ def decode( else: if not frame.rsv1: return frame - frame = dataclasses.replace(frame, rsv1=False) if not frame.fin: self.decode_cont_data = True @@ -146,7 +144,15 @@ def decode( if frame.fin and self.remote_no_context_takeover: del self.decoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Unset the rsv1 flag on the first frame of a compressed message. + False, + frame.rsv2, + frame.rsv3, + ) def encode(self, frame: frames.Frame) -> frames.Frame: """ @@ -161,8 +167,6 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # data" flag similar to "decode continuation data" at this time. if frame.opcode is not frames.OP_CONT: - # Set the rsv1 flag on the first frame of a compressed message. - frame = dataclasses.replace(frame, rsv1=True) # Re-initialize per-message decoder. if self.local_no_context_takeover: self.encoder = zlib.compressobj( @@ -172,14 +176,25 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK): + if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK: + # Making a copy is faster than memoryview(a)[:-4] until about 2kB. + # On larger messages, it's slower but profiling shows that it's + # marginal compared to compress() and flush(). Keep it simple. data = data[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: del self.encoder - return dataclasses.replace(frame, data=data) + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Set the rsv1 flag on the first frame of a compressed message. + frame.opcode is not frames.OP_CONT, + frame.rsv2, + frame.rsv3, + ) def _build_parameters( From 524dd4afa8dc2a0d17ab08f20f592af03c165db1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 26 Sep 2024 23:11:28 +0200 Subject: [PATCH 640/762] Avoid making a copy of large frames. This isn't very significant compared to the cost of compression. It can make a real difference for decompression. --- docs/project/changelog.rst | 14 +++++++++ .../extensions/permessage_deflate.py | 30 ++++++++++++------- src/websockets/frames.py | 8 ++--- tests/extensions/test_permessage_deflate.py | 11 +++++++ 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index c8d854ba4..f5b4812bd 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -61,6 +61,20 @@ Backwards-incompatible changes Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. +.. admonition:: :attr:`Frame.data ` is now a bytes-like object. + :class: note + + In addition to :class:`bytes`, it may be a :class:`bytearray` or a + :class:`memoryview`. + + If you wrote an :class:`extension ` that relies on + methods not provided by these new types, you may need to update your code. + +Improvements +............ + +* Sending or receiving large compressed frames is now faster. + .. _13.1: 13.1 diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 21df804fd..ed16937d8 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -129,16 +129,22 @@ def decode( # Uncompress data. Protect against zip bombs by preventing zlib from # decompressing more than max_length bytes (except when the limit is # disabled with max_size = None). - data = frame.data - if frame.fin: - data += _EMPTY_UNCOMPRESSED_BLOCK + if frame.fin and len(frame.data) < 2044: + # Profiling shows that appending four bytes, which makes a copy, is + # faster than calling decompress() again when data is less than 2kB. + data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK + else: + data = frame.data max_length = 0 if max_size is None else max_size try: data = self.decoder.decompress(data, max_length) + if self.decoder.unconsumed_tail: + raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + if frame.fin and len(frame.data) >= 2044: + # This cannot generate additional data. + self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) except zlib.error as exc: raise ProtocolError("decompression failed") from exc - if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") # Allow garbage collection of the decoder if it won't be reused. if frame.fin and self.remote_no_context_takeover: @@ -176,11 +182,15 @@ def encode(self, frame: frames.Frame) -> frames.Frame: # Compress data. data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) - if frame.fin and data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK: - # Making a copy is faster than memoryview(a)[:-4] until about 2kB. - # On larger messages, it's slower but profiling shows that it's - # marginal compared to compress() and flush(). Keep it simple. - data = data[:-4] + if frame.fin: + # Sync flush generates between 5 or 6 bytes, ending with the bytes + # 0x00 0x00 0xff 0xff, which must be removed. + assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK + # Making a copy is faster than memoryview(a)[:-4] until 2kB. + if len(data) < 2048: + data = data[:-4] + else: + data = memoryview(data)[:-4] # Allow garbage collection of the encoder if it won't be reused. if frame.fin and self.local_no_context_takeover: diff --git a/src/websockets/frames.py b/src/websockets/frames.py index dace2c902..5fadf3c2d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -7,7 +7,7 @@ import secrets import struct from collections.abc import Generator, Sequence -from typing import Callable +from typing import Callable, Union from .exceptions import PayloadTooBig, ProtocolError @@ -139,7 +139,7 @@ class Frame: """ opcode: Opcode - data: bytes + data: Union[bytes, bytearray, memoryview] fin: bool = True rsv1: bool = False rsv2: bool = False @@ -160,7 +160,7 @@ def __str__(self) -> str: if self.opcode is OP_TEXT: # Decoding only the beginning and the end is needlessly hard. # Decode the entire payload then elide later if necessary. - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) elif self.opcode is OP_BINARY: # We'll show at most the first 16 bytes and the last 8 bytes. # Encode just what we need, plus two dummy bytes to elide later. @@ -178,7 +178,7 @@ def __str__(self) -> str: # binary. If self.data is a memoryview, it has no decode() method, # which raises AttributeError. try: - data = repr(self.data.decode()) + data = repr(bytes(self.data).decode()) coding = "text" except (UnicodeDecodeError, AttributeError): binary = self.data diff --git a/tests/extensions/test_permessage_deflate.py b/tests/extensions/test_permessage_deflate.py index ee09813c4..76cd48623 100644 --- a/tests/extensions/test_permessage_deflate.py +++ b/tests/extensions/test_permessage_deflate.py @@ -1,4 +1,5 @@ import dataclasses +import os import unittest from websockets.exceptions import ( @@ -167,6 +168,16 @@ def test_encode_decode_fragmented_binary_frame(self): self.assertEqual(dec_frame1, frame1) self.assertEqual(dec_frame2, frame2) + def test_encode_decode_large_frame(self): + # There is a separate code path that avoids copying data + # when frames are larger than 2kB. Test it for coverage. + frame = Frame(OP_BINARY, os.urandom(4096)) + + enc_frame = self.extension.encode(frame) + dec_frame = self.extension.decode(enc_frame) + + self.assertEqual(dec_frame, frame) + def test_no_decode_text_frame(self): frame = Frame(OP_TEXT, "café".encode()) From 07dc56443acd8aaac4d79db6a30e450d0073137d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 21:44:36 +0200 Subject: [PATCH 641/762] Add asyncio & threading examples to the homepage. Fix #1437. --- docs/index.rst | 19 +++++++++++++++---- example/{ => asyncio}/echo.py | 12 +++++++++--- example/asyncio/hello.py | 17 +++++++++++++++++ example/ruff.toml | 2 ++ example/sync/echo.py | 19 +++++++++++++++++++ example/{ => sync}/hello.py | 10 +++++++--- 6 files changed, 69 insertions(+), 10 deletions(-) rename example/{ => asyncio}/echo.py (51%) create mode 100755 example/asyncio/hello.py create mode 100644 example/ruff.toml create mode 100755 example/sync/echo.py rename example/{ => sync}/hello.py (66%) diff --git a/docs/index.rst b/docs/index.rst index f9576f2dc..de14fa2d0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,13 +60,24 @@ list of features provided by each implementation. implementation, and then some. If you're using the historical implementation, you should :doc:`ugrade to the new implementation `. -Here's an echo server using the :mod:`asyncio` API: +Here's an echo server and corresponding client. -.. literalinclude:: ../example/echo.py +.. tab:: asyncio -Here's a client using the :mod:`threading` API: + .. literalinclude:: ../example/asyncio/echo.py -.. literalinclude:: ../example/hello.py +.. tab:: threading + + .. literalinclude:: ../example/sync/echo.py + +.. tab:: asyncio + :new-set: + + .. literalinclude:: ../example/asyncio/hello.py + +.. tab:: threading + + .. literalinclude:: ../example/sync/hello.py Don't worry about the opening and closing handshakes, pings and pongs, or any other behavior described in the WebSocket specification. websockets takes care diff --git a/example/echo.py b/example/asyncio/echo.py similarity index 51% rename from example/echo.py rename to example/asyncio/echo.py index b952a5cfb..28d877be7 100755 --- a/example/echo.py +++ b/example/asyncio/echo.py @@ -1,14 +1,20 @@ #!/usr/bin/env python +"""Echo server using the asyncio API.""" + import asyncio from websockets.asyncio.server import serve + async def echo(websocket): async for message in websocket: await websocket.send(message) + async def main(): - async with serve(echo, "localhost", 8765): - await asyncio.get_running_loop().create_future() # run forever + async with serve(echo, "localhost", 8765) as server: + await server.serve_forever() + -asyncio.run(main()) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/asyncio/hello.py b/example/asyncio/hello.py new file mode 100755 index 000000000..6e4518497 --- /dev/null +++ b/example/asyncio/hello.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +"""Client using the asyncio API.""" + +import asyncio +from websockets.asyncio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + await websocket.send("Hello world!") + message = await websocket.recv() + print(message) + + +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/ruff.toml b/example/ruff.toml new file mode 100644 index 000000000..13ae36c08 --- /dev/null +++ b/example/ruff.toml @@ -0,0 +1,2 @@ +[lint.isort] +no-sections = true diff --git a/example/sync/echo.py b/example/sync/echo.py new file mode 100755 index 000000000..4b47db1ba --- /dev/null +++ b/example/sync/echo.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +"""Echo server using the threading API.""" + +from websockets.sync.server import serve + + +def echo(websocket): + for message in websocket: + websocket.send(message) + + +def main(): + with serve(echo, "localhost", 8765) as server: + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/example/hello.py b/example/sync/hello.py similarity index 66% rename from example/hello.py rename to example/sync/hello.py index a3ce0699e..bb4cd3ffd 100755 --- a/example/hello.py +++ b/example/sync/hello.py @@ -1,12 +1,16 @@ #!/usr/bin/env python -import asyncio +"""Client using the threading API.""" + from websockets.sync.client import connect + def hello(): with connect("ws://localhost:8765") as websocket: websocket.send("Hello world!") message = websocket.recv() - print(f"Received: {message}") + print(message) + -hello() +if __name__ == "__main__": + hello() From a5c8943a99a8625836b150ca3559923ecf79bcd0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 21:53:49 +0200 Subject: [PATCH 642/762] Add asyncio & threading client & server examples. They're convenient to tweak to reproduce issues. --- example/asyncio/client.py | 22 ++++++++++++++++++++++ example/asyncio/server.py | 25 +++++++++++++++++++++++++ example/sync/client.py | 20 ++++++++++++++++++++ example/sync/server.py | 24 ++++++++++++++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 example/asyncio/client.py create mode 100644 example/asyncio/server.py create mode 100644 example/sync/client.py create mode 100644 example/sync/server.py diff --git a/example/asyncio/client.py b/example/asyncio/client.py new file mode 100644 index 000000000..e3562642d --- /dev/null +++ b/example/asyncio/client.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +"""Client example using the asyncio API.""" + +import asyncio + +from websockets.asyncio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + await websocket.send(name) + print(f">>> {name}") + + greeting = await websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + asyncio.run(hello()) diff --git a/example/asyncio/server.py b/example/asyncio/server.py new file mode 100644 index 000000000..574e053bf --- /dev/null +++ b/example/asyncio/server.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +"""Server example using the asyncio API.""" + +import asyncio +from websockets.asyncio.server import serve + + +async def hello(websocket): + name = await websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + await websocket.send(greeting) + print(f">>> {greeting}") + + +async def main(): + async with serve(hello, "localhost", 8765) as server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/sync/client.py b/example/sync/client.py new file mode 100644 index 000000000..c0d633c7b --- /dev/null +++ b/example/sync/client.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +"""Client example using the threading API.""" + +from websockets.sync.client import connect + + +def hello(): + with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + websocket.send(name) + print(f">>> {name}") + + greeting = websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + hello() diff --git a/example/sync/server.py b/example/sync/server.py new file mode 100644 index 000000000..030049f81 --- /dev/null +++ b/example/sync/server.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python + +"""Server example using the threading API.""" + +from websockets.sync.server import serve + + +def hello(websocket): + name = websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + websocket.send(greeting) + print(f">>> {greeting}") + + +def main(): + with serve(hello, "localhost", 8765) as server: + server.serve_forever() + + +if __name__ == "__main__": + main() From 21987f96ad93f8c8bbf0b8ea99f3a18a52335730 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 27 Sep 2024 23:06:00 +0200 Subject: [PATCH 643/762] Migrate authentication experiment to new asyncio. --- docs/topics/authentication.rst | 118 ++++++-------- experiments/authentication/app.py | 194 ++++++++++-------------- experiments/authentication/script.js | 5 +- experiments/authentication/test.js | 2 - experiments/authentication/user_info.js | 2 +- 5 files changed, 127 insertions(+), 194 deletions(-) diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index 86d2e2587..e2de4332e 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -1,13 +1,13 @@ Authentication ============== -The WebSocket protocol was designed for creating web applications that need -bidirectional communication between clients running in browsers and servers. +The WebSocket protocol is designed for creating web applications that require +bidirectional communication between browsers and servers. In most practical use cases, WebSocket servers need to authenticate clients in order to route communications appropriately and securely. -:rfc:`6455` stays elusive when it comes to authentication: +:rfc:`6455` remains elusive when it comes to authentication: This protocol doesn't prescribe any particular way that servers can authenticate clients during the WebSocket handshake. The WebSocket @@ -26,8 +26,8 @@ System design Consider a setup where the WebSocket server is separate from the HTTP server. -Most servers built with websockets to complement a web application adopt this -design because websockets doesn't aim at supporting HTTP. +Most servers built with websockets adopt this design because they're a component +in a web application and websockets doesn't aim at supporting HTTP. The following diagram illustrates the authentication flow. @@ -82,8 +82,8 @@ WebSocket server. credentials would be a session identifier or a serialized, signed session. Unfortunately, when the WebSocket server runs on a different domain from - the web application, this idea bumps into the `Same-Origin Policy`_. For - security reasons, setting a cookie on a different origin is impossible. + the web application, this idea hits the wall of the `Same-Origin Policy`_. + For security reasons, setting a cookie on a different origin is impossible. The proper workaround consists in: @@ -108,13 +108,11 @@ WebSocket server. Letting the browser perform HTTP Basic Auth is a nice idea in theory. - In practice it doesn't work due to poor support in browsers. + In practice it doesn't work due to browser support limitations: - As of May 2021: + * Chrome behaves as expected. - * Chrome 90 behaves as expected. - - * Firefox 88 caches credentials too aggressively. + * Firefox caches credentials too aggressively. When connecting again to the same server with new credentials, it reuses the old credentials, which may be expired, resulting in an HTTP 401. Then @@ -123,7 +121,7 @@ WebSocket server. When tokens are short-lived or single-use, this bug produces an interesting effect: every other WebSocket connection fails. - * Safari 14 ignores credentials entirely. + * Safari behaves as expected. Two other options are off the table: @@ -142,8 +140,10 @@ Two other options are off the table: While this is suggested by the RFC, installing a TLS certificate is too far from the mainstream experience of browser users. This could make sense in - high security contexts. I hope developers working on such projects don't - take security advice from the documentation of random open source projects. + high security contexts. + + I hope that developers working on projects in this category don't take + security advice from the documentation of random open source projects :-) Let's experiment! ----------------- @@ -185,6 +185,8 @@ connection: .. code-block:: python + from websockets.frames import CloseCode + async def first_message_handler(websocket): token = await websocket.recv() user = get_user(token) @@ -212,24 +214,16 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.server import WebSocketServerProtocol - - class QueryParamProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - token = get_query_parameter(path, "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + async def query_param_auth(connection, request): + token = get_query_param(request.path, "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - self.user = user + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") - async def query_param_handler(websocket): - user = websocket.user - - ... + connection.username = user Cookie ...... @@ -260,27 +254,19 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.server import WebSocketServerProtocol - - class CookieProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - # Serve iframe on non-WebSocket requests - ... - - token = get_cookie(headers.get("Cookie", ""), "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" + async def cookie_auth(connection, request): + # Serve iframe on non-WebSocket requests + ... - self.user = user + token = get_cookie(request.headers.get("Cookie", ""), "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - async def cookie_handler(websocket): - user = websocket.user + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") - ... + connection.username = user User information ................ @@ -303,24 +289,12 @@ the user. If authentication fails, it returns an HTTP 401: .. code-block:: python - from websockets.legacy.auth import BasicAuthWebSocketServerProtocol - - class UserInfoProtocol(BasicAuthWebSocketServerProtocol): - async def check_credentials(self, username, password): - if username != "token": - return False - - user = get_user(password) - if user is None: - return False + from websockets.asyncio.server import basic_auth as websockets_basic_auth - self.user = user - return True + def check_credentials(username, password): + return username == get_user(password) - async def user_info_handler(websocket): - user = websocket.user - - ... + basic_auth = websockets_basic_auth(check_credentials=check_credentials) Machine-to-machine authentication --------------------------------- @@ -334,11 +308,9 @@ To authenticate a websockets client with HTTP Basic Authentication .. code-block:: python - from websockets.legacy.client import connect + from websockets.asyncio.client import connect - async with connect( - f"wss://{username}:{password}@example.com" - ) as websocket: + async with connect(f"wss://{username}:{password}@.../") as websocket: ... (You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they @@ -349,10 +321,8 @@ To authenticate a websockets client with HTTP Bearer Authentication .. code-block:: python - from websockets.legacy.client import connect + from websockets.asyncio.client import connect - async with connect( - "wss://example.com", - extra_headers={"Authorization": f"Bearer {token}"} - ) as websocket: + headers = {"Authorization": f"Bearer {token}"} + async with connect("wss://.../", additional_headers=headers) as websocket: ... diff --git a/experiments/authentication/app.py b/experiments/authentication/app.py index e3b2cf1f6..0bdd7fd2f 100644 --- a/experiments/authentication/app.py +++ b/experiments/authentication/app.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio +import email.utils import http import http.cookies import pathlib @@ -8,9 +9,10 @@ import urllib.parse import uuid +from websockets.asyncio.server import basic_auth as websockets_basic_auth, serve +from websockets.datastructures import Headers from websockets.frames import CloseCode -from websockets.legacy.auth import BasicAuthWebSocketServerProtocol -from websockets.legacy.server import WebSocketServerProtocol, serve +from websockets.http11 import Response # User accounts database @@ -49,7 +51,19 @@ def get_query_param(path, key): return values[0] -# Main HTTP server +# WebSocket handler + + +async def handler(websocket): + try: + user = websocket.username + except AttributeError: + return + + await websocket.send(f"Hello {user}!") + message = await websocket.recv() + assert message == f"Goodbye {user}." + CONTENT_TYPES = { ".css": "text/css", @@ -59,9 +73,10 @@ def get_query_param(path, key): } -async def serve_html(path, request_headers): - user = get_query_param(path, "user") - path = urllib.parse.urlparse(path).path +async def serve_html(connection, request): + """Basic HTTP server implemented as a process_request hook.""" + user = get_query_param(request.path, "user") + path = urllib.parse.urlparse(request.path).path if path == "/": if user is None: page = "index.html" @@ -76,147 +91,96 @@ async def serve_html(path, request_headers): pass else: if template.is_file(): - headers = {"Content-Type": CONTENT_TYPES[template.suffix]} body = template.read_bytes() if user is not None: token = create_token(user) body = body.replace(b"TOKEN", token.encode()) - return http.HTTPStatus.OK, headers, body - - return http.HTTPStatus.NOT_FOUND, {}, b"Not found\n" - + headers = Headers( + { + "Date": email.utils.formatdate(usegmt=True), + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": CONTENT_TYPES[template.suffix], + } + ) + return Response(200, "OK", headers, body) -async def noop_handler(websocket): - pass - - -# Send credentials as the first message in the WebSocket connection + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not found\n") async def first_message_handler(websocket): + """Handler that sends credentials in the first WebSocket message.""" token = await websocket.recv() user = get_user(token) if user is None: await websocket.close(CloseCode.INTERNAL_ERROR, "authentication failed") return - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Add credentials to the WebSocket URI in a query parameter - - -class QueryParamProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - token = get_query_param(path, "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user - - -async def query_param_handler(websocket): - user = websocket.user - - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Set a cookie on the domain of the WebSocket URI - - -class CookieProtocol(WebSocketServerProtocol): - async def process_request(self, path, headers): - if "Upgrade" not in headers: - template = pathlib.Path(__file__).with_name(path[1:]) - headers = {"Content-Type": CONTENT_TYPES[template.suffix]} - body = template.read_bytes() - return http.HTTPStatus.OK, headers, body - - token = get_cookie(headers.get("Cookie", ""), "token") - if token is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Missing token\n" - - user = get_user(token) - if user is None: - return http.HTTPStatus.UNAUTHORIZED, [], b"Invalid token\n" - - self.user = user + websocket.username = user + await handler(websocket) -async def cookie_handler(websocket): - user = websocket.user +async def query_param_auth(connection, request): + """Authenticate user from token in query parameter.""" + token = get_query_param(request.path, "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." - - -# Adding credentials to the WebSocket URI in user information - - -class UserInfoProtocol(BasicAuthWebSocketServerProtocol): - async def check_credentials(self, username, password): - if username != "token": - return False - - user = get_user(password) - if user is None: - return False + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") + + connection.username = user + + +async def cookie_auth(connection, request): + """Authenticate user from token in cookie.""" + if "Upgrade" not in request.headers: + template = pathlib.Path(__file__).with_name(request.path[1:]) + body = template.read_bytes() + headers = Headers( + { + "Date": email.utils.formatdate(usegmt=True), + "Connection": "close", + "Content-Length": str(len(body)), + "Content-Type": CONTENT_TYPES[template.suffix], + } + ) + return Response(200, "OK", headers, body) + + token = get_cookie(request.headers.get("Cookie", ""), "token") + if token is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Missing token\n") - self.user = user - return True + user = get_user(token) + if user is None: + return connection.respond(http.HTTPStatus.UNAUTHORIZED, "Invalid token\n") + connection.username = user -async def user_info_handler(websocket): - user = websocket.user - await websocket.send(f"Hello {user}!") - message = await websocket.recv() - assert message == f"Goodbye {user}." +def check_credentials(username, password): + """Authenticate user with HTTP Basic Auth.""" + return username == get_user(password) -# Start all five servers +basic_auth = websockets_basic_auth(check_credentials=check_credentials) async def main(): + """Start one HTTP server and four WebSocket servers.""" # Set the stop condition when receiving SIGINT or SIGTERM. loop = asyncio.get_running_loop() stop = loop.create_future() loop.add_signal_handler(signal.SIGINT, stop.set_result, None) loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with serve( - noop_handler, - host="", - port=8000, - process_request=serve_html, - ), serve( - first_message_handler, - host="", - port=8001, - ), serve( - query_param_handler, - host="", - port=8002, - create_protocol=QueryParamProtocol, - ), serve( - cookie_handler, - host="", - port=8003, - create_protocol=CookieProtocol, - ), serve( - user_info_handler, - host="", - port=8004, - create_protocol=UserInfoProtocol, + async with ( + serve(handler, host="", port=8000, process_request=serve_html), + serve(first_message_handler, host="", port=8001), + serve(handler, host="", port=8002, process_request=query_param_auth), + serve(handler, host="", port=8003, process_request=cookie_auth), + serve(handler, host="", port=8004, process_request=basic_auth), ): print("Running on http://localhost:8000/") await stop diff --git a/experiments/authentication/script.js b/experiments/authentication/script.js index ec4e5e670..01dd5b168 100644 --- a/experiments/authentication/script.js +++ b/experiments/authentication/script.js @@ -1,4 +1,5 @@ -var token = window.parent.token; +var token = window.parent.token, + user = window.parent.user; function getExpectedEvents() { return [ @@ -7,7 +8,7 @@ function getExpectedEvents() { }, { type: "message", - data: `Hello ${window.parent.user}!`, + data: `Hello ${user}!`, }, { type: "close", diff --git a/experiments/authentication/test.js b/experiments/authentication/test.js index 428830ff3..e05ca697e 100644 --- a/experiments/authentication/test.js +++ b/experiments/authentication/test.js @@ -1,6 +1,4 @@ -// for connecting to WebSocket servers var token = document.body.dataset.token; -// for test assertions only const params = new URLSearchParams(window.location.search); var user = params.get("user"); diff --git a/experiments/authentication/user_info.js b/experiments/authentication/user_info.js index 1dab2ce4c..bc9a3f148 100644 --- a/experiments/authentication/user_info.js +++ b/experiments/authentication/user_info.js @@ -1,5 +1,5 @@ window.addEventListener("DOMContentLoaded", () => { - const uri = `ws://token:${token}@localhost:8004/`; + const uri = `ws://${user}:${token}@localhost:8004/`; const websocket = new WebSocket(uri); websocket.onmessage = ({ data }) => { From bc4b8f2776cd4da9aee2e67f66764bf26a5ad09e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 30 Sep 2024 21:07:28 +0200 Subject: [PATCH 644/762] Add option to force sending text or binary frames. Fix #1515. --- src/websockets/asyncio/connection.py | 124 +++++++++++++++------------ tests/asyncio/test_connection.py | 72 ++++++++++++++-- 2 files changed, 134 insertions(+), 62 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 702e69995..12871e4b3 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -251,12 +251,13 @@ async def recv(self, decode: bool | None = None) -> Data: You may override this behavior with the ``decode`` argument: - * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames - and return a bytestring (:class:`bytes`). This may be useful to - optimize performance when decoding isn't needed. + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return a string (:class:`str`). This is useful for servers - that send binary frames instead of text frames. + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -333,7 +334,11 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data "is already running recv or recv_streaming" ) from None - async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + async def send( + self, + message: Data | Iterable[Data] | AsyncIterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -344,6 +349,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. @@ -393,12 +409,20 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No # strings and bytes-like objects are iterable. if isinstance(message, str): - async with self.send_context(): - self.protocol.send_text(message.encode()) + if text is False: + async with self.send_context(): + self.protocol.send_binary(message.encode()) + else: + async with self.send_context(): + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): - async with self.send_context(): - self.protocol.send_binary(message) + if text is True: + async with self.send_context(): + self.protocol.send_text(message) + else: + async with self.send_context(): + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -419,36 +443,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("iterable must contain uniform types") @@ -481,36 +501,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("async iterable must contain bytes or str") # Other fragments async for chunk in achunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("async iterable must contain uniform types") diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 70d9dad63..563cf2b17 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -190,13 +190,13 @@ async def test_recv_binary(self): await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") - async def test_recv_encoded_text(self): - """recv receives an UTF-8 encoded text message.""" + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) - async def test_recv_decoded_binary(self): - """recv receives an UTF-8 decoded binary message.""" + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual(await self.connection.recv(decode=True), "😀") @@ -304,16 +304,16 @@ async def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) - async def test_recv_streaming_encoded_text(self): - """recv_streaming receives an UTF-8 encoded text message.""" + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual( await alist(self.connection.recv_streaming(decode=False)), ["😀".encode()], ) - async def test_recv_streaming_decoded_binary(self): - """recv_streaming receives a UTF-8 decoded binary message.""" + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual( await alist(self.connection.recv_streaming(decode=True)), @@ -438,6 +438,16 @@ async def test_send_binary(self): await self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + async def test_send_fragmented_text(self): """send sends a fragmented text message.""" await self.connection.send(["😀", "😀"]) @@ -456,6 +466,24 @@ async def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_async_fragmented_text(self): """send sends a fragmented text message asynchronously.""" @@ -484,6 +512,34 @@ async def fragments(): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" await self.remote_connection.close() From 7fdd932c6b29a9ef4db46b2141371f42207b4f00 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2024 14:21:00 +0200 Subject: [PATCH 645/762] Review & update gitignore. --- .gitignore | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d8e6697a8..291bf1fb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ *.pyc *.so .coverage -.direnv +.direnv/ .envrc .idea/ -.mypy_cache -.tox +.mypy_cache/ +.tox/ +.vscode/ build/ compliance/reports/ -experiments/compression/corpus/ dist/ docs/_build/ +experiments/compression/corpus/ htmlcov/ -MANIFEST -websockets.egg-info/ +src/websockets.egg-info/ From c5985d5c4192390b2da58ae97015e3ab1ba41cd2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 16:21:16 +0200 Subject: [PATCH 646/762] Update compliance test suite. --- Makefile | 4 +- compliance/README.rst | 74 ++++++++++++++++++---------- compliance/asyncio/client.py | 52 +++++++++++++++++++ compliance/asyncio/server.py | 32 ++++++++++++ compliance/config/fuzzingclient.json | 11 +++++ compliance/config/fuzzingserver.json | 7 +++ compliance/fuzzingclient.json | 11 ----- compliance/fuzzingserver.json | 12 ----- compliance/sync/client.py | 51 +++++++++++++++++++ compliance/sync/server.py | 31 ++++++++++++ compliance/test_client.py | 48 ------------------ compliance/test_server.py | 29 ----------- 12 files changed, 234 insertions(+), 128 deletions(-) create mode 100644 compliance/asyncio/client.py create mode 100644 compliance/asyncio/server.py create mode 100644 compliance/config/fuzzingclient.json create mode 100644 compliance/config/fuzzingserver.json delete mode 100644 compliance/fuzzingclient.json delete mode 100644 compliance/fuzzingserver.json create mode 100644 compliance/sync/client.py create mode 100644 compliance/sync/server.py delete mode 100644 compliance/test_client.py delete mode 100644 compliance/test_server.py diff --git a/Makefile b/Makefile index fd36d0367..06bfe9edc 100644 --- a/Makefile +++ b/Makefile @@ -8,8 +8,8 @@ build: python setup.py build_ext --inplace style: - ruff format src tests - ruff check --fix src tests + ruff format compliance src tests + ruff check --fix compliance src tests types: mypy --strict src diff --git a/compliance/README.rst b/compliance/README.rst index 8570f9176..c7c7c93b4 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -4,47 +4,69 @@ Autobahn Testsuite General information and installation instructions are available at https://github.com/crossbario/autobahn-testsuite. -To improve performance, you should compile the C extension first:: +Running the test suite +---------------------- + +All commands below must be run from the root directory of the repository. + +To get acceptable performance, compile the C extension first: + +.. code-block:: console $ python setup.py build_ext --inplace -Running the test suite ----------------------- +Run each command in a different shell. Testing takes several minutes to complete +— wstest is the bottleneck. When clients finish, stop servers with Ctrl-C. + +You can exclude slow tests by modifying the configuration files as follows:: + + "exclude-cases": ["9.*", "12.*", "13.*"] -All commands below must be run from the directory containing this file. +The test server and client applications shouldn't display any exceptions. -To test the server:: +To test the servers: - $ PYTHONPATH=.. python test_server.py - $ wstest -m fuzzingclient +.. code-block:: console -To test the client:: + $ PYTHONPATH=src python compliance/asyncio/server.py + $ PYTHONPATH=src python compliance/sync/server.py - $ wstest -m fuzzingserver - $ PYTHONPATH=.. python test_client.py + $ docker run --interactive --tty --rm \ + --volume "${PWD}/compliance/config:/config" \ + --volume "${PWD}/compliance/reports:/reports" \ + --name fuzzingclient \ + crossbario/autobahn-testsuite \ + wstest --mode fuzzingclient --spec /config/fuzzingclient.json -Run the first command in a shell. Run the second command in another shell. -It should take about ten minutes to complete — wstest is the bottleneck. -Then kill the first one with Ctrl-C. + $ open reports/servers/index.html -The test client or server shouldn't display any exceptions. The results are -stored in reports/clients/index.html. +To test the clients: -Note that the Autobahn software only supports Python 2, while ``websockets`` -only supports Python 3; you need two different environments. +.. code-block:: console + $ docker run --interactive --tty --rm \ + --volume "${PWD}/compliance/config:/config" \ + --volume "${PWD}/compliance/reports:/reports" \ + --publish 9001:9001 \ + --name fuzzingserver \ + crossbario/autobahn-testsuite \ + wstest --mode fuzzingserver --spec /config/fuzzingserver.json + + $ PYTHONPATH=src python compliance/asyncio/client.py + $ PYTHONPATH=src python compliance/sync/client.py + + $ open reports/clients/index.html Conformance notes ----------------- Some test cases are more strict than the RFC. Given the implementation of the -library and the test echo client or server, ``websockets`` gets a "Non-Strict" -in these cases. - -In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 ``websockets`` notices the -protocol error and closes the connection before it has had a chance to echo -the previous frame. +library and the test client and server applications, websockets passes with a +"Non-Strict" result in these cases. -In 6.4.3 and 6.4.4, even though it uses an incremental decoder, ``websockets`` -doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These -tests are more strict than the RFC. +In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 websockets notices the +protocol error and closes the connection at the library level before the +application gets a chance to echo the previous frame. +In 6.4.3 and 6.4.4, even though it uses an incremental decoder, websockets +doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests +are more strict than the RFC. diff --git a/compliance/asyncio/client.py b/compliance/asyncio/client.py new file mode 100644 index 000000000..5b0bfb3ae --- /dev/null +++ b/compliance/asyncio/client.py @@ -0,0 +1,52 @@ +import asyncio +import json +import logging + +from websockets.asyncio.client import connect +from websockets.exceptions import WebSocketException + + +logging.basicConfig(level=logging.WARNING) + +SERVER = "ws://127.0.0.1:9001" + + +async def get_case_count(): + async with connect(f"{SERVER}/getCaseCount") as ws: + return json.loads(await ws.recv()) + + +async def run_case(case): + async with connect( + f"{SERVER}/runCase?case={case}", + user_agent_header="websockets.asyncio", + max_size=2**25, + ) as ws: + async for msg in ws: + await ws.send(msg) + + +async def update_reports(): + async with connect(f"{SERVER}/updateReports", open_timeout=60): + pass + + +async def main(): + cases = await get_case_count() + for case in range(1, cases + 1): + print(f"Running test case {case:03d} / {cases}... ", end="\t") + try: + await run_case(case) + except WebSocketException as exc: + print(f"ERROR: {type(exc).__name__}: {exc}") + except Exception as exc: + print(f"FAIL: {type(exc).__name__}: {exc}") + else: + print("OK") + print("Ran {cases} test cases") + await update_reports() + print("Updated reports") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/compliance/asyncio/server.py b/compliance/asyncio/server.py new file mode 100644 index 000000000..cff2728c9 --- /dev/null +++ b/compliance/asyncio/server.py @@ -0,0 +1,32 @@ +import asyncio +import logging + +from websockets.asyncio.server import serve + + +logging.basicConfig(level=logging.WARNING) + +HOST, PORT = "0.0.0.0", 9002 + + +async def echo(ws): + async for msg in ws: + await ws.send(msg) + + +async def main(): + async with serve( + echo, + HOST, + PORT, + server_header="websockets.sync", + max_size=2**25, + ) as server: + try: + await server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/compliance/config/fuzzingclient.json b/compliance/config/fuzzingclient.json new file mode 100644 index 000000000..37f8f3a35 --- /dev/null +++ b/compliance/config/fuzzingclient.json @@ -0,0 +1,11 @@ + +{ + "servers": [{ + "url": "ws://host.docker.internal:9002" + }, { + "url": "ws://host.docker.internal:9003" + }], + "outdir": "./reports/servers", + "cases": ["*"], + "exclude-cases": [] +} diff --git a/compliance/config/fuzzingserver.json b/compliance/config/fuzzingserver.json new file mode 100644 index 000000000..1bcb5959d --- /dev/null +++ b/compliance/config/fuzzingserver.json @@ -0,0 +1,7 @@ + +{ + "url": "ws://localhost:9001", + "outdir": "./reports/clients", + "cases": ["*"], + "exclude-cases": [] +} diff --git a/compliance/fuzzingclient.json b/compliance/fuzzingclient.json deleted file mode 100644 index 202ff49a0..000000000 --- a/compliance/fuzzingclient.json +++ /dev/null @@ -1,11 +0,0 @@ - -{ - "options": {"failByDrop": false}, - "outdir": "./reports/servers", - - "servers": [{"agent": "websockets", "url": "ws://localhost:8642", "options": {"version": 18}}], - - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/compliance/fuzzingserver.json b/compliance/fuzzingserver.json deleted file mode 100644 index 1bdb42723..000000000 --- a/compliance/fuzzingserver.json +++ /dev/null @@ -1,12 +0,0 @@ - -{ - "url": "ws://localhost:8642", - - "options": {"failByDrop": false}, - "outdir": "./reports/clients", - "webport": 8080, - - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/compliance/sync/client.py b/compliance/sync/client.py new file mode 100644 index 000000000..e585496f3 --- /dev/null +++ b/compliance/sync/client.py @@ -0,0 +1,51 @@ +import json +import logging + +from websockets.exceptions import WebSocketException +from websockets.sync.client import connect + + +logging.basicConfig(level=logging.WARNING) + +SERVER = "ws://127.0.0.1:9001" + + +def get_case_count(): + with connect(f"{SERVER}/getCaseCount") as ws: + return json.loads(ws.recv()) + + +def run_case(case): + with connect( + f"{SERVER}/runCase?case={case}", + user_agent_header="websockets.sync", + max_size=2**25, + ) as ws: + for msg in ws: + ws.send(msg) + + +def update_reports(): + with connect(f"{SERVER}/updateReports", open_timeout=60): + pass + + +def main(): + cases = get_case_count() + for case in range(1, cases + 1): + print(f"Running test case {case:03d} / {cases}... ", end="\t") + try: + run_case(case) + except WebSocketException as exc: + print(f"ERROR: {type(exc).__name__}: {exc}") + except Exception as exc: + print(f"FAIL: {type(exc).__name__}: {exc}") + else: + print("OK") + print("Ran {cases} test cases") + update_reports() + print("Updated reports") + + +if __name__ == "__main__": + main() diff --git a/compliance/sync/server.py b/compliance/sync/server.py new file mode 100644 index 000000000..c3cb4d989 --- /dev/null +++ b/compliance/sync/server.py @@ -0,0 +1,31 @@ +import logging + +from websockets.sync.server import serve + + +logging.basicConfig(level=logging.WARNING) + +HOST, PORT = "0.0.0.0", 9003 + + +def echo(ws): + for msg in ws: + ws.send(msg) + + +def main(): + with serve( + echo, + HOST, + PORT, + server_header="websockets.asyncio", + max_size=2**25, + ) as server: + try: + server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/compliance/test_client.py b/compliance/test_client.py deleted file mode 100644 index 8e22569fd..000000000 --- a/compliance/test_client.py +++ /dev/null @@ -1,48 +0,0 @@ -import asyncio -import json -import logging -import urllib.parse - -from websockets.asyncio.client import connect - - -logging.basicConfig(level=logging.WARNING) - -# Uncomment this line to make only websockets more verbose. -# logging.getLogger('websockets').setLevel(logging.DEBUG) - - -SERVER = "ws://127.0.0.1:8642" -AGENT = "websockets" - - -async def get_case_count(server): - uri = f"{server}/getCaseCount" - async with connect(uri) as ws: - msg = ws.recv() - return json.loads(msg) - - -async def run_case(server, case, agent): - uri = f"{server}/runCase?case={case}&agent={agent}" - async with connect(uri, max_size=2 ** 25, max_queue=1) as ws: - async for msg in ws: - await ws.send(msg) - - -async def update_reports(server, agent): - uri = f"{server}/updateReports?agent={agent}" - async with connect(uri): - pass - - -async def run_tests(server, agent): - cases = await get_case_count(server) - for case in range(1, cases + 1): - print(f"Running test case {case} out of {cases}", end="\r") - await run_case(server, case, agent) - print(f"Ran {cases} test cases ") - await update_reports(server, agent) - - -asyncio.run(run_tests(SERVER, urllib.parse.quote(AGENT))) diff --git a/compliance/test_server.py b/compliance/test_server.py deleted file mode 100644 index 39176e902..000000000 --- a/compliance/test_server.py +++ /dev/null @@ -1,29 +0,0 @@ -import asyncio -import logging - -from websockets.asyncio.server import serve - - -logging.basicConfig(level=logging.WARNING) - -# Uncomment this line to make only websockets more verbose. -# logging.getLogger('websockets').setLevel(logging.DEBUG) - - -HOST, PORT = "127.0.0.1", 8642 - - -async def echo(ws): - async for msg in ws: - await ws.send(msg) - - -async def main(): - with serve(echo, HOST, PORT, max_size=2 ** 25, max_queue=1): - try: - await asyncio.get_running_loop().create_future() # run forever - except KeyboardInterrupt: - pass - - -asyncio.run(main()) From b2f0a7647f1402c84a8dabb391c3ca7371975eb3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 21:32:30 +0200 Subject: [PATCH 647/762] Add option to force sending text or binary frames. This adds the same functionality to the threading implemetation as bc4b8f2 did to the asyncio implementation. Refs #1515. --- src/websockets/asyncio/connection.py | 43 +++++++++++--------- src/websockets/sync/connection.py | 61 +++++++++++++++++----------- tests/sync/test_connection.py | 28 +++++++++++++ 3 files changed, 90 insertions(+), 42 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 12871e4b3..3b81e386b 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -409,19 +409,17 @@ async def send( # strings and bytes-like objects are iterable. if isinstance(message, str): - if text is False: - async with self.send_context(): + async with self.send_context(): + if text is False: self.protocol.send_binary(message.encode()) - else: - async with self.send_context(): + else: self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): - if text is True: - async with self.send_context(): + async with self.send_context(): + if text is True: self.protocol.send_text(message) - else: - async with self.send_context(): + else: self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -443,19 +441,17 @@ async def send( try: # First fragment. if isinstance(chunk, str): - if text is False: - async with self.send_context(): + async with self.send_context(): + if text is False: self.protocol.send_binary(chunk.encode(), fin=False) - else: - async with self.send_context(): + else: self.protocol.send_text(chunk.encode(), fin=False) encode = True elif isinstance(chunk, BytesLike): - if text is True: - async with self.send_context(): + async with self.send_context(): + if text is True: self.protocol.send_text(chunk, fin=False) - else: - async with self.send_context(): + else: self.protocol.send_binary(chunk, fin=False) encode = False else: @@ -480,7 +476,10 @@ async def send( # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -538,7 +537,10 @@ async def send( # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -568,7 +570,10 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # to terminate after calling a method that sends a close frame. async with self.send_context(): if self.fragmented_send_waiter is not None: - self.protocol.fail(1011, "close during fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) else: self.protocol.send_close(code, reason) except ConnectionClosed: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 8c5df9592..3f4cac09f 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -251,7 +251,11 @@ def recv_streaming(self) -> Iterator[Data]: "is already running recv or recv_streaming" ) from None - def send(self, message: Data | Iterable[Data]) -> None: + def send( + self, + message: Data | Iterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -262,6 +266,17 @@ def send(self, message: Data | Iterable[Data]) -> None: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the @@ -300,7 +315,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_text(message.encode()) + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): @@ -309,7 +327,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_binary(message) + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -328,7 +349,6 @@ def send(self, message: Data | Iterable[Data]) -> None: try: # First fragment. if isinstance(chunk, str): - text = True with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -336,12 +356,12 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -349,29 +369,24 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("data iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("data iterable must contain uniform types") diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 16f92e164..87333fd35 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -308,6 +308,16 @@ def test_send_binary(self): self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff") + def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + self.connection.send("😀", text=False) + self.assertEqual(self.remote_connection.recv(), "😀".encode()) + + def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + self.connection.send("😀".encode(), text=True) + self.assertEqual(self.remote_connection.recv(), "😀") + def test_send_fragmented_text(self): """send sends a fragmented text message.""" self.connection.send(["😀", "😀"]) @@ -326,6 +336,24 @@ def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close() From e5182c95a3332535a034c409d59463afbd760f0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 21:45:17 +0200 Subject: [PATCH 648/762] Blind fix for coverage failing in GitHub Actions. It doesn't fail locally. --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a0ab8d7c..4e26c757e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,15 +60,19 @@ source = [ [tool.coverage.report] exclude_lines = [ + "pragma: no cover", "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", "if typing.TYPE_CHECKING:", - "pragma: no cover", "raise AssertionError", "self.fail\\(\".*\"\\)", "@unittest.skip", ] +partial_branches = [ + "pragma: no branch", + "with self.assertRaises\\(.*\\)", +] [tool.ruff] target-version = "py312" From 1387c976833956ea4d44ed0bd541fe648a065ed7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 13:26:52 +0200 Subject: [PATCH 649/762] Rewrite sync Assembler to improve performance. Previously, a latch was used to synchronize the user thread reading messages and the background thread reading from the network. This required two thread switches per message. Now, the background thread writes messages to queue, from which the user thread reads. This allows passing several frames at each thread switch, reducing the overhead. With this server code: async def test(websocket): for i in range(int(await websocket.recv())): await websocket.send(f"{{\"iteration\": {i}}}") async with serve(test, "localhost", 8765) as server: await server.serve_forever() and this client code: with connect("ws://localhost:8765", compression=None) as websocket: websocket.send("1_000_000") for message in websocket: pass an unscientific benchmark (running it on my laptop) shows a 2.5x speedup, going from 11 seconds to 4.4 seconds. Setting a very large recv_bufsize and max_size doesn't yield significant further improvement. Flow control was tested by inserting debug logs in maybe_pause/resume() and by measuring the wait for the recv_flow_control lock. It showed the expected behavior of pausing and unpausing coupled with some wait time. The new implementation mirrors the asyncio implementation and gains the option to prevent or force decoding of frames. Fix #1376 for the threading implementation. --- docs/project/changelog.rst | 16 +- src/websockets/asyncio/client.py | 2 +- src/websockets/asyncio/messages.py | 59 +++-- src/websockets/asyncio/server.py | 2 +- src/websockets/sync/client.py | 12 +- src/websockets/sync/connection.py | 79 ++++-- src/websockets/sync/messages.py | 333 ++++++++++++------------ src/websockets/sync/server.py | 12 +- tests/asyncio/test_connection.py | 19 +- tests/asyncio/test_messages.py | 12 +- tests/sync/test_connection.py | 106 +++++--- tests/sync/test_messages.py | 400 ++++++++++++++--------------- 12 files changed, 585 insertions(+), 467 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f5b4812bd..410671239 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -70,10 +70,21 @@ Backwards-incompatible changes If you wrote an :class:`extension ` that relies on methods not provided by these new types, you may need to update your code. +New features +............ + +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`threading` implementation; also binary frames as :class:`str`. + +* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio` + and :mod:`threading` implementations, as well as :class:`str` a binary frame. + Improvements ............ -* Sending or receiving large compressed frames is now faster. +* The :mod:`threading` implementation receives messages faster. + +* Sending or receiving large compressed messages is now faster. .. _13.1: @@ -198,6 +209,9 @@ New features * Validated compatibility with Python 3.12 and 3.13. +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`asyncio` implementation; also binary frames as :class:`str`. + * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 23b1a348a..0c8bedc5d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -45,7 +45,7 @@ class ClientConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`connect`. + and ``write_limit`` arguments have the same meaning as in :func:`connect`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e3ec5062f..09be22ba2 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -60,6 +60,7 @@ def reset(self, items: Iterable[T]) -> None: self.queue.extend(items) def abort(self) -> None: + """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) # Clear the queue to avoid storing unnecessary data in memory. @@ -89,7 +90,7 @@ def __init__( # pragma: no cover pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: - # Queue of incoming messages. Each item is a queue of frames. + # Queue of incoming frames. self.frames: SimpleQueue[Frame] = SimpleQueue() # We cannot put a hard limit on the size of the queue because a single @@ -140,36 +141,35 @@ async def get(self, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # First frame + # Locking with get_in_progress prevents concurrent execution until + # get() fetches a complete message or is cancelled. + try: + # First frame frame = await self.frames.get() - except asyncio.CancelledError: - self.get_in_progress = False - raise - self.maybe_resume() - assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY - if decode is None: - decode = frame.opcode is OP_TEXT - frames = [frame] - - # Following frames, for fragmented messages - while not frame.fin: - try: - frame = await self.frames.get() - except asyncio.CancelledError: - # Put frames already received back into the queue - # so that future calls to get() can return them. - self.frames.reset(frames) - self.get_in_progress = False - raise self.maybe_resume() - assert frame.opcode is OP_CONT - frames.append(frame) - - self.get_in_progress = False + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False data = b"".join(frame.data for frame in frames) if decode: @@ -207,9 +207,14 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + # First frame try: frame = await self.frames.get() diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index e11dd91f1..a6ae5996d 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -55,7 +55,7 @@ class ServerConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`serve`. + and ``write_limit`` arguments have the same meaning as in :func:`serve`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 5e1ba6d84..42daa32ea 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -40,10 +40,12 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`connect`. + Args: socket: Socket connected to a WebSocket server. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -53,6 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -60,6 +63,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) def handshake( @@ -135,6 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -183,6 +188,10 @@ def connect( :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -287,6 +296,7 @@ def connect( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: if sock is not None: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 3f4cac09f..3ab9f4937 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,10 +49,14 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int): + max_queue = (max_queue, None) + self.max_queue = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -76,8 +80,15 @@ def __init__( # Mutex serializing interactions with the protocol. self.protocol_mutex = threading.Lock() + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = threading.Lock() + # Assembler turning frames into messages and serializing reads. - self.recv_messages = Assembler() + self.recv_messages = Assembler( + *self.max_queue, + pause=self.recv_flow_control.acquire, + resume=self.recv_flow_control.release, + ) # Whether we are busy sending a fragmented message. self.send_in_progress = False @@ -88,6 +99,10 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: dict[bytes, threading.Event] = {} + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + # Receiving events from the socket. This thread is marked as daemon to # allow creating a connection in a non-daemon thread and using it in a # daemon thread. This mustn't prevent the interpreter from exiting. @@ -97,10 +112,6 @@ def __init__( ) self.recv_events_thread.start() - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - # Public attributes @property @@ -172,7 +183,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: float | None = None) -> Data: + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -191,6 +202,11 @@ def recv(self, timeout: float | None = None) -> Data: If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. + Args: + timeout: Timeout for receiving a message in seconds. + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -198,6 +214,16 @@ def recv(self, timeout: float | None = None) -> Data: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -205,7 +231,7 @@ def recv(self, timeout: float | None = None) -> Data: """ try: - return self.recv_messages.get(timeout) + return self.recv_messages.get(timeout, decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -216,16 +242,23 @@ def recv(self, timeout: float | None = None) -> Data: "is already running recv or recv_streaming" ) from None - def recv_streaming(self) -> Iterator[Data]: + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. - If the message is fragmented, yield each fragment as it is received. - The iterator must be fully consumed, or else the connection will become + This method is designed for receiving fragmented messages. It returns an + iterator that yields each fragment as it is received. This iterator must + be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -233,6 +266,15 @@ def recv_streaming(self) -> Iterator[Data]: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -240,7 +282,7 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + yield from self.recv_messages.get_iter(decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -571,8 +613,9 @@ def recv_events(self) -> None: try: while True: try: - if self.close_deadline is not None: - self.socket.settimeout(self.close_deadline.timeout()) + with self.recv_flow_control: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: @@ -622,13 +665,9 @@ def recv_events(self) -> None: # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. - try: - for event in events: - # This may raise EOFError if the closing handshake - # times out while a message is waiting to be read. - self.process_event(event) - except EOFError: - break + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 997fa98df..983b114dc 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,12 +3,12 @@ import codecs import queue import threading -from collections.abc import Iterator -from typing import cast +from typing import Any, Callable, Iterable, Iterator from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data +from .utils import Deadline __all__ = ["Assembler"] @@ -20,47 +20,83 @@ class Assembler: """ Assemble messages from frames. + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + """ - def __init__(self) -> None: + def __init__( + self, + high: int = 16, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: # Serialize reads and writes -- except for reads via synchronization # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to synchronize the production of - # frames and the consumption of messages (or frames) without a buffer. - # This design requires a switch between the library thread and the user - # thread for each message; that shouldn't be a performance bottleneck. - - # put() sets this event to tell get() that a message can be fetched. - self.message_complete = threading.Event() - # get() sets this event to let put() that the message was fetched. - self.message_fetched = threading.Event() + # Queue of incoming frames. + self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if low is None: + low = high // 4 + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False - # This flag prevents concurrent calls to put() by library code. - self.put_in_progress = False - - # Decoder for text frames, None for binary frames. - self.decoder: codecs.IncrementalDecoder | None = None - - # Buffer of frames belonging to the same message. - self.chunks: list[Data] = [] - - # When switching from "buffering" to "streaming", we use a thread-safe - # queue for transferring frames from the writing thread (library code) - # to the reading thread (user code). We're buffering when chunks_queue - # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the message, superseding message_complete. - - # Stream data from frames belonging to the same message. - self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: float | None = None) -> Data: + def get_next_frame(self, timeout: float | None = None) -> Frame: + # Helper to factor out the logic for getting the next frame from the + # queue, while handling timeouts and reaching the end of the stream. + try: + frame = self.frames.get(timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if frame is None: + raise EOFError("stream of frames ended") + return frame + + def reset_queue(self, frames: Iterable[Frame]) -> None: + # Helper to put frames back into the queue after they were fetched. + # This happens only when the queue is empty. However, by the time + # we acquire self.mutex, put() may have added items in the queue. + # Therefore, we must handle the case where the queue is not empty. + frame: Frame | None + with self.mutex: + queued = [] + try: + while True: + queued.append(self.frames.get_nowait()) + except queue.Empty: + pass + for frame in frames: + self.frames.put(frame) + # This loop runs only when a race condition occurs. + for frame in queued: # pragma: no cover + self.frames.put(frame) + + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -73,11 +109,14 @@ def get(self, timeout: float | None = None) -> Data: Args: timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. @@ -89,40 +128,45 @@ def get(self, timeout: float | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") + # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True - # If the message_complete event isn't set yet, release the lock to - # allow put() to run and eventually set it. - # Locking with get_in_progress ensures only one thread can get here. - completed = self.message_complete.wait(timeout) + try: + deadline = Deadline(timeout) + + # First frame + frame = self.get_next_frame(deadline.timeout()) + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = self.get_next_frame(deadline.timeout()) + except TimeoutError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.reset_queue(frames) + raise + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) - with self.mutex: + finally: self.get_in_progress = False - # Waiting for a complete message timed out. - if not completed: - raise TimeoutError(f"timed out in {timeout:.1f}s") - - # get() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - joiner: Data = b"" if self.decoder is None else "" - # mypy cannot figure out that chunks have the proper type. - message: Data = joiner.join(self.chunks) # type: ignore + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data - self.chunks = [] - assert self.chunks_queue is None - - assert not self.message_fetched.is_set() - self.message_fetched.set() - - return message - - def get_iter(self) -> Iterator[Data]: + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. @@ -135,10 +179,15 @@ def get_iter(self) -> Iterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ with self.mutex: @@ -148,116 +197,81 @@ def get_iter(self) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - chunks = self.chunks - self.chunks = [] - self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.10. - "queue.SimpleQueue[Data | None]", - queue.SimpleQueue(), - ) - - # Sending None in chunk_queue supersedes setting message_complete - # when switching to "streaming". If message is already complete - # when the switch happens, put() didn't send None, so we have to. - if self.message_complete.is_set(): - self.chunks_queue.put(None) - + # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress ensures only one thread can get here. - chunk: Data | None - for chunk in chunks: - yield chunk - while (chunk := self.chunks_queue.get()) is not None: - yield chunk + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. - with self.mutex: - self.get_in_progress = False + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. - # get_iter() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - assert self.chunks == [] - self.chunks_queue = None + # First frame + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data - assert not self.message_fetched.is_set() - self.message_fetched.set() + self.get_in_progress = False def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. - When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, which can be achieved by calling :meth:`get` or - by fully consuming the return value of :meth:`get_iter`. - - :meth:`put` assumes that the stream of frames respects the protocol. If - it doesn't, the behavior is undefined. - Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`put` concurrently. """ with self.mutex: if self.closed: raise EOFError("stream of frames ended") - if self.put_in_progress: - raise ConcurrencyError("put is already running") - - if frame.opcode is OP_TEXT: - self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is OP_BINARY: - self.decoder = None - else: - assert frame.opcode is OP_CONT - - data: Data - if self.decoder is not None: - data = self.decoder.decode(frame.data, frame.fin) - else: - data = frame.data - - if self.chunks_queue is None: - self.chunks.append(data) - else: - self.chunks_queue.put(data) - - if not frame.fin: - return - - # Message is complete. Wait until it's fetched to return. - - assert not self.message_complete.is_set() - self.message_complete.set() - - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - assert not self.message_fetched.is_set() - - self.put_in_progress = True - - # Release the lock to allow get() to run and eventually set the event. - # Locking with put_in_progress ensures only one coroutine can get here. - self.message_fetched.wait() - - with self.mutex: - self.put_in_progress = False - - # put() was unblocked by close() rather than get() or get_iter(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_fetched.is_set() - self.message_fetched.clear() - - self.decoder = None + self.frames.put(frame) + self.maybe_pause() + + # put() and get/get_iter() call maybe_pause() and maybe_resume() while + # holding self.mutex. This guarantees that the calls interleave properly. + # Specifically, it prevents a race condition where maybe_resume() would + # run before maybe_pause(), leaving the connection incorrectly paused. + + # A race condition is possible when get/get_iter() call self.frames.get() + # without holding self.mutex. However, it's harmless — and even beneficial! + # It can only result in popping an item from the queue before maybe_resume() + # runs and skipping a pause() - resume() cycle that would otherwise occur. + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + assert self.mutex.locked() + # Check for "> high" to support high = 0 + if self.frames.qsize() > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + assert self.mutex.locked() + # Check for "<= low" to support low = 0 + if self.frames.qsize() <= self.low and self.paused: + self.paused = False + self.resume() def close(self) -> None: """ @@ -273,12 +287,5 @@ def close(self) -> None: self.closed = True - # Unblock get or get_iter. - if self.get_in_progress: - self.message_complete.set() - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - # Unblock put(). - if self.put_in_progress: - self.message_fetched.set() + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 464c4a173..94f76b658 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -51,10 +51,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`serve`. + Args: socket: Socket connected to a WebSocket client. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -64,6 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -71,6 +74,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) self.username: str # see basic_auth() @@ -349,6 +353,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -427,6 +432,10 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -548,6 +557,7 @@ def protocol_select_subprotocol( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: sock.close() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 563cf2b17..12e2bd5fa 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,8 +793,8 @@ async def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) async def test_close_does_not_wait_for_recv(self): - # The asyncio implementation has a buffer for incoming messages. Closing - # the connection discards buffered messages. This is allowed by the RFC: + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: # > However, there is no guarantee that the endpoint that has already # > sent a Close frame will continue to process data. await self.remote_connection.send("😀") @@ -1075,7 +1075,10 @@ async def test_max_queue(self): async def test_max_queue_tuple(self): """max_queue parameter configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + connection = Connection( + Protocol(self.LOCAL), + max_queue=(4, 2), + ) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) @@ -1083,14 +1086,20 @@ async def test_max_queue_tuple(self): async def test_write_limit(self): """write_limit parameter configures high-water mark of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=4096) + connection = Connection( + Protocol(self.LOCAL), + write_limit=4096, + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) async def test_write_limits(self): """write_limit parameter configures high and low-water marks of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + connection = Connection( + Protocol(self.LOCAL), + write_limit=(4096, 2048), + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index d2cf25c9c..2ff929d3a 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -350,7 +350,7 @@ async def test_cancel_get_iter_before_first_frame(self): self.assertEqual(fragments, ["café"]) async def test_cancel_get_iter_after_first_frame(self): - """get cannot be canceled after reading the first frame.""" + """get_iter cannot be canceled after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -429,7 +429,7 @@ async def test_get_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" @@ -437,7 +437,7 @@ async def test_get_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" @@ -445,7 +445,7 @@ async def test_get_iter_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" @@ -453,7 +453,7 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate # Test setting limits @@ -463,7 +463,7 @@ async def test_set_high_water_mark(self): self.assertEqual(assembler.high, 10) async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water mark and low-water mark.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 87333fd35..db1cc8e93 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -154,6 +154,16 @@ def test_recv_binary(self): self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(decode=False), "😀".encode()) + + def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual(self.connection.recv(decode=True), "😀") + def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -228,6 +238,22 @@ def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual( + list(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual( + list(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -499,28 +525,17 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_waits_for_recv(self): - # The sync implementation doesn't have a buffer for incoming messsages. - # It requires reading incoming frames until the close frame is reached. - # This behavior — close() blocks until recv() is called — is less than - # ideal and inconsistent with the asyncio implementation. + def test_close_does_not_wait_for_recv(self): + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. self.remote_connection.send("😀") + self.connection.close() close_thread = threading.Thread(target=self.connection.close) close_thread.start() - # Let close() initiate the closing handshake and send a close frame. - time.sleep(MS) - self.assertTrue(close_thread.is_alive()) - - # Connection isn't closed yet. - self.connection.recv() - - # Let close() receive a close frame and finish the closing handshake. - time.sleep(MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -528,24 +543,6 @@ def test_close_waits_for_recv(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) - def test_close_timeout_waiting_for_recv(self): - self.remote_connection.send("😀") - - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - - # Let close() time out during the closing handshake. - time.sleep(3 * MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.recv() - - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") - self.assertIsInstance(exc.__cause__, TimeoutError) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -724,6 +721,45 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test parameters. + + def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) + self.assertEqual(connection.close_timeout, 42 * MS) + + def test_max_queue(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + + def test_max_queue_tuple(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + # Test attributes. def test_id(self): diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index d44b39b88..02513894a 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,4 +1,6 @@ import time +import unittest +import unittest.mock from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -9,66 +11,23 @@ class AssemblerTests(ThreadTestCase): - """ - Tests in this class interact a lot with hidden synchronization mechanisms: - - - get() / get_iter() and put() must run in separate threads when a final - frame is set because put() waits for get() / get_iter() to fetch the - message before returning. - - - run_in_thread() lets its target run before yielding back control on entry, - which guarantees the intended execution order of test cases. - - - run_in_thread() waits for its target to finish running before yielding - back control on exit, which allows making assertions immediately. - - - When the main thread performs actions that let another thread progress, it - must wait before making assertions, to avoid depending on scheduling. - - """ - def setUp(self): - self.assembler = Assembler() - - def tearDown(self): - """ - Check that the assembler goes back to its default state after each test. - - This removes the need for testing various sequences. - - """ - self.assertFalse(self.assembler.mutex.locked()) - self.assertFalse(self.assembler.get_in_progress) - self.assertFalse(self.assembler.put_in_progress) - if not self.assembler.closed: - self.assertFalse(self.assembler.message_complete.is_set()) - self.assertFalse(self.assembler.message_fetched.is_set()) - self.assertIsNone(self.assembler.decoder) - self.assertEqual(self.assembler.chunks, []) - self.assertIsNone(self.assembler.chunks_queue) + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get def test_get_text_message_already_received(self): """get returns a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_binary_message_already_received(self): """get returns a binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_text_message_not_received_yet(self): @@ -99,112 +58,145 @@ def getter(): def test_get_fragmented_text_message_already_received(self): """get reassembles a fragmented a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_fragmented_binary_message_already_received(self): """get reassembles a fragmented binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = self.assembler.get() self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_being_received(self): - """get reassembles a fragmented text message that is partially received.""" + def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_being_received(self): - """get reassembles a fragmented binary message that is partially received.""" + def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_not_received_yet(self): - """get reassembles a fragmented text message when it is received.""" + def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_not_received_yet(self): - """get reassembles a fragmented binary message when it is received.""" + def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - # Test get_iter + def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + def test_get_timeout_before_first_frame(self): + """get times out before reading the first frame.""" + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_text_message_already_received(self): - """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_timeout_after_first_frame(self): + """get times out after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(fragments, ["café"]) + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_binary_message_already_received(self): - """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + # Test get_iter + def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"tea"]) def test_get_iter_text_message_not_received_yet(self): @@ -212,6 +204,7 @@ def test_get_iter_text_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -225,6 +218,7 @@ def test_get_iter_binary_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -235,121 +229,112 @@ def getter(): def test_get_iter_fragmented_text_message_already_received(self): """get_iter yields a fragmented text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["ca", "f", "é"]) def test_get_iter_fragmented_binary_message_already_received(self): """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") - self.assertEqual(fragments, [b"t", b"e", b"a"]) + def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - self.assertEqual(fragments, ["ca", "f", "é"]) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) - - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - def test_get_iter_fragmented_text_message_not_received_yet(self): - """get_iter yields a fragmented text message when it is received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") + + def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) + def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) - self.assertEqual(fragments, ["ca", "f", "é"]) + def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) - def test_get_iter_fragmented_binary_message_not_received_yet(self): - """get_iter yields a fragmented binary message when it is received.""" - fragments = [] + iterator = self.assembler.get_iter() - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + # queue is above the low-water mark + next(iterator) + self.resume.assert_not_called() - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) + # queue is at the low-water mark + next(iterator) + self.resume.assert_called_once_with() - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - # Test timeouts + # queue is below the low-water mark + next(iterator) + self.resume.assert_called_once_with() - def test_get_with_timeout_completes(self): - """get returns a message when it is received before the timeout.""" + # Test put - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get(MS) + def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() - self.assertEqual(message, "café") + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() - def test_get_with_timeout_times_out(self): - """get raises TimeoutError when no message is received before the timeout.""" - with self.assertRaises(TimeoutError): - self.assembler.get(MS) + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() # Test termination @@ -373,18 +358,8 @@ def closer(): with self.run_in_thread(closer): with self.assertRaises(EOFError): - list(self.assembler.get_iter()) - - def test_put_fails_when_interrupted_by_close(self): - """put raises EOFError when close is called.""" - - def closer(): - time.sleep(2 * MS) - self.assembler.close() - - with self.run_in_thread(closer): - with self.assertRaises(EOFError): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_get_fails_after_close(self): """get raises EOFError after close is called.""" @@ -396,7 +371,8 @@ def test_get_iter_fails_after_close(self): """get_iter raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): - list(self.assembler.get_iter()) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_put_fails_after_close(self): """put raises EOFError after close is called.""" @@ -439,13 +415,25 @@ def test_get_iter_fails_when_get_iter_is_running(self): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently.""" + # Test setting limits - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + def test_set_high_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) - with self.run_in_thread(putter): - with self.assertRaises(ConcurrencyError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread + def test_set_high_and_low_water_mark(self): + """high sets the high-water mark and low-water mark.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) From 6c9e3f48ff48dd4ad025f307e68d1be3c0687b4e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 14:16:06 +0200 Subject: [PATCH 650/762] Update feature list for encode/decode params. --- docs/reference/features.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 8b04034eb..576ea1025 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -43,12 +43,18 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Receive a fragmented message frame | ✅ | ✅ | ✅ | ❌ | + | Receive a fragmented message frame | ✅ | ✅ | — | ❌ | | by frame | | | | | +------------------------------------+--------+--------+--------+--------+ | Receive a fragmented message after | ✅ | ✅ | — | ✅ | | reassembly | | | | | +------------------------------------+--------+--------+--------+--------+ + | Force sending a message as Text or | ✅ | ✅ | — | ❌ | + | Binary | | | | | + +------------------------------------+--------+--------+--------+--------+ + | Force receiving a message as | ✅ | ✅ | — | ❌ | + | :class:`bytes` or :class:`str` | | | | | + +------------------------------------+--------+--------+--------+--------+ | Send a ping | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | From 8315d3cbd356fb2fbbe3fd03e189e3750a4d0399 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 16:18:30 +0200 Subject: [PATCH 651/762] Close the connection with code 1007 on invalid UTF-8. Fix #1523. --- docs/reference/features.rst | 5 +++++ src/websockets/asyncio/connection.py | 33 +++++++++++++++++++++++----- src/websockets/asyncio/messages.py | 2 ++ src/websockets/frames.py | 1 - src/websockets/sync/connection.py | 33 +++++++++++++++++++++++----- src/websockets/sync/messages.py | 2 ++ tests/asyncio/test_connection.py | 18 +++++++++++++++ tests/sync/test_connection.py | 18 +++++++++++++++ 8 files changed, 99 insertions(+), 13 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 576ea1025..9187fa505 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -194,3 +194,8 @@ connection to a given IP address in a CONNECTING state. This behavior is mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` isn't the right layer for enforcing this constraint. It's the caller's responsibility. + +It is possible to send or receive a text message containing invalid UTF-8 with +``send(not_utf8_bytes, text=True)`` and ``not_utf8_bytes = recv(decode=False)`` +respectively. As a side effect of disabling UTF-8 encoding and decoding, these +options also disable UTF-8 validation. diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 3b81e386b..2568249c7 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -268,14 +268,24 @@ async def recv(self, decode: bool | None = None) -> Data: try: return await self.recv_messages.get(decode) except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - await asyncio.shield(self.connection_lost_waiter) - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another coroutine " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ @@ -324,15 +334,26 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data try: async for frame in self.recv_messages.get_iter(decode): yield frame + return except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - await asyncio.shield(self.connection_lost_waiter) - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another coroutine " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc async def send( self, diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 09be22ba2..b57c0ca4e 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -131,6 +131,7 @@ async def get(self, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. @@ -197,6 +198,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 5fadf3c2d..0ff9f4d71 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -222,7 +222,6 @@ def parse( Raises: EOFError: If the connection is closed without a full WebSocket frame. - UnicodeDecodeError: If the frame contains invalid UTF-8. PayloadTooBig: If the frame's payload size exceeds ``max_size``. ProtocolError: If the frame contains incorrect values. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 3ab9f4937..823b44f74 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -233,14 +233,24 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data try: return self.recv_messages.get(timeout, decode) except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - self.recv_events_thread.join() - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv while another thread " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ @@ -283,15 +293,26 @@ def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ try: yield from self.recv_messages.get_iter(decode) + return except EOFError: - # Wait for the protocol state to be CLOSED before accessing close_exc. - self.recv_events_thread.join() - raise self.protocol.close_exc from self.recv_exc + pass + # fallthrough except ConcurrencyError: raise ConcurrencyError( "cannot call recv_streaming while another thread " "is already running recv or recv_streaming" ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc def send( self, diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 983b114dc..17f8dce7e 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -115,6 +115,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a @@ -186,6 +187,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: Raises: EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. ConcurrencyError: If two coroutines run :meth:`get` or :meth:`get_iter` concurrently. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 12e2bd5fa..a3b65e956 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -222,6 +222,15 @@ async def test_recv_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): await self.connection.recv() + async def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + async def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_task = asyncio.create_task(self.connection.recv()) @@ -352,6 +361,15 @@ async def test_recv_streaming_connection_closed_error(self): async for _ in self.connection.recv_streaming(): self.fail("did not raise") + async def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await alist(self.connection.recv_streaming()) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + async def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index db1cc8e93..abdfd3f78 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -186,6 +186,15 @@ def test_recv_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): self.connection.recv() + def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + self.connection.recv() + self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_thread = threading.Thread(target=self.connection.recv) @@ -286,6 +295,15 @@ def test_recv_streaming_connection_closed_error(self): for _ in self.connection.recv_streaming(): self.fail("did not raise") + def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + list(self.connection.recv_streaming()) + self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_thread = threading.Thread(target=self.connection.recv) From 0d2e246f4ee44ece6f74066cfe608e0df906e312 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 16:36:51 +0200 Subject: [PATCH 652/762] Various fixes for the compliance test suite. --- compliance/README.rst | 17 ++++++++++++----- compliance/asyncio/client.py | 21 ++++++++++++++------- compliance/asyncio/server.py | 8 ++++++-- compliance/config/fuzzingclient.json | 2 +- compliance/config/fuzzingserver.json | 2 +- compliance/sync/client.py | 21 ++++++++++++++------- compliance/sync/server.py | 8 ++++++-- 7 files changed, 54 insertions(+), 25 deletions(-) diff --git a/compliance/README.rst b/compliance/README.rst index c7c7c93b4..ee491310f 100644 --- a/compliance/README.rst +++ b/compliance/README.rst @@ -38,7 +38,7 @@ To test the servers: crossbario/autobahn-testsuite \ wstest --mode fuzzingclient --spec /config/fuzzingclient.json - $ open reports/servers/index.html + $ open compliance/reports/servers/index.html To test the clients: @@ -54,7 +54,7 @@ To test the clients: $ PYTHONPATH=src python compliance/asyncio/client.py $ PYTHONPATH=src python compliance/sync/client.py - $ open reports/clients/index.html + $ open compliance/reports/clients/index.html Conformance notes ----------------- @@ -67,6 +67,13 @@ In 3.2, 3.3, 4.1.3, 4.1.4, 4.2.3, 4.2.4, and 5.15 websockets notices the protocol error and closes the connection at the library level before the application gets a chance to echo the previous frame. -In 6.4.3 and 6.4.4, even though it uses an incremental decoder, websockets -doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. These tests -are more strict than the RFC. +In 6.4.1, 6.4.2, 6.4.3, and 6.4.4, even though it uses an incremental decoder, +websockets doesn't notice the invalid utf-8 fast enough to get a "Strict" pass. +These tests are more strict than the RFC. + +Test case 7.1.5 fails because websockets treats closing the connection in the +middle of a fragmented message as a protocol error. As a consequence, it sends +a close frame with code 1002. The test suite expects a close frame with code +1000, echoing the close code that it sent. This isn't required. RFC 6455 states +that "the endpoint typically echos the status code it received", which leaves +the possibility to send a close frame with a different status code. diff --git a/compliance/asyncio/client.py b/compliance/asyncio/client.py index 5b0bfb3ae..044ed6043 100644 --- a/compliance/asyncio/client.py +++ b/compliance/asyncio/client.py @@ -8,7 +8,9 @@ logging.basicConfig(level=logging.WARNING) -SERVER = "ws://127.0.0.1:9001" +SERVER = "ws://localhost:9001" + +AGENT = "websockets.asyncio" async def get_case_count(): @@ -18,16 +20,21 @@ async def get_case_count(): async def run_case(case): async with connect( - f"{SERVER}/runCase?case={case}", - user_agent_header="websockets.asyncio", + f"{SERVER}/runCase?case={case}&agent={AGENT}", max_size=2**25, ) as ws: - async for msg in ws: - await ws.send(msg) + try: + async for msg in ws: + await ws.send(msg) + except WebSocketException: + pass async def update_reports(): - async with connect(f"{SERVER}/updateReports", open_timeout=60): + async with connect( + f"{SERVER}/updateReports?agent={AGENT}", + open_timeout=60, + ): pass @@ -43,7 +50,7 @@ async def main(): print(f"FAIL: {type(exc).__name__}: {exc}") else: print("OK") - print("Ran {cases} test cases") + print(f"Ran {cases} test cases") await update_reports() print("Updated reports") diff --git a/compliance/asyncio/server.py b/compliance/asyncio/server.py index cff2728c9..84deb9727 100644 --- a/compliance/asyncio/server.py +++ b/compliance/asyncio/server.py @@ -2,6 +2,7 @@ import logging from websockets.asyncio.server import serve +from websockets.exceptions import WebSocketException logging.basicConfig(level=logging.WARNING) @@ -10,8 +11,11 @@ async def echo(ws): - async for msg in ws: - await ws.send(msg) + try: + async for msg in ws: + await ws.send(msg) + except WebSocketException: + pass async def main(): diff --git a/compliance/config/fuzzingclient.json b/compliance/config/fuzzingclient.json index 37f8f3a35..756ad03b6 100644 --- a/compliance/config/fuzzingclient.json +++ b/compliance/config/fuzzingclient.json @@ -5,7 +5,7 @@ }, { "url": "ws://host.docker.internal:9003" }], - "outdir": "./reports/servers", + "outdir": "/reports/servers", "cases": ["*"], "exclude-cases": [] } diff --git a/compliance/config/fuzzingserver.json b/compliance/config/fuzzingserver.json index 1bcb5959d..384caf0a2 100644 --- a/compliance/config/fuzzingserver.json +++ b/compliance/config/fuzzingserver.json @@ -1,7 +1,7 @@ { "url": "ws://localhost:9001", - "outdir": "./reports/clients", + "outdir": "/reports/clients", "cases": ["*"], "exclude-cases": [] } diff --git a/compliance/sync/client.py b/compliance/sync/client.py index e585496f3..c810e1beb 100644 --- a/compliance/sync/client.py +++ b/compliance/sync/client.py @@ -7,7 +7,9 @@ logging.basicConfig(level=logging.WARNING) -SERVER = "ws://127.0.0.1:9001" +SERVER = "ws://localhost:9001" + +AGENT = "websockets.sync" def get_case_count(): @@ -17,16 +19,21 @@ def get_case_count(): def run_case(case): with connect( - f"{SERVER}/runCase?case={case}", - user_agent_header="websockets.sync", + f"{SERVER}/runCase?case={case}&agent={AGENT}", max_size=2**25, ) as ws: - for msg in ws: - ws.send(msg) + try: + for msg in ws: + ws.send(msg) + except WebSocketException: + pass def update_reports(): - with connect(f"{SERVER}/updateReports", open_timeout=60): + with connect( + f"{SERVER}/updateReports?agent={AGENT}", + open_timeout=60, + ): pass @@ -42,7 +49,7 @@ def main(): print(f"FAIL: {type(exc).__name__}: {exc}") else: print("OK") - print("Ran {cases} test cases") + print(f"Ran {cases} test cases") update_reports() print("Updated reports") diff --git a/compliance/sync/server.py b/compliance/sync/server.py index c3cb4d989..494f56a44 100644 --- a/compliance/sync/server.py +++ b/compliance/sync/server.py @@ -1,5 +1,6 @@ import logging +from websockets.exceptions import WebSocketException from websockets.sync.server import serve @@ -9,8 +10,11 @@ def echo(ws): - for msg in ws: - ws.send(msg) + try: + for msg in ws: + ws.send(msg) + except WebSocketException: + pass def main(): From 6cea05e51d50455d66e90a1888aba9be8e8809db Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 08:06:21 +0200 Subject: [PATCH 653/762] Support HTTP response without Content-Length. Fix #1531. --- src/websockets/asyncio/connection.py | 13 +++++++++-- src/websockets/legacy/exceptions.py | 2 +- src/websockets/sync/connection.py | 13 +++++++++-- tests/asyncio/test_client.py | 30 +++++++++++++++++++++++++ tests/sync/test_client.py | 33 +++++++++++++++++++++++++++- 5 files changed, 85 insertions(+), 6 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 2568249c7..5545632d6 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -1060,13 +1060,22 @@ def eof_received(self) -> None: # Feed the end of the data stream to the connection. self.protocol.receive_eof() - # This isn't expected to generate events. - assert not self.protocol.events_received() + # This isn't expected to raise an exception. + events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it shouldn't raise errors. self.send_data() + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + # The WebSocket protocol has its own closing handshake: endpoints close # the TCP or TLS connection after sending and receiving a close frame. # As a consequence, they never need to write after receiving EOF, so diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 9ca9b7aff..e2279c825 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -50,7 +50,7 @@ def __init__( headers: datastructures.HeadersLike, body: bytes = b"", ) -> None: - # If a user passes an int instead of a HTTPStatus, fix it automatically. + # If a user passes an int instead of an HTTPStatus, fix it automatically. self.status = http.HTTPStatus(status) self.headers = datastructures.Headers(headers) self.body = body diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 823b44f74..8d1dbcf58 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -696,13 +696,22 @@ def recv_events(self) -> None: # Feed the end of the data stream to the protocol. self.protocol.receive_eof() - # This isn't expected to generate events. - assert not self.protocol.events_received() + # This isn't expected to raise an exception. + events = self.protocol.events_received() # There is no error handling because send_data() can only write # the end of the data stream here and it handles errors itself. self.send_data() + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + except Exception as exc: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 9354a6e0a..1b89977ea 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -401,6 +401,36 @@ def close_connection(self, request): "connection closed while reading HTTP status line", ) + async def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + async with serve(*args, process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + async with serve(*args, process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + async def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e63d774b7..e9b0f63ad 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,3 +1,4 @@ +import http import logging import socket import socketserver @@ -6,7 +7,7 @@ import time import unittest -from websockets.exceptions import InvalidHandshake, InvalidURI +from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -156,6 +157,36 @@ def close_connection(self, request): "connection closed while reading HTTP status line", ) + def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" From c75b1df159d83e46a2ae29069ae92789690ed22f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 09:22:17 +0200 Subject: [PATCH 654/762] Mention the option to keep the legacy implementation. --- docs/project/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 410671239..1b3b0073c 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -52,6 +52,12 @@ Backwards-incompatible changes If you're using any of them, then you must follow the :doc:`upgrade guide <../howto/upgrade>` immediately. + Alternatively, you may stick to the legacy :mod:`asyncio` implementation for + now by importing it explicitly:: + + from websockets.legacy.client import connect, unix_connect + from websockets.legacy.server import broadcast, serve, unix_serve + .. admonition:: The legacy :mod:`asyncio` implementation is now deprecated. :class: caution From d248160b6b509c5701fd749cd2ef103d244c7631 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 26 Oct 2024 09:22:52 +0200 Subject: [PATCH 655/762] Raise ValueError for required or unacceptable arguments. This improves consistency with asyncio and within websockets. --- docs/project/changelog.rst | 13 +++++++++++++ src/websockets/asyncio/client.py | 4 ++-- src/websockets/asyncio/server.py | 4 +++- src/websockets/sync/client.py | 6 +++--- src/websockets/sync/server.py | 8 +++++--- tests/asyncio/test_client.py | 4 ++-- tests/asyncio/test_server.py | 4 ++-- tests/sync/test_client.py | 6 +++--- tests/sync/test_server.py | 8 ++++---- 9 files changed, 37 insertions(+), 20 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1b3b0073c..576b7252e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -76,6 +76,19 @@ Backwards-incompatible changes If you wrote an :class:`extension ` that relies on methods not provided by these new types, you may need to update your code. +.. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` + on invalid arguments. + :class: note + + :func:`~asyncio.client.connect`, :func:`~asyncio.client.unix_connect`, and + :func:`~asyncio.server.basic_auth` in the :mod:`asyncio` implementation as + well as :func:`~sync.client.connect`, :func:`~sync.client.unix_connect`, + :func:`~sync.server.serve`, :func:`~sync.server.unix_serve`, and + :func:`~sync.server.basic_auth` in the :mod:`threading` implementation now + raise :exc:`ValueError` when a required argument isn't provided or an + argument that is incompatible with others is provided. + + New features ............ diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 0c8bedc5d..ff7916d39 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -352,10 +352,10 @@ def factory() -> ClientConnection: kwargs.setdefault("ssl", True) kwargs.setdefault("server_hostname", wsuri.host) if kwargs.get("ssl") is None: - raise TypeError("ssl=None is incompatible with a wss:// URI") + raise ValueError("ssl=None is incompatible with a wss:// URI") else: if kwargs.get("ssl") is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + raise ValueError("ssl argument is incompatible with a ws:// URI") if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index a6ae5996d..180d3a5a9 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -890,10 +890,12 @@ def basic_auth( whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. """ if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 42daa32ea..0aada658e 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -221,7 +221,7 @@ def connect( wsuri = parse_uri(uri) if not wsuri.secure and ssl is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + raise ValueError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() unix: bool = kwargs.pop("unix", False) @@ -229,9 +229,9 @@ def connect( if unix: if path is None and sock is None: - raise TypeError("missing path argument") + raise ValueError("missing path argument") elif path is not None and sock is not None: - raise TypeError("path and sock arguments are incompatible") + raise ValueError("path and sock arguments are incompatible") if subprotocols is not None: validate_subprotocols(subprotocols) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 94f76b658..44dbd7290 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -477,14 +477,14 @@ def handler(websocket): if sock is None: if unix: if path is None: - raise TypeError("missing path argument") + raise ValueError("missing path argument") kwargs.setdefault("family", socket.AF_UNIX) sock = socket.create_server(path, **kwargs) else: sock = socket.create_server((host, port), **kwargs) else: if path is not None: - raise TypeError("path and sock arguments are incompatible") + raise ValueError("path and sock arguments are incompatible") # Initialize TLS wrapper @@ -667,10 +667,12 @@ def basic_auth( whether they're valid. Raises: TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. """ if (credentials is None) == (check_credentials is None): - raise TypeError("provide either credentials or check_credentials") + raise ValueError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 1b89977ea..231d6b8ca 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -607,7 +607,7 @@ async def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.IsolatedAsyncioTestCase): async def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: await connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), @@ -616,7 +616,7 @@ async def test_ssl_without_secure_uri(self): async def test_secure_uri_without_ssl(self): """Client rejects no ssl when URI is secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( str(raised.exception), diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 47e0148a6..1dcb8c7b7 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -717,7 +717,7 @@ async def check_credentials(username, password): async def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), @@ -726,7 +726,7 @@ async def test_without_credentials_or_check_credentials(self): async def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e9b0f63ad..9d457a912 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -311,7 +311,7 @@ def test_set_server_hostname(self): class ClientUsageErrorsTests(unittest.TestCase): def test_ssl_without_secure_uri(self): """Client rejects ssl when URI isn't secure.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: connect("ws://localhost/", ssl=CLIENT_CONTEXT) self.assertEqual( str(raised.exception), @@ -320,7 +320,7 @@ def test_ssl_without_secure_uri(self): def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_connect() self.assertEqual( str(raised.exception), @@ -331,7 +331,7 @@ def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_connect(path="/", sock=sock) self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 3bc6f76cd..54e49bf16 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -361,7 +361,7 @@ def test_connection(self): class ServerUsageErrorsTests(unittest.TestCase): def test_unix_without_path_or_sock(self): """Unix server requires path when sock isn't provided.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_serve(handler) self.assertEqual( str(raised.exception), @@ -372,7 +372,7 @@ def test_unix_with_path_and_sock(self): """Unix server rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.addCleanup(sock.close) - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: unix_serve(handler, path="/", sock=sock) self.assertEqual( str(raised.exception), @@ -504,7 +504,7 @@ def check_credentials(username, password): def test_without_credentials_or_check_credentials(self): """basic_auth requires either credentials or check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth() self.assertEqual( str(raised.exception), @@ -513,7 +513,7 @@ def test_without_credentials_or_check_credentials(self): def test_with_credentials_and_check_credentials(self): """basic_auth requires only one of credentials and check_credentials.""" - with self.assertRaises(TypeError) as raised: + with self.assertRaises(ValueError) as raised: basic_auth( credentials=("hello", "iloveyou"), check_credentials=lambda: False, # pragma: no cover From 3b2c5223f7213b130d1469535fccbf9c3d08c4a9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Oct 2024 07:59:16 +0100 Subject: [PATCH 656/762] Rewrite tips for Sans-I/O integration. --- docs/howto/sansio.rst | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index d41519ff0..ca530e6a1 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -302,21 +302,24 @@ Tips Serialize operations .................... -The Sans-I/O layer expects to run sequentially. If your interact with it from -multiple threads or coroutines, you must ensure correct serialization. This -should happen automatically in a cooperative multitasking environment. +The Sans-I/O layer is designed to run sequentially. If you interact with it from +multiple threads or coroutines, you must ensure correct serialization. -However, you still have to make sure you don't break this property by -accident. For example, serialize writes to the network -when :meth:`~protocol.Protocol.data_to_send` returns multiple values to -prevent concurrent writes from interleaving incorrectly. +Usually, this comes for free in a cooperative multitasking environment. In a +preemptive multitasking environment, it requires mutual exclusion. -Avoid buffers -............. +Furthermore, you must serialize writes to the network. When +:meth:`~protocol.Protocol.data_to_send` returns several values, you must write +them all before starting the next write. -The Sans-I/O layer doesn't do any buffering. It makes events available in +Minimize buffers +................ + +The Sans-I/O layer doesn't perform any buffering. It makes events available in :meth:`~protocol.Protocol.events_received` as soon as they're received. -You should make incoming messages available to the application immediately and -stop further processing until the application fetches them. This will usually -result in the best performance. +You should make incoming messages available to the application immediately. + +A small buffer of incoming messages will usually result in the best performance. +It will reduce context switching between the library and the application while +ensuring that backpressure is propagated. From 018d2e5cf56ff03690bcbf271d188e76d59c62f3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 27 Oct 2024 07:59:35 +0100 Subject: [PATCH 657/762] Explain application-level keepalives. Fix #1514. --- docs/topics/keepalive.rst | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 458fa3d05..003087fad 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -90,6 +90,46 @@ application layer. Read this `blog post `_ for a complete walk-through of this issue. +Application-level keepalive +--------------------------- + +Some servers require clients to send a keepalive message with a specific content +at regular intervals. Usually they expect Text_ frames rather than Ping_ frames, +meaning that you must send them with :attr:`~asyncio.connection.Connection.send` +rather than :attr:`~asyncio.connection.Connection.ping`. + +.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + +In websockets, such keepalive mechanisms are considered as application-level +because they rely on data frames. That's unlike the protocol-level keepalive +based on control frames. Therefore, it's your responsibility to implement the +required behavior. + +You can run a task in the background to send keepalive messages: + +.. code-block:: python + + import itertools + import json + + from websockets import ConnectionClosed + + async def keepalive(websocket, ping_interval=30): + for ping in itertools.count(): + await asyncio.sleep(ping_interval) + try: + await websocket.send(json.dumps({"ping": ping})) + except ConnectionClosed: + break + + async def main(): + async with connect(...) as websocket: + keepalive_task = asyncio.create_task(keepalive(websocket)) + try: + ... # your application logic goes here + finally: + keepalive_task.cancel() + Latency issues -------------- From e44d79559d16661df4709c1b54150d735f85ae54 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:52:16 +0000 Subject: [PATCH 658/762] Bump pypa/cibuildwheel from 2.20.0 to 2.21.3 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.20.0 to 2.21.3. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.20.0...v2.21.3) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 863d88aa9..cc26502ca 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.20.0 + uses: pypa/cibuildwheel@v2.21.3 env: BUILD_EXTENSION: yes - name: Save wheels From 810bdeb7943e4cc0dcb662ef27ea764e31740e05 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 1 Nov 2024 12:35:38 +0100 Subject: [PATCH 659/762] Report size correctly in PayloadTooBig. Previously, it was reported incorrectly for fragmented messages. Fix #1522. --- docs/project/changelog.rst | 31 +++++--- src/websockets/exceptions.py | 41 +++++++++++ .../extensions/permessage_deflate.py | 3 +- src/websockets/frames.py | 2 +- src/websockets/legacy/framing.py | 2 +- src/websockets/protocol.py | 5 +- tests/test_exceptions.py | 20 +++++- tests/test_protocol.py | 72 ++++++++++++++----- 8 files changed, 143 insertions(+), 33 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 576b7252e..71b2a6960 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -67,15 +67,6 @@ Backwards-incompatible changes Aliases for deprecated API were removed from ``__all__``. As a consequence, they cannot be imported e.g. with ``from websockets import *`` anymore. -.. admonition:: :attr:`Frame.data ` is now a bytes-like object. - :class: note - - In addition to :class:`bytes`, it may be a :class:`bytearray` or a - :class:`memoryview`. - - If you wrote an :class:`extension ` that relies on - methods not provided by these new types, you may need to update your code. - .. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` on invalid arguments. :class: note @@ -88,6 +79,26 @@ Backwards-incompatible changes raise :exc:`ValueError` when a required argument isn't provided or an argument that is incompatible with others is provided. +.. admonition:: :attr:`Frame.data ` is now a bytes-like object. + :class: note + + In addition to :class:`bytes`, it may be a :class:`bytearray` or a + :class:`memoryview`. + + If you wrote an :class:`extension ` that relies on + methods not provided by these new types, you may need to update your code. + +.. admonition:: The signature of :exc:`~exceptions.PayloadTooBig` changed. + :class: note + + If you wrote an extension that raises :exc:`~exceptions.PayloadTooBig` in + :meth:`~extensions.Extension.decode`, for example, you must replace:: + + PayloadTooBig(f"over size limit ({size} > {max_size} bytes)") + + with:: + + PayloadTooBig(size, max_size) New features ............ @@ -105,6 +116,8 @@ Improvements * Sending or receiving large compressed messages is now faster. +* Errors when a fragmented message is too large are clearer. + .. _13.1: 13.1 diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 7681736a4..be3d1ca5f 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -334,6 +334,47 @@ class PayloadTooBig(WebSocketException): """ + def __init__( + self, + size_or_message: int | None | str, + max_size: int | None = None, + cur_size: int | None = None, + ) -> None: + if isinstance(size_or_message, str): + assert max_size is None + assert cur_size is None + warnings.warn( # deprecated in 14.0 + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + DeprecationWarning, + ) + self.message: str | None = size_or_message + else: + self.message = None + self.size: int | None = size_or_message + assert max_size is not None + self.max_size: int = max_size + self.cur_size: int | None = None + self.set_current_size(cur_size) + + def __str__(self) -> str: + if self.message is not None: + return self.message + else: + message = "frame " + if self.size is not None: + message += f"with {self.size} bytes " + if self.cur_size is not None: + message += f"after reading {self.cur_size} bytes " + message += f"exceeds limit of {self.max_size} bytes" + return message + + def set_current_size(self, cur_size: int | None) -> None: + assert self.cur_size is None + if cur_size is not None: + self.max_size += cur_size + self.cur_size = cur_size + class InvalidState(WebSocketException, AssertionError): """ diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index ed16937d8..cefad4f56 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -139,7 +139,8 @@ def decode( try: data = self.decoder.decompress(data, max_length) if self.decoder.unconsumed_tail: - raise PayloadTooBig(f"over size limit (? > {max_size} bytes)") + assert max_size is not None # help mypy + raise PayloadTooBig(None, max_size) if frame.fin and len(frame.data) >= 2044: # This cannot generate additional data. self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 0ff9f4d71..7898c8a5d 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -252,7 +252,7 @@ def parse( data = yield from read_exact(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise PayloadTooBig(length, max_size) if mask: mask_bytes = yield from read_exact(4) diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index 4ec194ed7..add0c6e0e 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -93,7 +93,7 @@ async def read( data = await reader(8) (length,) = struct.unpack("!Q", data) if max_size is not None and length > max_size: - raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + raise PayloadTooBig(length, max_size) if mask: mask_bits = await reader(4) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 091b4a23a..19b813526 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -587,6 +587,7 @@ def parse(self) -> Generator[None]: self.parser_exc = exc except PayloadTooBig as exc: + exc.set_current_size(self.cur_size) self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) self.parser_exc = exc @@ -639,9 +640,7 @@ def recv_frame(self, frame: Frame) -> None: if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: if self.cur_size is not None: raise ProtocolError("expected a continuation frame") - if frame.fin: - self.cur_size = None - else: + if not frame.fin: self.cur_size = len(frame.data) elif frame.opcode is OP_CONT: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8d41bf915..fef41d136 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -160,8 +160,16 @@ def test_str(self): "invalid opcode: 7", ), ( - PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), - "payload length exceeds limit: 2 > 1 bytes", + PayloadTooBig(None, 4), + "frame exceeds limit of 4 bytes", + ), + ( + PayloadTooBig(8, 4), + "frame with 8 bytes exceeds limit of 4 bytes", + ), + ( + PayloadTooBig(8, 4, 12), + "frame with 8 bytes after reading 12 bytes exceeds limit of 16 bytes", ), ( InvalidState("WebSocket connection isn't established yet"), @@ -202,3 +210,11 @@ def test_connection_closed_attributes_deprecation_defaults(self): "use Protocol.close_reason or ConnectionClosed.rcvd.reason" ): self.assertEqual(exception.reason, "") + + def test_payload_too_big_with_message(self): + with self.assertDeprecationWarning( + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + ): + exc = PayloadTooBig("payload length exceeds limit: 2 > 1 bytes") + self.assertEqual(str(exc), "payload length exceeds limit: 2 > 1 bytes") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7f1276bb2..0ae804bb3 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -265,18 +265,28 @@ def test_client_receives_text_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x81\x04\xf0\x9f\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_text_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_receives_text_without_size_limit(self): @@ -363,9 +373,14 @@ def test_client_receives_fragmented_text_over_size_limit(self): ) client.receive_data(b"\x80\x02\x98\x80") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_text_over_size_limit(self): @@ -377,9 +392,14 @@ def test_server_receives_fragmented_text_over_size_limit(self): ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\x98\x80") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_receives_fragmented_text_without_size_limit(self): @@ -533,18 +553,28 @@ def test_client_receives_binary_over_size_limit(self): client = Protocol(CLIENT, max_size=3) client.receive_data(b"\x82\x04\x01\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_server_receives_binary_over_size_limit(self): server = Protocol(SERVER, max_size=3) server.receive_data(b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (4 > 3 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 4 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (4 > 3 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 4 bytes exceeds limit of 3 bytes", ) def test_client_sends_fragmented_binary(self): @@ -615,9 +645,14 @@ def test_client_receives_fragmented_binary_over_size_limit(self): ) client.receive_data(b"\x80\x02\xfe\xff") self.assertIsInstance(client.parser_exc, PayloadTooBig) - self.assertEqual(str(client.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(client.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - client, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + client, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_server_receives_fragmented_binary_over_size_limit(self): @@ -629,9 +664,14 @@ def test_server_receives_fragmented_binary_over_size_limit(self): ) server.receive_data(b"\x80\x82\x00\x00\x00\x00\xfe\xff") self.assertIsInstance(server.parser_exc, PayloadTooBig) - self.assertEqual(str(server.parser_exc), "over size limit (2 > 1 bytes)") + self.assertEqual( + str(server.parser_exc), + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", + ) self.assertConnectionFailing( - server, CloseCode.MESSAGE_TOO_BIG, "over size limit (2 > 1 bytes)" + server, + CloseCode.MESSAGE_TOO_BIG, + "frame with 2 bytes after reading 2 bytes exceeds limit of 3 bytes", ) def test_client_sends_unexpected_binary(self): From cdeb882865145399ee0fb7d0e7623418916d6b78 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 08:48:53 +0100 Subject: [PATCH 660/762] Don't log an error when process_request returns a response. Fix #1513. --- src/websockets/asyncio/client.py | 6 +- src/websockets/asyncio/server.py | 17 ++-- src/websockets/protocol.py | 29 ++++-- src/websockets/server.py | 6 -- src/websockets/sync/client.py | 6 +- src/websockets/sync/server.py | 11 ++- tests/asyncio/test_connection.py | 2 +- tests/asyncio/test_server.py | 146 ++++++++++++++++++++----------- tests/test_protocol.py | 20 ++++- 9 files changed, 163 insertions(+), 80 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index ff7916d39..d276ac171 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -95,9 +95,9 @@ async def handshake( return_when=asyncio.FIRST_COMPLETED, ) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 180d3a5a9..15c9ba13e 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -192,10 +192,13 @@ async def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc @@ -360,7 +363,11 @@ async def conn_handler(self, connection: ServerConnection) -> None: connection.close_transport() return - assert connection.protocol.state is OPEN + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.close_transport() + return + try: connection.start_keepalive() await self.handler(connection) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 19b813526..0f6fea250 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -518,15 +518,34 @@ def close_expected(self) -> bool: Whether the TCP connection is expected to close soon. """ - # We expect a TCP close if and only if we sent a close frame: + # During the opening handshake, when our state is CONNECTING, we expect + # a TCP close if and only if the hansdake fails. When it does, we start + # the TCP closing handshake by sending EOF with send_eof(). + + # Once the opening handshake completes successfully, we expect a TCP + # close if and only if we sent a close frame, meaning that our state + # progressed to CLOSING: + # * Normal closure: once we send a close frame, we expect a TCP close: # server waits for client to complete the TCP closing handshake; # client waits for server to initiate the TCP closing handshake. + # * Abnormal closure: we always send a close frame and the same logic # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. - # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING or self.handshake_exc is not None + + # If our state is CLOSED, we already received a TCP close so we don't + # expect it anymore. + + # Micro-optimization: put the most common case first + if self.state is OPEN: + return False + if self.state is CLOSING: + return True + if self.state is CLOSED: + return False + assert self.state is CONNECTING + return self.eof_sent # Private methods for receiving data. @@ -616,14 +635,14 @@ def discard(self) -> Generator[None]: # connection in the same circumstances where discard() replaces parse(). # The client closes it when it receives EOF from the server or times # out. (The latter case cannot be handled in this Sans-I/O layer.) - assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent) + assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent) while not (yield from self.reader.at_eof()): self.reader.discard() if self.debug: self.logger.debug("< EOF") # A server closes the TCP connection immediately, while a client # waits for the server to close the TCP connection. - if self.state != CONNECTING and self.side is CLIENT: + if self.side is CLIENT and self.state is not CONNECTING: self.send_eof() self.state = CLOSED # If discard() completes normally, execution ends here. diff --git a/src/websockets/server.py b/src/websockets/server.py index 527db8990..e3fdcc646 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -14,7 +14,6 @@ InvalidHeader, InvalidHeaderValue, InvalidOrigin, - InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -536,11 +535,6 @@ def send_response(self, response: Response) -> None: self.logger.info("connection open") else: - # handshake_exc may be already set if accept() encountered an error. - # If the connection isn't open, set handshake_exc to guarantee that - # handshake_exc is None if and only if opening handshake succeeded. - if self.handshake_exc is None: - self.handshake_exc = InvalidStatus(response) self.logger.info( "connection rejected (%d %s)", response.status_code, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 0aada658e..54d0aef68 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -87,9 +87,9 @@ def handshake( if not self.response_rcvd.wait(timeout): raise TimeoutError("timed out during handshake") - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a response, when the response cannot be parsed, or - # when the response fails the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 44dbd7290..8601ccef9 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -170,10 +170,13 @@ def handshake( self.protocol.send_response(self.response) - # self.protocol.handshake_exc is always set when the connection is lost - # before receiving a request, when the request cannot be parsed, when - # the handshake encounters an error, or when process_request or - # process_response sends an HTTP response that rejects the handshake. + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. if self.protocol.handshake_exc is not None: raise self.protocol.handshake_exc diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index a3b65e956..c98765d80 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -51,7 +51,7 @@ async def asyncTearDown(self): if sys.version_info[:2] < (3, 10): # pragma: no cover @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): + def assertNoLogs(self, logger=None, level=None): """ No message is logged on the given logger with at least the given level. diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 1dcb8c7b7..c817f5ef6 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -148,14 +148,17 @@ def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_async_process_request_returns_response(self): """Server aborts handshake if async process_request returns a response.""" @@ -166,44 +169,65 @@ async def process_request(ws, request): async def handler(ws): self.fail("handler must not run") - async with serve(handler, *args[1:], process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 403", - ) + with self.assertNoLogs("websockets", logging.ERROR): + async with serve( + handler, *args[1:], process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) async def test_process_request_raises_exception(self): """Server returns an error if process_request raises an exception.""" def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_request_raises_exception(self): """Server returns an error if async process_request raises an exception.""" async def process_request(ws, request): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_request=process_request) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_process_response_returns_none(self): """Server runs process_response but keeps the handshake response.""" @@ -277,31 +301,49 @@ async def test_process_response_raises_exception(self): """Server returns an error if process_response raises an exception.""" def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_async_process_response_raises_exception(self): """Server returns an error if async process_response raises an exception.""" async def process_response(ws, request, response): - raise RuntimeError + raise RuntimeError("BOOM") - async with serve(*args, process_response=process_response) as server: - with self.assertRaises(InvalidStatus) as raised: - async with connect(get_uri(server)): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "server rejected WebSocket connection: HTTP 500", - ) + with self.assertLogs("websockets", logging.ERROR) as logs: + async with serve(*args, process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) async def test_override_server(self): """Server can override Server header with server_header.""" diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 0ae804bb3..1c092459d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -20,7 +20,7 @@ Frame, ) from websockets.protocol import * -from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER +from websockets.protocol import CLIENT, CLOSED, CLOSING, CONNECTING, SERVER from .extensions.utils import Rsv2Extension from .test_frames import FramesTestCase @@ -1696,6 +1696,24 @@ def test_server_fails_connection(self): server.fail(CloseCode.PROTOCOL_ERROR) self.assertTrue(server.close_expected()) + def test_client_is_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + self.assertFalse(client.close_expected()) + + def test_server_is_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + self.assertFalse(server.close_expected()) + + def test_client_failed_connecting(self): + client = Protocol(CLIENT, state=CONNECTING) + client.send_eof() + self.assertTrue(client.close_expected()) + + def test_server_failed_connecting(self): + server = Protocol(SERVER, state=CONNECTING) + server.send_eof() + self.assertTrue(server.close_expected()) + class ConnectionClosedTests(ProtocolTestCase): """ From 76f6f573e2ecb279230c2bf56c07bf4d4f717147 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 09:11:51 +0100 Subject: [PATCH 661/762] Factor out backport of assertNoLogs. Fix previous commit on Python 3.9. --- tests/asyncio/test_connection.py | 25 ++++--------------------- tests/asyncio/test_server.py | 3 ++- tests/legacy/test_protocol.py | 12 ++++++------ tests/legacy/utils.py | 23 +++-------------------- tests/utils.py | 26 ++++++++++++++++++++++++++ 5 files changed, 41 insertions(+), 48 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index c98765d80..d61798afb 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -19,7 +19,7 @@ from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol -from ..utils import MS +from ..utils import MS, AssertNoLogsMixin from .connection import InterceptingConnection from .utils import alist @@ -28,7 +28,7 @@ # All tests run on the client side and the server side to validate this. -class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): +class ClientConnectionTests(AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): LOCAL = CLIENT REMOTE = SERVER @@ -48,23 +48,6 @@ async def asyncTearDown(self): await self.remote_connection.close() await self.connection.close() - if sys.version_info[:2] < (3, 10): # pragma: no cover - - @contextlib.contextmanager - def assertNoLogs(self, logger=None, level=None): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - # Test helpers built upon RecordingProtocol and InterceptingConnection. async def assertFrameSent(self, frame): @@ -1277,7 +1260,7 @@ async def test_broadcast_skips_closed_connection(self): await self.connection.close() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() @@ -1288,7 +1271,7 @@ async def test_broadcast_skips_closing_connection(self): await asyncio.sleep(0) await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.WARNING): broadcast([self.connection], "😀") await self.assertNoFrameSent() diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index c817f5ef6..3e289e592 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -21,6 +21,7 @@ CLIENT_CONTEXT, MS, SERVER_CONTEXT, + AssertNoLogsMixin, temp_unix_socket_path, ) from .server import ( @@ -32,7 +33,7 @@ ) -class ServerTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): +class ServerTests(EvalShellMixin, AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Server receives connection from client and the handshake succeeds.""" async with serve(*args) as server: diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index de2a320b5..be2910a8f 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -938,7 +938,7 @@ def test_answer_ping_does_not_crash_if_connection_closing(self): self.receive_frame(Frame(True, OP_PING, b"test")) self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) self.loop.run_until_complete(close_task) # cleanup @@ -951,7 +951,7 @@ def test_answer_ping_does_not_crash_if_connection_closed(self): self.receive_eof() self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close()) def test_ignore_pong(self): @@ -1028,7 +1028,7 @@ def test_acknowledge_aborted_ping(self): pong_waiter.result() # transfer_data doesn't crash, which would be logged. - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): # Unclog incoming queue. self.loop.run_until_complete(self.protocol.recv()) self.loop.run_until_complete(self.protocol.recv()) @@ -1375,7 +1375,7 @@ def test_remote_close_and_connection_lost(self): self.receive_eof() self.run_loop_once() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): self.loop.run_until_complete(self.protocol.close(reason="oh noes!")) self.assertConnectionClosed(CloseCode.NORMAL_CLOSURE, "close") @@ -1500,14 +1500,14 @@ def test_broadcast_two_clients(self): def test_broadcast_skips_closed_connection(self): self.close_connection() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() def test_broadcast_skips_closing_connection(self): close_task = self.half_close_connection_local() - with self.assertNoLogs(): + with self.assertNoLogs("websockets", logging.ERROR): broadcast([self.protocol], "café") self.assertNoFrameSent() diff --git a/tests/legacy/utils.py b/tests/legacy/utils.py index 5b56050d5..1f79bb600 100644 --- a/tests/legacy/utils.py +++ b/tests/legacy/utils.py @@ -1,12 +1,12 @@ import asyncio -import contextlib import functools -import logging import sys import unittest +from ..utils import AssertNoLogsMixin -class AsyncioTestCase(unittest.TestCase): + +class AsyncioTestCase(AssertNoLogsMixin, unittest.TestCase): """ Base class for tests that sets up an isolated event loop for each test. @@ -56,23 +56,6 @@ def run_loop_once(self): self.loop.call_soon(self.loop.stop) self.loop.run_forever() - if sys.version_info[:2] < (3, 10): # pragma: no cover - - @contextlib.contextmanager - def assertNoLogs(self, logger="websockets", level=logging.ERROR): - """ - No message is logged on the given logger with at least the given level. - - """ - with self.assertLogs(logger, level) as logs: - # We want to test that no log message is emitted - # but assertLogs expects at least one log message. - logging.getLogger(logger).log(level, "dummy") - yield - - level_name = logging.getLevelName(level) - self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) - def assertDeprecationWarnings(self, recorded_warnings, expected_warnings): """ Check recorded deprecation warnings match a list of expected messages. diff --git a/tests/utils.py b/tests/utils.py index 639fb7fe5..77d020726 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,11 @@ import contextlib import email.utils +import logging import os import pathlib import platform import ssl +import sys import tempfile import time import unittest @@ -113,6 +115,30 @@ def assertDeprecationWarning(self, message): self.assertEqual(str(warning.message), message) +class AssertNoLogsMixin: + """ + Backport of assertNoLogs for Python 3.9. + + """ + + if sys.version_info[:2] < (3, 10): # pragma: no cover + + @contextlib.contextmanager + def assertNoLogs(self, logger=None, level=None): + """ + No message is logged on the given logger with at least the given level. + + """ + with self.assertLogs(logger, level) as logs: + # We want to test that no log message is emitted + # but assertLogs expects at least one log message. + logging.getLogger(logger).log(level, "dummy") + yield + + level_name = logging.getLevelName(level) + self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) + + @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: From 0a5a79c224c9be97f79909f7218fd1da7b2acabb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 3 Nov 2024 09:47:51 +0100 Subject: [PATCH 662/762] Clean up prefixes of debug log messages. All debug messages and only debug messages should have them. --- docs/topics/logging.rst | 8 ++++++-- src/websockets/asyncio/client.py | 2 +- src/websockets/asyncio/connection.py | 6 +++--- src/websockets/legacy/client.py | 4 ++-- src/websockets/legacy/protocol.py | 8 ++++---- src/websockets/sync/connection.py | 15 ++++++++++++--- tests/legacy/test_client_server.py | 4 ++-- 7 files changed, 30 insertions(+), 17 deletions(-) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index be5678455..ae71be265 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -218,7 +218,10 @@ Here's what websockets logs at each level. ``ERROR`` ......... -* Exceptions raised by connection handler coroutines in servers +* Exceptions raised by your code in servers + * connection handler coroutines + * ``select_subprotocol`` callbacks + * ``process_request`` and ``process_response`` callbacks * Exceptions resulting from bugs in websockets ``WARNING`` @@ -250,4 +253,5 @@ Debug messages have cute prefixes that make logs easier to scan: * ``=`` - set connection state * ``x`` - shut down connection * ``%`` - manage pings and pongs -* ``!`` - handle errors and timeouts +* ``-`` - timeout +* ``!`` - error, with a traceback diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index d276ac171..302d0b94d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -521,7 +521,7 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays = backoff() delay = next(delays) self.logger.info( - "! connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds", delay, exc_info=True, ) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 5545632d6..c4961884c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -785,7 +785,7 @@ async def keepalive(self) -> None: self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for keepalive pong") + self.logger.debug("- timed out waiting for keepalive pong") async with self.send_context(): self.protocol.fail( CloseCode.INTERNAL_ERROR, @@ -866,7 +866,7 @@ async def send_context( await self.drain() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug("! error while sending data", exc_info=True) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False @@ -1042,7 +1042,7 @@ def data_received(self, data: bytes) -> None: self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug("! error while sending data", exc_info=True) self.set_recv_exc(exc) if self.protocol.close_expected(): diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 116445e25..a2dc0250f 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -603,14 +603,14 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( - "! connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds", initial_delay, exc_info=True, ) await asyncio.sleep(initial_delay) else: self.logger.info( - "! connect failed again; retrying in %d seconds", + "connect failed again; retrying in %d seconds", int(backoff_delay), exc_info=True, ) diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index cedde6200..bd998dfd1 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1246,7 +1246,7 @@ async def keepalive_ping(self) -> None: self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: - self.logger.debug("! timed out waiting for keepalive pong") + self.logger.debug("- timed out waiting for keepalive pong") self.fail_connection( CloseCode.INTERNAL_ERROR, "keepalive ping timeout", @@ -1288,7 +1288,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") # Half-close the TCP connection if possible (when there's no TLS). if self.transport.can_write_eof(): @@ -1306,7 +1306,7 @@ async def close_connection(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") finally: # The try/finally ensures that the transport never remains open, @@ -1332,7 +1332,7 @@ async def close_transport(self) -> None: if await self.wait_for_connection_lost(): return if self.debug: - self.logger.debug("! timed out waiting for TCP close") + self.logger.debug("- timed out waiting for TCP close") # Abort the TCP connection. Buffers are discarded. if self.debug: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 8d1dbcf58..77f803c9b 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -640,7 +640,10 @@ def recv_events(self) -> None: data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: - self.logger.debug("error while receiving data", exc_info=True) + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. # In that case, send_context() already set recv_exc. @@ -665,7 +668,10 @@ def recv_events(self) -> None: self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # Similarly to the above, avoid overriding an exception # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() @@ -783,7 +789,10 @@ def send_context( self.send_data() except Exception as exc: if self.debug: - self.logger.debug("error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 502ab68e7..375d47e29 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1615,12 +1615,12 @@ async def run_client(): [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed; reconnecting in X seconds", + "connect failed; reconnecting in X seconds", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "! connect failed again; retrying in X seconds", + "connect failed again; retrying in X seconds", ] * ((len(logs.records) - 8) // 3) + [ From 9b3595d8d4d0573e00209f9e920f6a6fab981fa9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 4 Nov 2024 22:58:18 +0100 Subject: [PATCH 663/762] Remove stack traces from INFO and WARNING logs. Fix #1501. --- src/websockets/asyncio/client.py | 6 ++++-- src/websockets/asyncio/connection.py | 9 +++++++-- src/websockets/legacy/client.py | 13 ++++++++----- src/websockets/legacy/protocol.py | 9 +++++++-- tests/asyncio/test_connection.py | 5 ++++- tests/legacy/test_client_server.py | 8 ++++++-- tests/legacy/test_protocol.py | 4 ++-- 7 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 302d0b94d..b3d50c12e 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -3,6 +3,7 @@ import asyncio import logging import os +import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType @@ -521,9 +522,10 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays = backoff() delay = next(delays) self.logger.info( - "connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds: %s", delay, - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(delay) continue diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index c4961884c..186846ef3 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -7,6 +7,7 @@ import random import struct import sys +import traceback import uuid from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType @@ -1180,8 +1181,12 @@ def broadcast( exceptions.append(exception) else: connection.logger.warning( - "skipped broadcast: failed to write message", - exc_info=True, + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only( + # Remove first argument when dropping Python 3.9. + type(write_exception), + write_exception, + )[0].strip(), ) if raise_exceptions and exceptions: diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a2dc0250f..555069e8c 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -5,6 +5,7 @@ import logging import os import random +import traceback import urllib.parse import warnings from collections.abc import AsyncIterator, Generator, Sequence @@ -597,22 +598,24 @@ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: try: async with self as protocol: yield protocol - except Exception: + except Exception as exc: # Add a random initial delay between 0 and 5 seconds. # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. if backoff_delay == self.BACKOFF_MIN: initial_delay = random.random() * self.BACKOFF_INITIAL self.logger.info( - "connect failed; reconnecting in %.1f seconds", + "connect failed; reconnecting in %.1f seconds: %s", initial_delay, - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(initial_delay) else: self.logger.info( - "connect failed again; retrying in %d seconds", + "connect failed again; retrying in %d seconds: %s", int(backoff_delay), - exc_info=True, + # Remove first argument when dropping Python 3.9. + traceback.format_exception_only(type(exc), exc)[0].strip(), ) await asyncio.sleep(int(backoff_delay)) # Increase delay with truncated exponential backoff. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index bd998dfd1..db126c01e 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -9,6 +9,7 @@ import struct import sys import time +import traceback import uuid import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping @@ -1624,8 +1625,12 @@ def broadcast( exceptions.append(exception) else: websocket.logger.warning( - "skipped broadcast: failed to write message", - exc_info=True, + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only( + # Remove first argument when dropping Python 3.9. + type(write_exception), + write_exception, + )[0].strip(), ) if raise_exceptions and exceptions: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index d61798afb..902b3b847 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1337,7 +1337,10 @@ async def test_broadcast_skips_connection_failing_to_send(self): self.assertEqual( [record.getMessage() for record in logs.records], - ["skipped broadcast: failed to write message"], + [ + "skipped broadcast: failed to write message: " + "RuntimeError: Cannot call write() after write_eof()" + ], ) @unittest.skipIf( diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index 375d47e29..c13c6c92e 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -1607,6 +1607,10 @@ async def run_client(): ], ) # Iteration 3 + exc = ( + "websockets.legacy.exceptions.InvalidStatusCode: " + "server rejected WebSocket connection: HTTP 503" + ) self.assertEqual( [ re.sub(r"[0-9\.]+ seconds", "X seconds", record.getMessage()) @@ -1615,12 +1619,12 @@ async def run_client(): [ "connection rejected (503 Service Unavailable)", "connection closed", - "connect failed; reconnecting in X seconds", + f"connect failed; reconnecting in X seconds: {exc}", ] + [ "connection rejected (503 Service Unavailable)", "connection closed", - "connect failed again; retrying in X seconds", + f"connect failed again; retrying in X seconds: {exc}", ] * ((len(logs.records) - 8) // 3) + [ diff --git a/tests/legacy/test_protocol.py b/tests/legacy/test_protocol.py index be2910a8f..d30198934 100644 --- a/tests/legacy/test_protocol.py +++ b/tests/legacy/test_protocol.py @@ -1547,14 +1547,14 @@ def test_broadcast_reports_connection_sending_fragmented_text(self): def test_broadcast_skips_connection_failing_to_send(self): # Configure mock to raise an exception when writing to the network. - self.protocol.transport.write.side_effect = RuntimeError + self.protocol.transport.write.side_effect = RuntimeError("BOOM") with self.assertLogs("websockets", logging.WARNING) as logs: broadcast([self.protocol], "café") self.assertEqual( [record.getMessage() for record in logs.records], - ["skipped broadcast: failed to write message"], + ["skipped broadcast: failed to write message: RuntimeError: BOOM"], ) @unittest.skipIf( From 5f34e2741e94ac21ef99e3dc212aa152d77d1a37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 08:55:26 +0100 Subject: [PATCH 664/762] Fix remaining instances of shortcut imports. --- docs/topics/broadcast.rst | 2 +- docs/topics/keepalive.rst | 2 +- experiments/broadcast/server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/topics/broadcast.rst b/docs/topics/broadcast.rst index c9699feb2..66b0819b2 100644 --- a/docs/topics/broadcast.rst +++ b/docs/topics/broadcast.rst @@ -83,7 +83,7 @@ to:: Here's a coroutine that broadcasts a message to all clients:: - from websockets import ConnectionClosed + from websockets.exceptions import ConnectionClosed async def broadcast(message): for websocket in CLIENTS.copy(): diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 003087fad..4897de2ba 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -112,7 +112,7 @@ You can run a task in the background to send keepalive messages: import itertools import json - from websockets import ConnectionClosed + from websockets.exceptions import ConnectionClosed async def keepalive(websocket, ping_interval=30): for ping in itertools.count(): diff --git a/experiments/broadcast/server.py b/experiments/broadcast/server.py index d5b50bd71..eca55357e 100644 --- a/experiments/broadcast/server.py +++ b/experiments/broadcast/server.py @@ -6,8 +6,8 @@ import sys import time -from websockets import ConnectionClosed from websockets.asyncio.server import broadcast, serve +from websockets.exceptions import ConnectionClosed CLIENTS = set() From c57bcb743fe1128f6d28e2eaebbdd202eb3ee2eb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 09:36:42 +0100 Subject: [PATCH 665/762] Standardize links to RFC. There were 70 links to https://datatracker.ietf.org/doc/html/ vs. 15 links https://www.rfc-editor.org/rfc/. Also :rfc:`....` links to https://datatracker.ietf.org/doc/html/ by default. While https://www.ietf.org/process/rfcs/#introduction says: > The RFC Editor website is the authoritative site for RFCs. the IETF Datatracker looks a bit better and has more information. --- docs/faq/common.rst | 4 ++-- docs/howto/extensions.rst | 2 +- docs/project/changelog.rst | 2 +- docs/reference/extensions.rst | 2 +- docs/topics/design.rst | 10 +++++----- docs/topics/keepalive.rst | 6 +++--- docs/topics/logging.rst | 4 ++-- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/faq/common.rst b/docs/faq/common.rst index 0dc4a3aeb..ba7a95932 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -131,8 +131,8 @@ How do I respond to pings? If you are referring to Ping_ and Pong_ frames defined in the WebSocket protocol, don't bother, because websockets handles them for you. -.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 -.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 +.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 +.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 If you are connecting to a server that defines its own heartbeat at the application level, then you need to build that logic into your application. diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index 3c8a7d72a..c4e9da626 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -7,7 +7,7 @@ During the opening handshake, WebSocket clients and servers negotiate which extensions_ will be used with which parameters. Then each frame is processed by extensions before being sent or after being received. -.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 +.. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 As a consequence, writing an extension requires implementing several classes: diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 71b2a6960..1056bc980 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -1487,7 +1487,7 @@ New features * Added support for providing and checking Origin_. -.. _Origin: https://www.rfc-editor.org/rfc/rfc6455.html#section-10.2 +.. _Origin: https://datatracker.ietf.org/doc/html/rfc6455.html#section-10.2 .. _2.0: diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index a70f1b1e5..f3da464a5 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -8,7 +8,7 @@ The WebSocket protocol supports extensions_. At the time of writing, there's only one `registered extension`_ with a public specification, WebSocket Per-Message Deflate. -.. _extensions: https://www.rfc-editor.org/rfc/rfc6455.html#section-9 +.. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 .. _registered extension: https://www.iana.org/assignments/websocket/websocket.xhtml#extension-name Per-Message Deflate diff --git a/docs/topics/design.rst b/docs/topics/design.rst index b73ace517..bc14bd332 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -173,16 +173,16 @@ differences between a server and a client: - `closing the TCP connection`_: the server closes the connection immediately; the client waits for the server to do it. -.. _client-to-server masking: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.3 -.. _closing the TCP connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.1 +.. _client-to-server masking: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.3 +.. _closing the TCP connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.1 These differences are so minor that all the logic for `data framing`_, for `sending and receiving data`_ and for `closing the connection`_ is implemented in the same class, :class:`~protocol.WebSocketCommonProtocol`. -.. _data framing: https://www.rfc-editor.org/rfc/rfc6455.html#section-5 -.. _sending and receiving data: https://www.rfc-editor.org/rfc/rfc6455.html#section-6 -.. _closing the connection: https://www.rfc-editor.org/rfc/rfc6455.html#section-7 +.. _data framing: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5 +.. _sending and receiving data: https://datatracker.ietf.org/doc/html/rfc6455.html#section-6 +.. _closing the connection: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7 The :attr:`~protocol.WebSocketCommonProtocol.is_client` attribute tells which side a protocol instance is managing. This attribute is defined on the diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index 4897de2ba..a0467ced2 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -33,8 +33,8 @@ Keepalive in websockets To avoid these problems, websockets runs a keepalive and heartbeat mechanism based on WebSocket Ping_ and Pong_ frames, which are designed for this purpose. -.. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 -.. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 +.. _Ping: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.2 +.. _Pong: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.5.3 It sends a Ping frame every 20 seconds. It expects a Pong frame in return within 20 seconds. Else, it considers the connection broken and terminates it. @@ -98,7 +98,7 @@ at regular intervals. Usually they expect Text_ frames rather than Ping_ frames, meaning that you must send them with :attr:`~asyncio.connection.Connection.send` rather than :attr:`~asyncio.connection.Connection.ping`. -.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455.html#section-5.6 In websockets, such keepalive mechanisms are considered as application-level because they rely on data frames. That's unlike the protocol-level keepalive diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index ae71be265..fff33a024 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -35,8 +35,8 @@ Instead, when running as a server, websockets logs one event when a `connection is established`_ and another event when a `connection is closed`_. -.. _connection is established: https://www.rfc-editor.org/rfc/rfc6455.html#section-4 -.. _connection is closed: https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.4 +.. _connection is established: https://datatracker.ietf.org/doc/html/rfc6455.html#section-4 +.. _connection is closed: https://datatracker.ietf.org/doc/html/rfc6455.html#section-7.1.4 By default, websockets doesn't log an event for every message. That would be excessive for many applications exchanging small messages at a fast rate. If From 178c88447c2754d2d2b0c02472868cbaef7cec52 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 13:56:49 +0100 Subject: [PATCH 666/762] Complete and polish changelog for 14.0. --- docs/project/changelog.rst | 41 +++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1056bc980..161f71c7f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,7 +41,7 @@ Backwards-incompatible changes websockets 13.1 is the last version supporting Python 3.8. .. admonition:: The new :mod:`asyncio` implementation is now the default. - :class: caution + :class: danger The following aliases in the ``websockets`` package were switched to the new :mod:`asyncio` implementation:: @@ -64,8 +64,8 @@ Backwards-incompatible changes The :doc:`upgrade guide <../howto/upgrade>` provides complete instructions to migrate your application. - Aliases for deprecated API were removed from ``__all__``. As a consequence, - they cannot be imported e.g. with ``from websockets import *`` anymore. + Aliases for deprecated API were removed from ``websockets.__all__``, meaning + that they cannot be imported with ``from websockets import *`` anymore. .. admonition:: Several API raise :exc:`ValueError` instead of :exc:`TypeError` on invalid arguments. @@ -83,22 +83,16 @@ Backwards-incompatible changes :class: note In addition to :class:`bytes`, it may be a :class:`bytearray` or a - :class:`memoryview`. - - If you wrote an :class:`extension ` that relies on - methods not provided by these new types, you may need to update your code. + :class:`memoryview`. If you wrote an :class:`~extensions.Extension` that + relies on methods not provided by these types, you must update your code. .. admonition:: The signature of :exc:`~exceptions.PayloadTooBig` changed. :class: note If you wrote an extension that raises :exc:`~exceptions.PayloadTooBig` in - :meth:`~extensions.Extension.decode`, for example, you must replace:: - - PayloadTooBig(f"over size limit ({size} > {max_size} bytes)") - - with:: - - PayloadTooBig(size, max_size) + :meth:`~extensions.Extension.decode`, for example, you must replace + ``PayloadTooBig(f"over size limit ({size} > {max_size} bytes)")`` with + ``PayloadTooBig(size, max_size)``. New features ............ @@ -106,8 +100,8 @@ New features * Added an option to receive text frames as :class:`bytes`, without decoding, in the :mod:`threading` implementation; also binary frames as :class:`str`. -* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio` - and :mod:`threading` implementations, as well as :class:`str` a binary frame. +* Added an option to send :class:`bytes` in a text frame in the :mod:`asyncio` + and :mod:`threading` implementations; also :class:`str` in a binary frame. Improvements ............ @@ -118,6 +112,21 @@ Improvements * Errors when a fragmented message is too large are clearer. +* Log messages at the :data:`~logging.WARNING` and :data:`~logging.INFO` levels + no longer include stack traces. + +Bug fixes +......... + +* Clients no longer crash when the server rejects the opening handshake and the + HTTP response doesn't Include a ``Content-Length`` header. + +* Returning an HTTP response in ``process_request`` or ``process_response`` + doesn't generate a log message at the :data:`~logging.ERROR` level anymore. + +* Connections are closed with code 1007 (invalid data) when receiving invalid + UTF-8 in a text frame. + .. _13.1: 13.1 From f0d20aafab027e9b99460b193dcb709872b219a5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 13:58:10 +0100 Subject: [PATCH 667/762] Release version 14.0. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 161f71c7f..5aa58a09b 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 14.0 ---- -*In development* +*November 9, 2024* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 34fc2eaef..7c64f566a 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "14.0" From b9d74504eaeedd35cb7dc2651a18420f10e3828d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 14:01:07 +0100 Subject: [PATCH 668/762] Start version 14.1. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5aa58a09b..9e6a9d113 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.1: + +14.1 +---- + +*In development* + .. _14.0: 14.0 diff --git a/src/websockets/version.py b/src/websockets/version.py index 7c64f566a..48d2edaea 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "14.0" +tag = version = commit = "14.1" if not released: # pragma: no cover From e9fc77da927793d05072163d61e137dd35f97e4d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 9 Nov 2024 14:01:22 +0100 Subject: [PATCH 669/762] Add dates for deprecations. --- src/websockets/__init__.py | 2 +- src/websockets/auth.py | 2 +- src/websockets/client.py | 2 +- src/websockets/exceptions.py | 8 ++++---- src/websockets/legacy/__init__.py | 2 +- src/websockets/server.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 531ce49f7..0c7e9b4c6 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -151,7 +151,7 @@ "handshake": ".legacy", "parse_uri": ".uri", "WebSocketURI": ".uri", - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 # .legacy.auth "BasicAuthWebSocketServerProtocol": ".legacy.auth", "basic_auth_protocol_factory": ".legacy.auth", diff --git a/src/websockets/auth.py b/src/websockets/auth.py index 98e62af3c..15b70a372 100644 --- a/src/websockets/auth.py +++ b/src/websockets/auth.py @@ -10,7 +10,7 @@ from .legacy.auth import __all__ # noqa: F401 -warnings.warn( # deprecated in 14.0 +warnings.warn( # deprecated in 14.0 - 2024-11-09 "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", diff --git a/src/websockets/client.py b/src/websockets/client.py index 8b66900a8..f6cbc9f65 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -392,7 +392,7 @@ def backoff( lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "WebSocketClientProtocol": ".legacy.client", "connect": ".legacy.client", "unix_connect": ".legacy.client", diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index be3d1ca5f..f3e751971 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -113,7 +113,7 @@ def __str__(self) -> str: @property def code(self) -> int: - warnings.warn( # deprecated in 13.1 + warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.code is deprecated; " "use Protocol.close_code or ConnectionClosed.rcvd.code", DeprecationWarning, @@ -124,7 +124,7 @@ def code(self) -> int: @property def reason(self) -> str: - warnings.warn( # deprecated in 13.1 + warnings.warn( # deprecated in 13.1 - 2024-09-21 "ConnectionClosed.reason is deprecated; " "use Protocol.close_reason or ConnectionClosed.rcvd.reason", DeprecationWarning, @@ -343,7 +343,7 @@ def __init__( if isinstance(size_or_message, str): assert max_size is None assert cur_size is None - warnings.warn( # deprecated in 14.0 + warnings.warn( # deprecated in 14.0 - 2024-11-09 "PayloadTooBig(message) is deprecated; " "change to PayloadTooBig(size, max_size)", DeprecationWarning, @@ -408,7 +408,7 @@ class ConcurrencyError(WebSocketException, RuntimeError): lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "AbortHandshake": ".legacy.exceptions", "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", diff --git a/src/websockets/legacy/__init__.py b/src/websockets/legacy/__init__.py index 84f870f3a..ad9aa2506 100644 --- a/src/websockets/legacy/__init__.py +++ b/src/websockets/legacy/__init__.py @@ -3,7 +3,7 @@ import warnings -warnings.warn( # deprecated in 14.0 +warnings.warn( # deprecated in 14.0 - 2024-11-09 "websockets.legacy is deprecated; " "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " "for upgrade instructions", diff --git a/src/websockets/server.py b/src/websockets/server.py index e3fdcc646..607cc306e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -580,7 +580,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: lazy_import( globals(), deprecated_aliases={ - # deprecated in 14.0 + # deprecated in 14.0 - 2024-11-09 "WebSocketServer": ".legacy.server", "WebSocketServerProtocol": ".legacy.server", "broadcast": ".legacy.server", From 083bcacd485a12dc1f9c6b98123bb92fafa4a9cf Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:50:43 +0100 Subject: [PATCH 670/762] Import ConnectionClosed from websockets.exceptions. Fix #1539. --- docs/faq/client.rst | 3 ++- docs/faq/server.rst | 2 +- docs/intro/tutorial1.rst | 4 +++- src/websockets/asyncio/client.py | 2 +- src/websockets/legacy/client.py | 2 +- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index 0a7aab6e2..cc9856a8b 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -84,11 +84,12 @@ How do I reconnect when the connection drops? Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: from websockets.asyncio.client import connect + from websockets.exceptions import ConnectionClosed async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except ConnectionClosed: continue Make sure you handle exceptions in the ``async for`` loop. Uncaught exceptions diff --git a/docs/faq/server.rst b/docs/faq/server.rst index 63eb5ffc6..ce7e1962d 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -147,7 +147,7 @@ Then, call :meth:`~ServerConnection.send`:: async def message_user(user_id, message): websocket = CONNECTIONS[user_id] # raises KeyError if user disconnected - await websocket.send(message) # may raise websockets.ConnectionClosed + await websocket.send(message) # may raise websockets.exceptions.ConnectionClosed Add error handling according to the behavior you want if the user disconnected before the message could be sent. diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 6e91867c8..87074caee 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -271,11 +271,13 @@ spot real errors when you add functionality to the server. Catch it in the .. code-block:: python + from websockets.exceptions import ConnectionClosedOK + async def handler(websocket): while True: try: message = await websocket.recv() - except websockets.ConnectionClosedOK: + except ConnectionClosedOK: break print(message) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index b3d50c12e..74ae70f0d 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -182,7 +182,7 @@ class connect: async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except websockets.exceptions.ConnectionClosed: continue If the connection fails with a transient error, it is retried with diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 555069e8c..a3856b470 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -349,7 +349,7 @@ class Connect: async for websocket in connect(...): try: ... - except websockets.ConnectionClosed: + except websockets.exceptions.ConnectionClosed: continue The connection is closed automatically after each iteration of the loop. From 86bf0c5afc4c88429231ac7ba5f857fd32a02d0e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 17:40:55 +0100 Subject: [PATCH 671/762] Remove unnecessary branch. --- src/websockets/asyncio/messages.py | 3 +-- tests/asyncio/test_messages.py | 8 -------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index b57c0ca4e..69636e3d0 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -43,8 +43,7 @@ def put(self, item: T) -> None: async def get(self) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: - if self.get_waiter is not None: - raise ConcurrencyError("get is already running") + assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: await self.get_waiter diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 2ff929d3a..5c9ac9445 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -37,14 +37,6 @@ async def test_get_then_put(self): item = await getter_task self.assertEqual(item, 42) - async def test_get_concurrently(self): - """get cannot be called concurrently.""" - getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start - with self.assertRaises(ConcurrencyError): - await self.queue.get() - getter_task.cancel() - async def test_reset(self): """reset sets the content of the queue.""" self.queue.reset([42]) From f17c11ab6b46cfe6817cbab5d92ba2626d3e87d5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 17:42:16 +0100 Subject: [PATCH 672/762] Keep queued messages after abort. --- src/websockets/asyncio/messages.py | 2 -- tests/asyncio/test_messages.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 69636e3d0..678a3c14e 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -62,8 +62,6 @@ def abort(self) -> None: """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) - # Clear the queue to avoid storing unnecessary data in memory. - self.queue.clear() class Assembler: diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 5c9ac9445..181ffd376 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -51,13 +51,6 @@ async def test_abort(self): with self.assertRaises(EOFError): await getter_task - async def test_abort_clears_queue(self): - """abort clears buffered data from the queue.""" - self.queue.put(42) - self.assertEqual(len(self.queue), 1) - self.queue.abort() - self.assertEqual(len(self.queue), 0) - class AssemblerTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): From bdfc8cf90301b528eebfbd0ef8a31e5a2fc7d5f7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:11:02 +0100 Subject: [PATCH 673/762] Uniformize comments. --- src/websockets/asyncio/messages.py | 8 ++++---- src/websockets/sync/messages.py | 11 +++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 678a3c14e..814a3c03c 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -141,8 +141,8 @@ async def get(self, decode: bool | None = None) -> Data: self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is cancelled. try: # First frame @@ -208,8 +208,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get_iter() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is cancelled. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 17f8dce7e..ce08172b2 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -128,10 +128,11 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or times out. + try: deadline = Deadline(timeout) @@ -198,12 +199,10 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress prevents concurrent execution until - # get_iter() fetches a complete message or is cancelled. + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or times out. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. From 303483412dc5b420d09c1421792b3f8b99c323e6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 18:30:36 +0100 Subject: [PATCH 674/762] Support recv() after the connection is closed. Fix #1538. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/messages.py | 20 ++++-------- src/websockets/sync/messages.py | 27 ++++++++-------- tests/asyncio/test_connection.py | 8 ++--- tests/asyncio/test_messages.py | 52 ++++++++++++++++++++++++++++++ tests/sync/test_connection.py | 19 ++++------- tests/sync/test_messages.py | 52 ++++++++++++++++++++++++++++++ 7 files changed, 142 insertions(+), 43 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9e6a9d113..074a81c85 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Bug fixes +......... + +* Once the connection is closed, messages previously received and buffered can + be read in the :mod:`asyncio` and :mod:`threading` implementations, just like + in the legacy implementation. + 14.0 ---- diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 814a3c03c..14ea7bf90 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -40,9 +40,11 @@ def put(self, item: T) -> None: if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_result(None) - async def get(self) -> T: + async def get(self, block: bool = True) -> T: """Remove and return an item from the queue, waiting if necessary.""" if not self.queue: + if not block: + raise EOFError("stream of frames ended") assert self.get_waiter is None, "cannot call get() concurrently" self.get_waiter = self.loop.create_future() try: @@ -133,12 +135,8 @@ async def get(self, decode: bool | None = None) -> Data: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -146,7 +144,7 @@ async def get(self, decode: bool | None = None) -> Data: try: # First frame - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY if decode is None: @@ -156,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data: # Following frames, for fragmented messages while not frame.fin: try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: # Put frames already received back into the queue # so that future calls to get() can return them. @@ -200,12 +198,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :meth:`get_iter` concurrently. """ - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution @@ -216,7 +210,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # First frame try: - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) except asyncio.CancelledError: self.get_in_progress = False raise @@ -236,7 +230,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # previous fragments — we're streaming them. Canceling get_iter() # here will leave the assembler in a stuck state. Future calls to # get() or get_iter() will raise ConcurrencyError. - frame = await self.frames.get() + frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_CONT if decode: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index ce08172b2..af8635f16 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -69,10 +69,16 @@ def __init__( def get_next_frame(self, timeout: float | None = None) -> Frame: # Helper to factor out the logic for getting the next frame from the # queue, while handling timeouts and reaching the end of the stream. - try: - frame = self.frames.get(timeout=timeout) - except queue.Empty: - raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if self.closed: + try: + frame = self.frames.get(block=False) + except queue.Empty: + raise EOFError("stream of frames ended") from None + else: + try: + frame = self.frames.get(block=True, timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None if frame is None: raise EOFError("stream of frames ended") return frame @@ -87,7 +93,7 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: queued = [] try: while True: - queued.append(self.frames.get_nowait()) + queued.append(self.frames.get(block=False)) except queue.Empty: pass for frame in frames: @@ -123,9 +129,6 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -194,9 +197,6 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ with self.mutex: - if self.closed: - raise EOFError("stream of frames ended") - if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -288,5 +288,6 @@ def close(self) -> None: self.closed = True - # Unblock get() or get_iter(). - self.frames.put(None) + if self.get_in_progress: + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 902b3b847..b1c57c8ca 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - async def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + async def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" await self.remote_connection.send("😀") await self.connection.close() + self.assertEqual(await self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 181ffd376..566f71cea 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -395,6 +395,58 @@ async def test_get_iter_fails_after_close(self): async for _ in self.assembler.get_iter(): self.fail("no fragment expected") + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + async def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index abdfd3f78..408b9697a 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -543,17 +543,12 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_does_not_wait_for_recv(self): - # Closing the connection discards messages buffered in the assembler. - # This is allowed by the RFC: - # > However, there is no guarantee that the endpoint that has already - # > sent a Close frame will continue to process data. + def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" self.remote_connection.send("😀") self.connection.close() - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - + self.assertEqual(self.connection.recv(), "😀") with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -576,10 +571,10 @@ def test_close_idempotency(self): def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" - self.connection.close_timeout = 5 * MS + self.connection.close_timeout = 6 * MS def closer(): - with self.delay_frames_rcvd(3 * MS): + with self.delay_frames_rcvd(4 * MS): self.connection.close() close_thread = threading.Thread(target=closer) @@ -591,14 +586,14 @@ def closer(): # Connection isn't closed yet. with self.assertRaises(TimeoutError): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) self.connection.close() self.assertNoFrameSent() # Connection is closed now. with self.assertRaises(ConnectionClosedOK): - self.connection.recv(timeout=0) + self.connection.recv(timeout=MS) close_thread.join() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 02513894a..9ebe45088 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -374,6 +374,58 @@ def test_get_iter_fails_after_close(self): for _ in self.assembler.get_iter(): self.fail("no fragment expected") + def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, "café") + + def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = self.assembler.get() + self.assertEqual(message, b"tea") + + def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOF on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.get() + + def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + def test_put_fails_after_close(self): """put raises EOFError after close is called.""" self.assembler.close() From 9a2f39fc66fd2427f904e551b1cc5f3995b02217 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 21:13:39 +0100 Subject: [PATCH 675/762] Support max_queue=None like the legacy implementation. Fix #1540. --- docs/project/changelog.rst | 7 ++++ src/websockets/asyncio/client.py | 7 ++-- src/websockets/asyncio/connection.py | 4 +- src/websockets/asyncio/messages.py | 23 ++++++++--- src/websockets/asyncio/server.py | 7 ++-- src/websockets/sync/client.py | 7 ++-- src/websockets/sync/connection.py | 4 +- src/websockets/sync/messages.py | 25 +++++++++--- src/websockets/sync/server.py | 7 ++-- tests/asyncio/test_connection.py | 12 +++++- tests/asyncio/test_messages.py | 61 +++++++++++++++++++++++++--- tests/sync/test_connection.py | 17 +++++++- tests/sync/test_messages.py | 61 +++++++++++++++++++++++++--- 13 files changed, 200 insertions(+), 42 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 074a81c85..8e1ad81f0 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -34,6 +34,13 @@ notice. .. _14.0: +Improvements +............ + +* Supported ``max_queue=None`` in the :mod:`asyncio` and :mod:`threading` + implementations for consistency with the legacy implementation, even though + this is never a good idea. + Bug fixes ......... diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 74ae70f0d..cdd9bfac6 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -60,7 +60,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ClientProtocol @@ -222,7 +222,8 @@ class connect: max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -283,7 +284,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 186846ef3..f1dcbada6 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -56,14 +56,14 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol = protocol self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue if isinstance(write_limit, int): diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 14ea7bf90..e6d1d31cc 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -84,7 +84,7 @@ class Assembler: # coverage reports incorrectly: "line NN didn't jump to the function exit" def __init__( # pragma: no cover self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, @@ -96,12 +96,15 @@ def __init__( # pragma: no cover # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -256,6 +259,10 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + # Check for "> high" to support high = 0 if len(self.frames) > self.high and not self.paused: self.paused = True @@ -263,6 +270,10 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + # Check for "<= low" to support low = 0 if len(self.frames) <= self.low and self.paused: self.paused = False diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 15c9ba13e..fdb928004 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -71,7 +71,7 @@ def __init__( ping_interval: float | None = 20, ping_timeout: float | None = 20, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, ) -> None: self.protocol: ServerProtocol @@ -643,7 +643,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. write_limit: High-water mark of write buffer in bytes. It is passed to :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults to 32 KiB. You may pass a ``(high, low)`` tuple to set the @@ -713,7 +714,7 @@ def __init__( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, write_limit: int | tuple[int, int | None] = 2**15, # Logging logger: LoggerLike | None = None, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 54d0aef68..9e6da7caf 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -55,7 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -139,7 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -191,7 +191,8 @@ def connect( max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 77f803c9b..be3381c8a 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,12 +49,12 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout - if isinstance(max_queue, int): + if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) self.max_queue = max_queue diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index af8635f16..98490797f 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -33,7 +33,7 @@ class Assembler: def __init__( self, - high: int = 16, + high: int | None = None, low: int | None = None, pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, @@ -49,12 +49,15 @@ def __init__( # call to Protocol.data_received() could produce thousands of frames, # which must be buffered. Instead, we pause reading when the buffer goes # above the high limit and we resume when it goes under the low limit. - if low is None: + if high is not None and low is None: low = high // 4 - if low < 0: - raise ValueError("low must be positive or equal to zero") - if high < low: - raise ValueError("high must be greater than or equal to low") + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") self.high, self.low = high, low self.pause = pause self.resume = resume @@ -260,7 +263,12 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + assert self.mutex.locked() + # Check for "> high" to support high = 0 if self.frames.qsize() > self.high and not self.paused: self.paused = True @@ -268,7 +276,12 @@ def maybe_pause(self) -> None: def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + assert self.mutex.locked() + # Check for "<= low" to support low = 0 if self.frames.qsize() <= self.low and self.paused: self.paused = False diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 8601ccef9..9506d6830 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -66,7 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -356,7 +356,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, - max_queue: int | tuple[int, int | None] = 16, + max_queue: int | None | tuple[int | None, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -438,7 +438,8 @@ def handler(websocket): max_queue: High-water mark of the buffer where frames are received. It defaults to 16 frames. The low-water mark defaults to ``max_queue // 4``. You may pass a ``(high, low)`` tuple to set the high-water - and low-water marks. + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index b1c57c8ca..8dd0a0335 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1066,14 +1066,22 @@ async def test_close_timeout(self): self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water mark of frames buffer.""" connection = Connection(Protocol(self.LOCAL), max_queue=4) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + connection = Connection(Protocol(self.LOCAL), max_queue=None) + transport = Mock() + connection.connection_made(transport) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + async def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water and low-water marks of frames buffer.""" connection = Connection( Protocol(self.LOCAL), max_queue=(4, 2), diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index 566f71cea..a90788d02 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -153,7 +153,7 @@ async def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") async def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -170,6 +170,19 @@ async def test_get_resumes_reading(self): await self.assembler.get() self.resume.assert_called_once_with() + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + + self.resume.assert_not_called() + async def test_cancel_get_before_first_frame(self): """get can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(self.assembler.get()) @@ -302,7 +315,7 @@ async def test_get_iter_decoded_binary_message(self): self.assertEqual(fragments, ["t", "e", "a"]) async def test_get_iter_resumes_reading(self): - """get_iter resumes reading when queue goes below the high-water mark.""" + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -321,6 +334,20 @@ async def test_get_iter_resumes_reading(self): await anext(iterator) self.resume.assert_called_once_with() + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + await anext(iterator) + await anext(iterator) + await anext(iterator) + + self.resume.assert_not_called() + async def test_cancel_get_iter_before_first_frame(self): """get_iter can be canceled safely before reading the first frame.""" getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -367,6 +394,17 @@ async def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination async def test_get_fails_when_interrupted_by_close(self): @@ -495,16 +533,29 @@ async def test_get_iter_fails_when_get_iter_is_running(self): # Test setting limits async def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark and low-water mark.""" + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + async def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + async def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 408b9697a..6be490a5d 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -749,7 +749,7 @@ def test_close_timeout(self): self.assertEqual(connection.close_timeout, 42 * MS) def test_max_queue(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water mark of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) @@ -760,8 +760,21 @@ def test_max_queue(self): ) self.assertEqual(connection.recv_messages.high, 4) + def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.high, None) + def test_max_queue_tuple(self): - """max_queue parameter configures high-water mark of frames buffer.""" + """max_queue configures high-water and low-water marks of frames buffer.""" socket_, remote_socket = socket.socketpair() self.addCleanup(socket_.close) self.addCleanup(remote_socket.close) diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 9ebe45088..d22693102 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -145,7 +145,7 @@ def test_get_decoded_binary_message(self): self.assertEqual(message, "tea") def test_get_resumes_reading(self): - """get resumes reading when queue goes below the high-water mark.""" + """get resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) @@ -162,6 +162,19 @@ def test_get_resumes_reading(self): self.assembler.get() self.resume.assert_called_once_with() + def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + self.assembler.get() + self.assembler.get() + self.assembler.get() + + self.resume.assert_not_called() + def test_get_timeout_before_first_frame(self): """get times out before reading the first frame.""" with self.assertRaises(TimeoutError): @@ -300,7 +313,7 @@ def test_get_iter_decoded_binary_message(self): self.assertEqual(fragments, ["t", "e", "a"]) def test_get_iter_resumes_reading(self): - """get_iter resumes reading when queue goes below the high-water mark.""" + """get_iter resumes reading when queue goes below the low-water mark.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -319,6 +332,20 @@ def test_get_iter_resumes_reading(self): next(iterator) self.resume.assert_called_once_with() + def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = self.assembler.get_iter() + next(iterator) + next(iterator) + next(iterator) + + self.resume.assert_not_called() + # Test put def test_put_pauses_reading(self): @@ -336,6 +363,17 @@ def test_put_pauses_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() + def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + # Test termination def test_get_fails_when_interrupted_by_close(self): @@ -470,16 +508,29 @@ def test_get_iter_fails_when_get_iter_is_running(self): # Test setting limits def test_set_high_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water and low-water marks.""" assembler = Assembler(high=10) self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) - def test_set_high_and_low_water_mark(self): - """high sets the high-water mark and low-water mark.""" + def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) + def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + def test_set_invalid_high_water_mark(self): """high must be a non-negative integer.""" with self.assertRaises(ValueError): From de2e7fb8b7eaca56d633ae7d2ffffdbb212048a1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 21:40:26 +0100 Subject: [PATCH 676/762] Add close_code and close_reason to new implementations. Also add state to threading implementation. Fix #1537. --- docs/reference/asyncio/client.rst | 7 ++++++ docs/reference/asyncio/common.rst | 7 ++++++ docs/reference/asyncio/server.rst | 7 ++++++ docs/reference/sync/client.rst | 9 +++++++ docs/reference/sync/common.rst | 9 +++++++ docs/reference/sync/server.rst | 9 +++++++ src/websockets/asyncio/connection.py | 24 ++++++++++++++++++ src/websockets/protocol.py | 21 ++++++++++------ src/websockets/sync/connection.py | 37 ++++++++++++++++++++++++++++ tests/asyncio/test_connection.py | 10 +++++++- tests/sync/test_connection.py | 14 ++++++++++- 11 files changed, 145 insertions(+), 9 deletions(-) diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index e2b0ff550..ea7b21506 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -57,3 +57,10 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index a58325fb9..325f20450 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -45,3 +45,10 @@ Both sides (new :mod:`asyncio`) .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 2fcaeb414..49bd6f072 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -79,6 +79,13 @@ Using a connection .. autoproperty:: subprotocol + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + Broadcast --------- diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index af1132412..2aa491f6a 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -39,6 +39,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -47,3 +49,10 @@ Using a connection .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 3dc6d4a50..3c03b25b6 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -31,6 +31,8 @@ Both sides (:mod:`threading`) .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -39,3 +41,10 @@ Both sides (:mod:`threading`) .. autoattribute:: response .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 80e9c17bb..1d80450f9 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -52,6 +52,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: state + The following attributes are available after the opening handshake, once the WebSocket connection is open: @@ -61,6 +63,13 @@ Using a connection .. autoproperty:: subprotocol + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + HTTP Basic Authentication ------------------------- diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index f1dcbada6..e5c350fe2 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -185,6 +185,30 @@ def subprotocol(self) -> Subprotocol | None: """ return self.protocol.subprotocol + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + # Public methods async def __aenter__(self) -> Connection: diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0f6fea250..bc64a216a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -159,7 +159,12 @@ def state(self) -> State: """ State of the WebSocket connection. - Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`. + Defined in 4.1_, 4.2_, 7.1.3_, and 7.1.4_ of :rfc:`6455`. + + .. _4.1: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 + .. _4.2: https://datatracker.ietf.org/doc/html/rfc6455#section-4.2 + .. _7.1.3: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.3 + .. _7.1.4: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 """ return self._state @@ -173,10 +178,11 @@ def state(self, state: State) -> None: @property def close_code(self) -> int | None: """ - `WebSocket close code`_. + WebSocket close code received from the remote endpoint. + + Defined in 7.1.5_ of :rfc:`6455`. - .. _WebSocket close code: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 + .. _7.1.5: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 :obj:`None` if the connection isn't closed yet. @@ -191,10 +197,11 @@ def close_code(self) -> int | None: @property def close_reason(self) -> str | None: """ - `WebSocket close reason`_. + WebSocket close reason received from the remote endpoint. + + Defined in 7.1.6_ of :rfc:`6455`. - .. _WebSocket close reason: - https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 + .. _7.1.6: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 :obj:`None` if the connection isn't closed yet. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index be3381c8a..d8dbf140e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -140,6 +140,19 @@ def remote_address(self) -> Any: """ return self.socket.getpeername() + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + @property def subprotocol(self) -> Subprotocol | None: """ @@ -150,6 +163,30 @@ def subprotocol(self) -> Subprotocol | None: """ return self.protocol.subprotocol + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + # Public methods def __enter__(self) -> Connection: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 8dd0a0335..5a0b61bf7 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1139,7 +1139,7 @@ async def test_remote_address(self, get_extra_info): async def test_state(self): """Connection has a state attribute.""" - self.assertEqual(self.connection.state, State.OPEN) + self.assertIs(self.connection.state, State.OPEN) async def test_request(self): """Connection has a request attribute.""" @@ -1153,6 +1153,14 @@ async def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) + async def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + async def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + # Test reporting of network errors. async def test_writing_in_data_received_fails(self): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 6be490a5d..4884bf13f 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -15,7 +15,7 @@ ConnectionClosedOK, ) from websockets.frames import CloseCode, Frame, Opcode -from websockets.protocol import CLIENT, SERVER, Protocol +from websockets.protocol import CLIENT, SERVER, Protocol, State from websockets.sync.connection import * from ..protocol import RecordingProtocol @@ -808,6 +808,10 @@ def test_remote_address(self, getpeername): self.assertEqual(self.connection.remote_address, ("peer", 1234)) getpeername.assert_called_with() + def test_state(self): + """Connection has a state attribute.""" + self.assertIs(self.connection.state, State.OPEN) + def test_request(self): """Connection has a request attribute.""" self.assertIsNone(self.connection.request) @@ -820,6 +824,14 @@ def test_subprotocol(self): """Connection has a subprotocol attribute.""" self.assertIsNone(self.connection.subprotocol) + def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + # Test reporting of network errors. @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") From 1f19487b0f55aac3f67a5f1b35209e0e3b294063 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 11 Nov 2024 22:01:58 +0100 Subject: [PATCH 677/762] Add changelog for previous commit. --- docs/project/changelog.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 8e1ad81f0..792f8ede4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -41,6 +41,10 @@ Improvements implementations for consistency with the legacy implementation, even though this is never a good idea. +* Added ``close_code`` and ``close_reason`` attributes in the :mod:`asyncio` and + :mod:`threading` implementations for consistency with the legacy + implementation. + Bug fixes ......... From 7438b8ebee3f6ef59c15b86d627d32f98f49df8f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 12 Nov 2024 08:40:24 +0100 Subject: [PATCH 678/762] Stop testing on PyPy 3.9. PyPy v7.3.17 no longer provides PyPy 3.9. Also the test suite was flaky under PyPy 3.9. --- .github/workflows/tests.yml | 3 --- tests/sync/test_connection.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index beaf9d12b..5ab9c4c72 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,13 +60,10 @@ jobs: - "3.11" - "3.12" - "3.13" - - "pypy-3.9" - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} exclude: - - python: "pypy-3.9" - is_main: false - python: "pypy-3.10" is_main: false steps: diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 4884bf13f..e21e310a2 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -1,6 +1,5 @@ import contextlib import logging -import platform import socket import sys import threading @@ -564,10 +563,6 @@ def test_close_idempotency(self): self.connection.close() self.assertNoFrameSent() - @unittest.skipIf( - platform.python_implementation() == "PyPy", - "this test fails randomly due to a bug in PyPy", # see #1314 for details - ) def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" From d0015c93f49511eb8dd073caa6a3a338979f741b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 07:51:11 +0100 Subject: [PATCH 679/762] Fix refactoring error in a78b5546. Fix #1546. --- example/faq/health_check_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index c0fa4327f..30623a4bb 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -6,7 +6,7 @@ def health_check(connection, request): if request.path == "/healthz": - return connection.respond(HTTPStatus.OK, b"OK\n") + return connection.respond(HTTPStatus.OK, "OK\n") async def echo(websocket): async for message in websocket: From 0403823185b272ae389da1c9182773932b4df950 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 07:53:45 +0100 Subject: [PATCH 680/762] Release version 14.1. --- docs/project/changelog.rst | 6 +++--- src/websockets/version.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 792f8ede4..ca6769199 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,9 +30,7 @@ notice. 14.1 ---- -*In development* - -.. _14.0: +*November 13, 2024* Improvements ............ @@ -52,6 +50,8 @@ Bug fixes be read in the :mod:`asyncio` and :mod:`threading` implementations, just like in the legacy implementation. +.. _14.0: + 14.0 ---- diff --git a/src/websockets/version.py b/src/websockets/version.py index 48d2edaea..f2defeff0 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "14.1" From d8891a101d7cfc2cf7e01de078ada7c93a264ebd Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 07:55:13 +0100 Subject: [PATCH 681/762] Start version 14.2. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index ca6769199..9c594b653 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.2: + +14.2 +---- + +*In development* + .. _14.1: 14.1 diff --git a/src/websockets/version.py b/src/websockets/version.py index f2defeff0..b0df10f0a 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "14.1" +tag = version = commit = "14.2" if not released: # pragma: no cover From 59d4dcf779fe7d2b0302083b072d8b03adce2f61 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 23:00:22 +0100 Subject: [PATCH 682/762] Reintroduce InvalidMessage. This improves compatibility with the legacy implementation and clarifies error reporting. Fix #1548. --- docs/project/changelog.rst | 8 ++++++++ docs/reference/exceptions.rst | 4 ++-- src/websockets/__init__.py | 4 +++- src/websockets/asyncio/client.py | 6 ++++-- src/websockets/client.py | 6 +++++- src/websockets/exceptions.py | 11 +++++++++-- src/websockets/legacy/client.py | 3 ++- src/websockets/legacy/exceptions.py | 9 ++------- src/websockets/legacy/server.py | 3 ++- src/websockets/server.py | 6 +++++- tests/asyncio/test_client.py | 27 ++++++++++++++++++++------- tests/asyncio/test_server.py | 4 ++++ tests/legacy/test_exceptions.py | 4 ---- tests/sync/test_client.py | 21 ++++++++++++++++++--- tests/sync/test_server.py | 4 ++++ tests/test_client.py | 28 ++++++++++++++++++++++++---- tests/test_exceptions.py | 4 ++++ tests/test_server.py | 24 ++++++++++++++++++++---- 18 files changed, 136 insertions(+), 40 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9c594b653..b7f4f62f9 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,14 @@ notice. *In development* +Bug fixes +......... + +* Wrapped errors when reading the opening handshake request or response in + :exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect` + raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening + handshake fails. + .. _14.1: 14.1 diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 75934ef99..d6b7f0f57 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -30,6 +30,8 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs. .. autoexception:: InvalidHandshake +.. autoexception:: InvalidMessage + .. autoexception:: SecurityError .. autoexception:: InvalidStatus @@ -74,8 +76,6 @@ Legacy exceptions These exceptions are only used by the legacy :mod:`asyncio` implementation. -.. autoexception:: InvalidMessage - .. autoexception:: InvalidStatusCode .. autoexception:: AbortHandshake diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 0c7e9b4c6..c278b21d4 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -31,6 +31,7 @@ "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", + "InvalidMessage", "InvalidOrigin", "InvalidParameterName", "InvalidParameterValue", @@ -71,6 +72,7 @@ InvalidHeader, InvalidHeaderFormat, InvalidHeaderValue, + InvalidMessage, InvalidOrigin, InvalidParameterName, InvalidParameterValue, @@ -122,6 +124,7 @@ "InvalidHeader": ".exceptions", "InvalidHeaderFormat": ".exceptions", "InvalidHeaderValue": ".exceptions", + "InvalidMessage": ".exceptions", "InvalidOrigin": ".exceptions", "InvalidParameterName": ".exceptions", "InvalidParameterValue": ".exceptions", @@ -159,7 +162,6 @@ "WebSocketClientProtocol": ".legacy.client", # .legacy.exceptions "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index cdd9bfac6..8581c0551 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -11,7 +11,7 @@ from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike -from ..exceptions import InvalidStatus, SecurityError +from ..exceptions import InvalidMessage, InvalidStatus, SecurityError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -147,7 +147,9 @@ def process_exception(exc: Exception) -> Exception | None: That exception will be raised, breaking out of the retry loop. """ - if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): + if isinstance(exc, (OSError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError): return None if isinstance(exc, InvalidStatus) and exc.response.status_code in [ 500, # Internal Server Error diff --git a/src/websockets/client.py b/src/websockets/client.py index f6cbc9f65..5ced05c2a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -11,6 +11,7 @@ InvalidHandshake, InvalidHeader, InvalidHeaderValue, + InvalidMessage, InvalidStatus, InvalidUpgrade, NegotiationError, @@ -318,7 +319,10 @@ def parse(self) -> Generator[None]: self.reader.read_to_eof, ) except Exception as exc: - self.handshake_exc = exc + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP response" + ) + self.handshake_exc.__cause__ = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f3e751971..81fbb1efd 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -8,7 +8,7 @@ * :exc:`InvalidURI` * :exc:`InvalidHandshake` * :exc:`SecurityError` - * :exc:`InvalidMessage` (legacy) + * :exc:`InvalidMessage` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) * :exc:`InvalidHeader` @@ -48,6 +48,7 @@ "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", + "InvalidMessage", "InvalidOrigin", "InvalidUpgrade", "NegotiationError", @@ -185,6 +186,13 @@ class SecurityError(InvalidHandshake): """ +class InvalidMessage(InvalidHandshake): + """ + Raised when a handshake request or response is malformed. + + """ + + class InvalidStatus(InvalidHandshake): """ Raised when a handshake response rejects the WebSocket upgrade. @@ -410,7 +418,6 @@ class ConcurrencyError(WebSocketException, RuntimeError): deprecated_aliases={ # deprecated in 14.0 - 2024-11-09 "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a3856b470..29141f39a 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -17,6 +17,7 @@ from ..exceptions import ( InvalidHeader, InvalidHeaderValue, + InvalidMessage, NegotiationError, SecurityError, ) @@ -34,7 +35,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake +from .exceptions import InvalidStatusCode, RedirectHandshake from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index e2279c825..78fb696fa 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -3,18 +3,13 @@ from .. import datastructures from ..exceptions import ( InvalidHandshake, + # InvalidMessage was incorrectly moved here in versions 14.0 and 14.1. + InvalidMessage, # noqa: F401 ProtocolError as WebSocketProtocolError, # noqa: F401 ) from ..typing import StatusLike -class InvalidMessage(InvalidHandshake): - """ - Raised when a handshake request or response is malformed. - - """ - - class InvalidStatusCode(InvalidHandshake): """ Raised when a handshake response status code is invalid. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9326b6100..f9d57cb99 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -17,6 +17,7 @@ from ..exceptions import ( InvalidHandshake, InvalidHeader, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -32,7 +33,7 @@ from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .exceptions import AbortHandshake, InvalidMessage +from .exceptions import AbortHandshake from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol, broadcast diff --git a/src/websockets/server.py b/src/websockets/server.py index 607cc306e..1b663a137 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -13,6 +13,7 @@ InvalidHandshake, InvalidHeader, InvalidHeaderValue, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -552,7 +553,10 @@ def parse(self) -> Generator[None]: self.reader.read_line, ) except Exception as exc: - self.handshake_exc = exc + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP request" + ) + self.handshake_exc.__cause__ = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 231d6b8ca..1773c08bd 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -12,6 +12,7 @@ from websockets.client import backoff from websockets.exceptions import ( InvalidHandshake, + InvalidMessage, InvalidStatus, InvalidURI, SecurityError, @@ -151,22 +152,24 @@ async def test_reconnect(self): iterations = 0 successful = 0 - def process_request(connection, request): + async def process_request(connection, request): nonlocal iterations iterations += 1 # Retriable errors if iterations == 1: - connection.transport.close() + await asyncio.sleep(3 * MS) elif iterations == 2: + connection.transport.close() + elif iterations == 3: return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") # Fatal error - elif iterations == 5: + elif iterations == 6: return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with short_backoff_delay(): - async for client in connect(get_uri(server)): + async for client in connect(get_uri(server), open_timeout=3 * MS): self.assertEqual(client.protocol.state.name, "OPEN") successful += 1 @@ -174,7 +177,7 @@ def process_request(connection, request): str(raised.exception), "server rejected WebSocket connection: HTTP 402", ) - self.assertEqual(iterations, 5) + self.assertEqual(iterations, 6) self.assertEqual(successful, 2) async def test_reconnect_with_custom_process_exception(self): @@ -393,11 +396,16 @@ def close_connection(self, request): self.close_transport() async with serve(*args, process_request=close_connection) as server: - with self.assertRaises(EOFError) as raised: + with self.assertRaises(InvalidMessage) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), "connection closed while reading HTTP status line", ) @@ -443,11 +451,16 @@ async def junk(reader, writer): server = await asyncio.start_server(junk, "localhost", 0) host, port = get_host_port(server) async with server: - with self.assertRaises(ValueError) as raised: + with self.assertRaises(InvalidMessage) as raised: async with connect(f"ws://{host}:{port}"): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), "unsupported protocol; expected HTTP/1.1: " "220 smtp.invalid ESMTP Postfix", ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 3e289e592..83885fab5 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -473,6 +473,10 @@ async def test_junk_handshake(self): ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], ["invalid HTTP request line: HELO relay.invalid"], ) diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index e5d22a917..4e6ff952b 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -7,10 +7,6 @@ class ExceptionsTests(unittest.TestCase): def test_str(self): for exception, exception_str in [ - ( - InvalidMessage("malformed HTTP message"), - "malformed HTTP message", - ), ( InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 9d457a912..7d8170519 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,7 +7,12 @@ import time import unittest -from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI +from websockets.exceptions import ( + InvalidHandshake, + InvalidMessage, + InvalidStatus, + InvalidURI, +) from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -149,11 +154,16 @@ def close_connection(self, request): self.close_socket() with run_server(process_request=close_connection) as server: - with self.assertRaises(EOFError) as raised: + with self.assertRaises(InvalidMessage) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), "connection closed while reading HTTP status line", ) @@ -203,11 +213,16 @@ def handle(self): thread = threading.Thread(target=server.serve_forever, args=(MS,)) thread.start() try: - with self.assertRaises(ValueError) as raised: + with self.assertRaises(InvalidMessage) as raised: with connect(f"ws://{host}:{port}"): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), "unsupported protocol; expected HTTP/1.1: " "220 smtp.invalid ESMTP Postfix", ) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 54e49bf16..9a2676437 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -311,6 +311,10 @@ def test_junk_handshake(self): ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], ["invalid HTTP request line: HELO relay.invalid"], ) diff --git a/tests/test_client.py b/tests/test_client.py index 2468be85e..1edbae57d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,12 @@ from websockets.client import * from websockets.client import backoff from websockets.datastructures import Headers -from websockets.exceptions import InvalidHandshake, InvalidHeader, InvalidStatus +from websockets.exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidStatus, +) from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN @@ -244,9 +249,14 @@ def test_receive_no_response(self, _generate_key): client.receive_eof() self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, EOFError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(client.handshake_exc.__cause__), "connection closed while reading HTTP status line", ) @@ -257,9 +267,14 @@ def test_receive_truncated_response(self, _generate_key): client.receive_eof() self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, EOFError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(client.handshake_exc.__cause__), "connection closed while reading HTTP headers", ) @@ -272,9 +287,14 @@ def test_receive_random_response(self, _generate_key): client.receive_data(b"250 Ok\r\n") self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, ValueError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, ValueError) + self.assertEqual( + str(client.handshake_exc.__cause__), "invalid HTTP status line: 220 smtp.invalid", ) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index fef41d136..e0518b0e0 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -91,6 +91,10 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), + ( + InvalidMessage("malformed HTTP message"), + "malformed HTTP message", + ), ( InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", diff --git a/tests/test_server.py b/tests/test_server.py index 844ba64ec..5efeca2d0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,6 +7,7 @@ from websockets.datastructures import Headers from websockets.exceptions import ( InvalidHeader, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -207,9 +208,15 @@ def test_receive_no_request(self): server.receive_eof() self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(server.handshake_exc.__cause__), "connection closed while reading HTTP request line", ) @@ -220,9 +227,14 @@ def test_receive_truncated_request(self): server.receive_eof() self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, EOFError) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(server.handshake_exc.__cause__), "connection closed while reading HTTP headers", ) @@ -233,10 +245,14 @@ def test_receive_junk_request(self): server.receive_data(b"MAIL FROM: \r\n") server.receive_data(b"RCPT TO: \r\n") - self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, ValueError) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, ValueError) + self.assertEqual( + str(server.handshake_exc.__cause__), "invalid HTTP request line: HELO relay.invalid", ) From d852df7dd6324eaee17fc848f029ada371678cbe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 7 Dec 2024 06:23:21 +0000 Subject: [PATCH 683/762] Bump actions/attest-build-provenance from 1 to 2 Bumps [actions/attest-build-provenance](https://github.com/actions/attest-build-provenance) from 1 to 2. - [Release notes](https://github.com/actions/attest-build-provenance/releases) - [Changelog](https://github.com/actions/attest-build-provenance/blob/main/RELEASE.md) - [Commits](https://github.com/actions/attest-build-provenance/compare/v1...v2) --- updated-dependencies: - dependency-name: actions/attest-build-provenance dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cc26502ca..aa1e2e7a9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -82,7 +82,7 @@ jobs: merge-multiple: true path: dist - name: Attest provenance - uses: actions/attest-build-provenance@v1 + uses: actions/attest-build-provenance@v2 with: subject-path: dist/* - name: Upload to PyPI From 197b3ec2c7acf3a3804a94c0c02ed8e27051b0f0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 3 Jan 2025 22:50:31 +0100 Subject: [PATCH 684/762] Don't crash when acknowledging a cancelled ping. Fix #1566. --- src/websockets/asyncio/connection.py | 3 ++- tests/asyncio/test_connection.py | 23 ++++++++++++++++++++++- tests/sync/test_connection.py | 2 +- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index e5c350fe2..c34d19d58 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -750,7 +750,8 @@ def acknowledge_pings(self, data: bytes) -> None: for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) latency = pong_timestamp - ping_timestamp - pong_waiter.set_result(latency) + if not pong_waiter.done(): + pong_waiter.set_result(latency) if ping_id == data: self.latency = latency break diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 5a0b61bf7..788a457ed 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -892,6 +892,15 @@ async def test_acknowledge_ping(self): async with asyncio_timeout(MS): await pong_waiter + async def test_acknowledge_canceled_ping(self): + """ping is acknowledged by a pong with the same payload after being canceled.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + pong_waiter.cancel() + await self.remote_connection.pong("this") + with self.assertRaises(asyncio.CancelledError): + await pong_waiter + async def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping @@ -902,7 +911,7 @@ async def test_acknowledge_ping_non_matching_pong(self): await pong_waiter async def test_acknowledge_previous_ping(self): - """ping is acknowledged by a pong with the same payload as a later ping.""" + """ping is acknowledged by a pong for a later ping.""" async with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = await self.connection.ping("this") await self.connection.ping("that") @@ -910,6 +919,18 @@ async def test_acknowledge_previous_ping(self): async with asyncio_timeout(MS): await pong_waiter + async def test_acknowledge_previous_canceled_ping(self): + """ping is acknowledged by a pong for a later ping after being canceled.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter = await self.connection.ping("this") + pong_waiter_2 = await self.connection.ping("that") + pong_waiter.cancel() + await self.remote_connection.pong("that") + async with asyncio_timeout(MS): + await pong_waiter_2 + with self.assertRaises(asyncio.CancelledError): + await pong_waiter + async def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" async with self.drop_frames_rcvd(): # drop automatic response to ping diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index e21e310a2..aa445498c 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -678,7 +678,7 @@ def test_acknowledge_ping_non_matching_pong(self): self.assertFalse(pong_waiter.wait(MS)) def test_acknowledge_previous_ping(self): - """ping is acknowledged by a pong with the same payload as a later ping.""" + """ping is acknowledged by a pong for as a later ping.""" with self.drop_frames_rcvd(): # drop automatic response to ping pong_waiter = self.connection.ping("this") self.connection.ping("that") From 2abf77fb1ce67b3c5995fb63eeea6b829d96c0fb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 4 Jan 2025 22:38:44 +0100 Subject: [PATCH 685/762] Improve consistency of convenience imports. APIs available as convenience imports should have their types also available as convenienc imports. Fix #1560. --- docs/howto/upgrade.rst | 12 ++++----- src/websockets/__init__.py | 44 ++++++++++++++++++++++++++++++-- src/websockets/datastructures.py | 6 ++++- src/websockets/frames.py | 1 + src/websockets/http11.py | 7 ++++- tests/test_exports.py | 36 ++++++++++++-------------- 6 files changed, 77 insertions(+), 29 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 02d4c6f01..db6bf11f1 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -93,8 +93,8 @@ Client APIs | ``websockets.client.unix_connect()`` |br| | :func:`websockets.asyncio.client.unix_connect` | | :func:`websockets.legacy.client.unix_connect` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | -| ``websockets.client.WebSocketClientProtocol`` |br| | | +| ``websockets.WebSocketClientProtocol`` |br| | ``websockets.ClientConnection`` *(since 14.2)* |br| | +| ``websockets.client.WebSocketClientProtocol`` |br| | :class:`websockets.asyncio.client.ClientConnection` | | :class:`websockets.legacy.client.WebSocketClientProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ @@ -112,12 +112,12 @@ Server APIs | ``websockets.server.unix_serve()`` |br| | :func:`websockets.asyncio.server.unix_serve` | | :func:`websockets.legacy.server.unix_serve` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | -| ``websockets.server.WebSocketServer`` |br| | | +| ``websockets.WebSocketServer`` |br| | ``websockets.Server`` *(since 14.2)* |br| | +| ``websockets.server.WebSocketServer`` |br| | :class:`websockets.asyncio.server.Server` | | :class:`websockets.legacy.server.WebSocketServer` | | +-------------------------------------------------------------------+-----------------------------------------------------+ -| ``websockets.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | -| ``websockets.server.WebSocketServerProtocol`` |br| | | +| ``websockets.WebSocketServerProtocol`` |br| | ``websockets.ServerConnection`` *(since 14.2)* |br| | +| ``websockets.server.WebSocketServerProtocol`` |br| | :class:`websockets.asyncio.server.ServerConnection` | | :class:`websockets.legacy.server.WebSocketServerProtocol` | | +-------------------------------------------------------------------+-----------------------------------------------------+ | ``websockets.broadcast()`` *(before 14.0)* |br| | ``websockets.broadcast()`` *(since 14.0)* |br| | diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index c278b21d4..c8df54e0b 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -10,11 +10,14 @@ # .asyncio.client "connect", "unix_connect", + "ClientConnection", # .asyncio.server "basic_auth", "broadcast", "serve", "unix_serve", + "ServerConnection", + "Server", # .client "ClientProtocol", # .datastructures @@ -44,6 +47,18 @@ "ProtocolError", "SecurityError", "WebSocketException", + # .frames + "Close", + "CloseCode", + "Frame", + "Opcode", + # .http11 + "Request", + "Response", + # .protocol + "Protocol", + "Side", + "State", # .server "ServerProtocol", # .typing @@ -58,8 +73,15 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if typing.TYPE_CHECKING: - from .asyncio.client import connect, unix_connect - from .asyncio.server import basic_auth, broadcast, serve, unix_serve + from .asyncio.client import ClientConnection, connect, unix_connect + from .asyncio.server import ( + Server, + ServerConnection, + basic_auth, + broadcast, + serve, + unix_serve, + ) from .client import ClientProtocol from .datastructures import Headers, HeadersLike, MultipleValuesError from .exceptions import ( @@ -86,6 +108,9 @@ SecurityError, WebSocketException, ) + from .frames import Close, CloseCode, Frame, Opcode + from .http11 import Request, Response + from .protocol import Protocol, Side, State from .server import ServerProtocol from .typing import ( Data, @@ -103,11 +128,14 @@ # .asyncio.client "connect": ".asyncio.client", "unix_connect": ".asyncio.client", + "ClientConnection": ".asyncio.client", # .asyncio.server "basic_auth": ".asyncio.server", "broadcast": ".asyncio.server", "serve": ".asyncio.server", "unix_serve": ".asyncio.server", + "ServerConnection": ".asyncio.server", + "Server": ".asyncio.server", # .client "ClientProtocol": ".client", # .datastructures @@ -137,6 +165,18 @@ "ProtocolError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", + # .frames + "Close": ".frames", + "CloseCode": ".frames", + "Frame": ".frames", + "Opcode": ".frames", + # .http11 + "Request": ".http11", + "Response": ".http11", + # .protocol + "Protocol": ".protocol", + "Side": ".protocol", + "State": ".protocol", # .server "ServerProtocol": ".server", # .typing diff --git a/src/websockets/datastructures.py b/src/websockets/datastructures.py index 77b6f86fa..3c5dcbe9a 100644 --- a/src/websockets/datastructures.py +++ b/src/websockets/datastructures.py @@ -4,7 +4,11 @@ from typing import Any, Protocol, Union -__all__ = ["Headers", "HeadersLike", "MultipleValuesError"] +__all__ = [ + "Headers", + "HeadersLike", + "MultipleValuesError", +] class MultipleValuesError(LookupError): diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 7898c8a5d..ab0869d01 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -28,6 +28,7 @@ "OP_PONG", "DATA_OPCODES", "CTRL_OPCODES", + "CloseCode", "Frame", "Close", ] diff --git a/src/websockets/http11.py b/src/websockets/http11.py index af542c77b..949a320f7 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -13,7 +13,12 @@ from .version import version as websockets_version -__all__ = ["SERVER", "USER_AGENT", "Request", "Response"] +__all__ = [ + "SERVER", + "USER_AGENT", + "Request", + "Response", +] PYTHON_VERSION = "{}.{}".format(*sys.version_info) diff --git a/tests/test_exports.py b/tests/test_exports.py index 93b0684f7..88e27e69d 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -11,24 +11,22 @@ import websockets.uri -combined_exports = ( - [] - + websockets.asyncio.client.__all__ - + websockets.asyncio.server.__all__ - + websockets.client.__all__ - + websockets.datastructures.__all__ - + websockets.exceptions.__all__ - + websockets.server.__all__ - + websockets.typing.__all__ -) - -# These API are intentionally not re-exported by the top-level module. -missing_reexports = [ - # websockets.asyncio.client - "ClientConnection", - # websockets.asyncio.server - "ServerConnection", - "Server", +combined_exports = [ + name + for name in ( + [] + + websockets.asyncio.client.__all__ + + websockets.asyncio.server.__all__ + + websockets.client.__all__ + + websockets.datastructures.__all__ + + websockets.exceptions.__all__ + + websockets.frames.__all__ + + websockets.http11.__all__ + + websockets.protocol.__all__ + + websockets.server.__all__ + + websockets.typing.__all__ + ) + if not name.isupper() # filter out constants ] @@ -36,7 +34,7 @@ class ExportsTests(unittest.TestCase): def test_top_level_module_reexports_submodule_exports(self): self.assertEqual( set(combined_exports), - set(websockets.__all__ + missing_reexports), + set(websockets.__all__), ) def test_submodule_exports_are_globally_unique(self): From e6d0ea1d6b13a979924329d02fb82f79d82c7236 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 4 Jan 2025 10:18:38 +0100 Subject: [PATCH 686/762] Read chunked HTTP responses. Fix #1550. --- src/websockets/client.py | 2 +- src/websockets/http11.py | 145 ++++++++++++++++++++++++--------------- src/websockets/server.py | 2 +- tests/test_client.py | 2 +- tests/test_http11.py | 93 ++++++++++++++++++++----- tests/test_server.py | 2 +- 6 files changed, 167 insertions(+), 79 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 5ced05c2a..37e2a8b3a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -333,7 +333,7 @@ def parse(self) -> Generator[None]: self.logger.debug("< HTTP/1.1 %d %s", code, phrase) for key, value in response.headers.raw_items(): self.logger.debug("< %s: %s", key, value) - if response.body is not None: + if response.body: self.logger.debug("< [body] (%d bytes)", len(response.body)) try: diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 949a320f7..396a43f07 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -185,14 +185,14 @@ class Response: status_code: Response code. reason_phrase: Response reason. headers: Response headers. - body: Response body, if any. + body: Response body. """ status_code: int reason_phrase: str headers: Headers - body: bytes | None = None + body: bytes = b"" _exception: Exception | None = None @@ -266,36 +266,9 @@ def parse( headers = yield from parse_headers(read_line) - # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 - - if "Transfer-Encoding" in headers: - raise NotImplementedError("transfer codings aren't supported") - - # Since websockets only does GET requests (no HEAD, no CONNECT), all - # responses except 1xx, 204, and 304 include a message body. - if 100 <= status_code < 200 or status_code == 204 or status_code == 304: - body = None - else: - content_length: int | None - try: - # MultipleValuesError is sufficiently unlikely that we don't - # attempt to handle it. Instead we document that its parent - # class, LookupError, may be raised. - raw_content_length = headers["Content-Length"] - except KeyError: - content_length = None - else: - content_length = int(raw_content_length) - - if content_length is None: - try: - body = yield from read_to_eof(MAX_BODY_SIZE) - except RuntimeError: - raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") - elif content_length > MAX_BODY_SIZE: - raise SecurityError(f"body too large: {content_length} bytes") - else: - body = yield from read_exact(content_length) + body = yield from read_body( + status_code, headers, read_line, read_exact, read_to_eof + ) return cls(status_code, reason, headers, body) @@ -308,11 +281,37 @@ def serialize(self) -> bytes: # we can keep this simple. response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode() response += self.headers.serialize() - if self.body is not None: - response += self.body + response += self.body return response +def parse_line( + read_line: Callable[[int], Generator[None, None, bytes]], +) -> Generator[None, None, bytes]: + """ + Parse a single line. + + CRLF is stripped from the return value. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: If the connection is closed without a CRLF. + SecurityError: If the response exceeds a security limit. + + """ + try: + line = yield from read_line(MAX_LINE_LENGTH) + except RuntimeError: + raise SecurityError("line too long") + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 + if not line.endswith(b"\r\n"): + raise EOFError("line without CRLF") + return line[:-2] + + def parse_headers( read_line: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, Headers]: @@ -364,28 +363,62 @@ def parse_headers( return headers -def parse_line( +def read_body( + status_code: int, + headers: Headers, read_line: Callable[[int], Generator[None, None, bytes]], + read_exact: Callable[[int], Generator[None, None, bytes]], + read_to_eof: Callable[[int], Generator[None, None, bytes]], ) -> Generator[None, None, bytes]: - """ - Parse a single line. - - CRLF is stripped from the return value. - - Args: - read_line: Generator-based coroutine that reads a LF-terminated line - or raises an exception if there isn't enough data. - - Raises: - EOFError: If the connection is closed without a CRLF. - SecurityError: If the response exceeds a security limit. + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 + + # Since websockets only does GET requests (no HEAD, no CONNECT), all + # responses except 1xx, 204, and 304 include a message body. + if 100 <= status_code < 200 or status_code == 204 or status_code == 304: + return b"" + + # MultipleValuesError is sufficiently unlikely that we don't attempt to + # handle it when accessing headers. Instead we document that its parent + # class, LookupError, may be raised. + # Conversions from str to int are protected by sys.set_int_max_str_digits.. + + elif (coding := headers.get("Transfer-Encoding")) is not None: + if coding != "chunked": + raise NotImplementedError(f"transfer coding {coding} isn't supported") + + body = b"" + while True: + chunk_size_line = yield from parse_line(read_line) + raw_chunk_size = chunk_size_line.split(b";", 1)[0] + # Set a lower limit than default_max_str_digits; 1 EB is plenty. + if len(raw_chunk_size) > 15: + str_chunk_size = raw_chunk_size.decode(errors="backslashreplace") + raise SecurityError(f"chunk too large: 0x{str_chunk_size} bytes") + chunk_size = int(raw_chunk_size, 16) + if chunk_size == 0: + break + if len(body) + chunk_size > MAX_BODY_SIZE: + raise SecurityError( + f"chunk too large: {chunk_size} bytes after {len(body)} bytes" + ) + body += yield from read_exact(chunk_size) + if (yield from read_exact(2)) != b"\r\n": + raise ValueError("chunk without CRLF") + # Read the trailer. + yield from parse_headers(read_line) + return body + + elif (raw_content_length := headers.get("Content-Length")) is not None: + # Set a lower limit than default_max_str_digits; 1 EiB is plenty. + if len(raw_content_length) > 18: + raise SecurityError(f"body too large: {raw_content_length} bytes") + content_length = int(raw_content_length) + if content_length > MAX_BODY_SIZE: + raise SecurityError(f"body too large: {content_length} bytes") + return (yield from read_exact(content_length)) - """ - try: - line = yield from read_line(MAX_LINE_LENGTH) - except RuntimeError: - raise SecurityError("line too long") - # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 - if not line.endswith(b"\r\n"): - raise EOFError("line without CRLF") - return line[:-2] + else: + try: + return (yield from read_to_eof(MAX_BODY_SIZE)) + except RuntimeError: + raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") diff --git a/src/websockets/server.py b/src/websockets/server.py index 1b663a137..fe3c65a7d 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -525,7 +525,7 @@ def send_response(self, response: Response) -> None: self.logger.debug("> HTTP/1.1 %d %s", code, phrase) for key, value in response.headers.raw_items(): self.logger.debug("> %s: %s", key, value) - if response.body is not None: + if response.body: self.logger.debug("> [body] (%d bytes)", len(response.body)) self.writes.append(response.serialize()) diff --git a/tests/test_client.py b/tests/test_client.py index 1edbae57d..9f3ab09b2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -204,7 +204,7 @@ def test_receive_successful_response(self, _generate_key): } ), ) - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") self.assertIsNone(client.handshake_exc) def test_receive_failed_response(self, _generate_key): diff --git a/tests/test_http11.py b/tests/test_http11.py index 1fbcb3ba4..bb0d27b95 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -87,7 +87,7 @@ def test_parse_body(self): ) def test_parse_body_with_transfer_encoding(self): - self.reader.feed_data(b"GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") + self.reader.feed_data(b"GET / HTTP/1.1\r\nTransfer-Encoding: compress\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: next(self.parse()) self.assertEqual( @@ -151,7 +151,7 @@ def test_parse(self): self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual(response.headers["Upgrade"], "websocket") - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") def test_parse_empty(self): self.reader.feed_eof() @@ -215,14 +215,7 @@ def test_parse_invalid_header(self): "invalid HTTP header line: Oops", ) - def test_parse_body_with_content_length(self): - self.reader.feed_data( - b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello world!\n" - ) - response = self.assertGeneratorReturns(self.parse()) - self.assertEqual(response.body, b"Hello world!\n") - - def test_parse_body_without_content_length(self): + def test_parse_body(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\nHello world!\n") gen = self.parse() self.assertGeneratorRunning(gen) @@ -230,7 +223,23 @@ def test_parse_body_without_content_length(self): response = self.assertGeneratorReturns(gen) self.assertEqual(response.body, b"Hello world!\n") - def test_parse_body_with_content_length_too_long(self): + def test_parse_body_too_large(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n" + b"a" * 1048577) + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "body too large: over 1048576 bytes", + ) + + def test_parse_body_with_content_length(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nContent-Length: 13\r\n\r\nHello world!\n" + ) + response = self.assertGeneratorReturns(self.parse()) + self.assertEqual(response.body, b"Hello world!\n") + + def test_parse_body_with_content_length_and_body_too_large(self): self.reader.feed_data(b"HTTP/1.1 200 OK\r\nContent-Length: 1048577\r\n\r\n") with self.assertRaises(SecurityError) as raised: next(self.parse()) @@ -239,33 +248,79 @@ def test_parse_body_with_content_length_too_long(self): "body too large: 1048577 bytes", ) - def test_parse_body_without_content_length_too_long(self): - self.reader.feed_data(b"HTTP/1.1 200 OK\r\n\r\n" + b"a" * 1048577) + def test_parse_body_with_content_length_and_body_way_too_large(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nContent-Length: 1234567890123456789\r\n\r\n" + ) with self.assertRaises(SecurityError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "body too large: over 1048576 bytes", + "body too large: 1234567890123456789 bytes", ) - def test_parse_body_with_transfer_encoding(self): - self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n") + def test_parse_body_with_chunked_transfer_encoding(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + b"6\r\nHello \r\n7\r\nworld!\n\r\n0\r\n\r\n" + ) + response = self.assertGeneratorReturns(self.parse()) + self.assertEqual(response.body, b"Hello world!\n") + + def test_parse_body_with_chunked_transfer_encoding_and_chunk_without_crlf(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + b"6\r\nHello 7\r\nworld!\n0\r\n" + ) + with self.assertRaises(ValueError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "chunk without CRLF", + ) + + def test_parse_body_with_chunked_transfer_encoding_and_chunk_too_large(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + b"100000\r\n" + b"a" * 1048576 + b"\r\n1\r\na\r\n0\r\n\r\n" + ) + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "chunk too large: 1 bytes after 1048576 bytes", + ) + + def test_parse_body_with_chunked_transfer_encoding_and_chunk_way_too_large(self): + self.reader.feed_data( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + b"1234567890ABCDEF\r\n\r\n" + ) + with self.assertRaises(SecurityError) as raised: + next(self.parse()) + self.assertEqual( + str(raised.exception), + "chunk too large: 0x1234567890ABCDEF bytes", + ) + + def test_parse_body_with_unsupported_transfer_encoding(self): + self.reader.feed_data(b"HTTP/1.1 200 OK\r\nTransfer-Encoding: compress\r\n\r\n") with self.assertRaises(NotImplementedError) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "transfer codings aren't supported", + "transfer coding compress isn't supported", ) def test_parse_body_no_content(self): self.reader.feed_data(b"HTTP/1.1 204 No Content\r\n\r\n") response = self.assertGeneratorReturns(self.parse()) - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") def test_parse_body_not_modified(self): self.reader.feed_data(b"HTTP/1.1 304 Not Modified\r\n\r\n") response = self.assertGeneratorReturns(self.parse()) - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") def test_serialize(self): # Example from the protocol overview in RFC 6455 diff --git a/tests/test_server.py b/tests/test_server.py index 5efeca2d0..69f555689 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -281,7 +281,7 @@ def test_accept_response(self, _formatdate): } ), ) - self.assertIsNone(response.body) + self.assertEqual(response.body, b"") @patch("email.utils.formatdate", return_value=DATE) def test_reject_response(self, _formatdate): From 916f841815070a0374434b13b9c91829a7e7a522 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 10 Jan 2025 22:32:44 +0100 Subject: [PATCH 687/762] Upgrade ruff. --- src/websockets/exceptions.py | 3 +-- src/websockets/http11.py | 3 +-- src/websockets/legacy/exceptions.py | 4 +--- src/websockets/sync/connection.py | 6 ++---- tests/asyncio/test_server.py | 3 +-- tests/sync/test_server.py | 3 +-- tests/test_client.py | 3 +-- 7 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 81fbb1efd..73b24debf 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -204,8 +204,7 @@ def __init__(self, response: http11.Response) -> None: def __str__(self) -> str: return ( - "server rejected WebSocket connection: " - f"HTTP {self.response.status_code:d}" + f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" ) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 396a43f07..49d7b9a41 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -94,8 +94,7 @@ class Request: @property def exception(self) -> Exception | None: # pragma: no cover warnings.warn( # deprecated in 10.3 - 2022-04-17 - "Request.exception is deprecated; " - "use ServerProtocol.handshake_exc instead", + "Request.exception is deprecated; use ServerProtocol.handshake_exc instead", DeprecationWarning, ) return self._exception diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index 78fb696fa..29a2525b4 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -52,9 +52,7 @@ def __init__( def __str__(self) -> str: return ( - f"HTTP {self.status:d}, " - f"{len(self.headers)} headers, " - f"{len(self.body)} bytes" + f"HTTP {self.status:d}, {len(self.headers)} headers, {len(self.body)} bytes" ) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index d8dbf140e..60be245b4 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -412,8 +412,7 @@ def send( with self.send_context(): if self.send_in_progress: raise ConcurrencyError( - "cannot call send while another thread " - "is already running send" + "cannot call send while another thread is already running send" ) if text is False: self.protocol.send_binary(message.encode()) @@ -424,8 +423,7 @@ def send( with self.send_context(): if self.send_in_progress: raise ConcurrencyError( - "cannot call send while another thread " - "is already running send" + "cannot call send while another thread is already running send" ) if text is True: self.protocol.send_text(message) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 83885fab5..edad52da5 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -59,8 +59,7 @@ async def test_connection_handler_raises_exception(self): await client.recv() self.assertEqual( str(raised.exception), - "received 1011 (internal error); " - "then sent 1011 (internal error)", + "received 1011 (internal error); then sent 1011 (internal error)", ) async def test_existing_socket(self): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 9a2676437..bb2ebae14 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -58,8 +58,7 @@ def test_connection_handler_raises_exception(self): client.recv() self.assertEqual( str(raised.exception), - "received 1011 (internal error); " - "then sent 1011 (internal error)", + "received 1011 (internal error); then sent 1011 (internal error)", ) def test_existing_socket(self): diff --git a/tests/test_client.py b/tests/test_client.py index 9f3ab09b2..fc9f2ec9a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -638,8 +638,7 @@ def test_multiple_subprotocols_accepted(self): self.assertHandshakeError( client, InvalidHandshake, - "invalid Sec-WebSocket-Protocol header: " - "multiple values: superchat, chat", + "invalid Sec-WebSocket-Protocol header: multiple values: superchat, chat", ) From 668e56cf5d405d5f418d97d183e76fc36d21f857 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 10 Jan 2025 22:26:15 +0100 Subject: [PATCH 688/762] Fix Connection.recv(timeout=0) in the sync implementation. Fix #1552. --- docs/project/changelog.rst | 4 ++++ src/websockets/sync/messages.py | 13 ++++++++++--- tests/sync/test_messages.py | 6 ++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index b7f4f62f9..de6b0be21 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,6 +35,10 @@ notice. Bug fixes ......... +* Fixed ``connection.recv(timeout=0)`` in the :mod:`threading` implementation. + If a message is already received, it is returned. Previously, + :exc:`TimeoutError` was raised incorrectly. + * Wrapped errors when reading the opening handshake request or response in :exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect` raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 98490797f..12e8b1623 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -79,7 +79,12 @@ def get_next_frame(self, timeout: float | None = None) -> Frame: raise EOFError("stream of frames ended") from None else: try: - frame = self.frames.get(block=True, timeout=timeout) + # Check for a frame that's already received if timeout <= 0. + # SimpleQueue.get() doesn't support negative timeout values. + if timeout is not None and timeout <= 0: + frame = self.frames.get(block=False) + else: + frame = self.frames.get(block=True, timeout=timeout) except queue.Empty: raise TimeoutError(f"timed out in {timeout:.1f}s") from None if frame is None: @@ -143,7 +148,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: deadline = Deadline(timeout) # First frame - frame = self.get_next_frame(deadline.timeout()) + frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False)) with self.mutex: self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY @@ -154,7 +159,9 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: # Following frames, for fragmented messages while not frame.fin: try: - frame = self.get_next_frame(deadline.timeout()) + frame = self.get_next_frame( + deadline.timeout(raise_if_elapsed=False) + ) except TimeoutError: # Put frames already received back into the queue # so that future calls to get() can return them. diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index d22693102..0a94b4f85 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -198,6 +198,12 @@ def test_get_timeout_after_first_frame(self): message = self.assembler.get() self.assertEqual(message, "café") + def test_get_if_received(self): + """get returns a text message if it's already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get(timeout=0) + self.assertEqual(message, "café") + # Test get_iter def test_get_iter_text_message_already_received(self): From b1e88fcb77f6b74b2eab6a11a4ee53e4f04937ac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 30 Nov 2024 06:42:21 +0000 Subject: [PATCH 689/762] Bump pypa/cibuildwheel from 2.21.3 to 2.22.0 Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.21.3 to 2.22.0. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.21.3...v2.22.0) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index aa1e2e7a9..0ff07c9df 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: with: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.21.3 + uses: pypa/cibuildwheel@v2.22.0 env: BUILD_EXTENSION: yes - name: Save wheels From 6317c00cc5af245116781ddfde518ec004de672e Mon Sep 17 00:00:00 2001 From: = Date: Fri, 10 Jan 2025 14:47:36 -0700 Subject: [PATCH 690/762] Clarify behavior of `recv(timeout=0)` behavior. Refs #1552. --- src/websockets/sync/connection.py | 7 ++++--- tests/sync/test_messages.py | 24 ++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 60be245b4..e78073242 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -232,9 +232,10 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data message stream. If ``timeout`` is :obj:`None`, block until a message is received. If - ``timeout`` is set and no message is received within ``timeout`` - seconds, raise :exc:`TimeoutError`. Set ``timeout`` to ``0`` to check if - a message was already received. + ``timeout`` is set, wait up to ``timeout`` seconds for a message to be + received and return it, else raise :exc:`TimeoutError`. If ``timeout`` + is ``0`` or negative, check if a message has been received already and + return it, else raise :exc:`TimeoutError`. If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index 0a94b4f85..e5510af3e 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -198,12 +198,32 @@ def test_get_timeout_after_first_frame(self): message = self.assembler.get() self.assertEqual(message, "café") - def test_get_if_received(self): - """get returns a text message if it's already received.""" + def test_get_timeout_0_message_already_received(self): + """get(timeout=0) returns a message that is already received.""" self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) message = self.assembler.get(timeout=0) self.assertEqual(message, "café") + def test_get_timeout_0_message_not_received_yet(self): + """get(timeout=0) times out when no message is already received.""" + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=0) + + def test_get_timeout_0_fragmented_message_already_received(self): + """get(timeout=0) returns a fragmented message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = self.assembler.get(timeout=0) + self.assertEqual(message, "café") + + def test_get_timeout_0_fragmented_message_partially_received(self): + """get(timeout=0) times out when a fragmented message is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=0) + # Test get_iter def test_get_iter_text_message_already_received(self): From 031ec31b70adc527836c5565a7809724fb888c9c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 11 Jan 2025 10:11:59 +0100 Subject: [PATCH 691/762] Prevent close() from blocking when reading is paused. Fix #1555. --- docs/project/changelog.rst | 11 +++++++---- src/websockets/asyncio/messages.py | 2 +- src/websockets/sync/messages.py | 7 ++++++- tests/sync/test_messages.py | 12 ++++++++++++ 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index de6b0be21..83a8e6962 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,15 +35,18 @@ notice. Bug fixes ......... -* Fixed ``connection.recv(timeout=0)`` in the :mod:`threading` implementation. - If a message is already received, it is returned. Previously, - :exc:`TimeoutError` was raised incorrectly. - * Wrapped errors when reading the opening handshake request or response in :exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect` raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening handshake fails. +* Fixed :meth:`~sync.connection.Connection.recv` with ``timeout=0`` in the + :mod:`threading` implementation. If a message is already received, it is + returned. Previously, :exc:`TimeoutError` was raised incorrectly. + +* Prevented :meth:`~sync.connection.Connection.close` from blocking when + receive buffers are saturated in the :mod:`threading` implementation. + .. _14.1: 14.1 diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e6d1d31cc..c10072467 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -283,7 +283,7 @@ def close(self) -> None: """ End the stream of frames. - Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 12e8b1623..dfabedd65 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -298,7 +298,7 @@ def close(self) -> None: """ End the stream of frames. - Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ @@ -311,3 +311,8 @@ def close(self) -> None: if self.get_in_progress: # Unblock get() or get_iter(). self.frames.put(None) + + if self.paused: + # Unblock recv_events(). + self.paused = False + self.resume() diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index e5510af3e..e42784094 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -496,6 +496,18 @@ def test_put_fails_after_close(self): with self.assertRaises(EOFError): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + def test_close_resumes_reading(self): + """close unblocks reading when queue is above the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is at the high-water mark + assert self.assembler.paused + + self.assembler.close() + self.resume.assert_called_once_with() + def test_close_is_idempotent(self): """close can be called multiple times safely.""" self.assembler.close() From e7a098e1a0d5dffcac5f0600703c4ec2de0be48a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 12 Jan 2025 09:00:51 +0100 Subject: [PATCH 692/762] Prevent AssertionError in the recv_events thread. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit close_socket() was interacting with the protocol, namely calling protocol.receive_of(), without locking the mutex. This created the possibility of a race condition. If two threads called receive_eof() concurrently, the second one could return before the first one finished running it. This led to receive_eof() returning (in the second thread) before the connection state was CLOSED, breaking an invariant. This race condition could be triggered reliably by shutting down the network (e.g., turning wifi off), closing the connection, and waiting for the timeout. Then, close() calls close_socket() — this happens in the `raise_close_exc` branch of send_context(). This unblocks the read in recv_events() which calls close_socket() in the `finally:` branch. Fix #1558. --- src/websockets/sync/connection.py | 5 +++-- tests/sync/test_client.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index e78073242..06ea00efc 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -923,8 +923,9 @@ def close_socket(self) -> None: # Calling protocol.receive_eof() is safe because it's idempotent. # This guarantees that the protocol state becomes CLOSED. - self.protocol.receive_eof() - assert self.protocol.state is CLOSED + with self.protocol_mutex: + self.protocol.receive_eof() + assert self.protocol.state is CLOSED # Abort recv() with a ConnectionClosed exception. self.recv_messages.close() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 7d8170519..7ab8f4dd4 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -151,7 +151,8 @@ def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" def close_connection(self, request): - self.close_socket() + self.socket.shutdown(socket.SHUT_RDWR) + self.socket.close() with run_server(process_request=close_connection) as server: with self.assertRaises(InvalidMessage) as raised: From 613f3f0ef83a0c80ae49a42766fb634295216c5c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 18:58:20 +0100 Subject: [PATCH 693/762] Prevent close() from blocking when reading is paused. Closing the transport normally is achieved with transport.write_eof(). Closing it in abnormal situations relied on transport.close(). However, that didn't lead to connection_lost() when reading is paused. Replacing it with transport.abort() ensures that buffers are dropped (which is what we want in abnormal situations) and connection_lost() called quickly. Fix #1555 (for real!) --- docs/project/changelog.rst | 5 +++-- src/websockets/asyncio/client.py | 4 ++-- src/websockets/asyncio/connection.py | 10 +--------- src/websockets/asyncio/server.py | 6 +++--- tests/asyncio/test_client.py | 2 +- 5 files changed, 10 insertions(+), 17 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 83a8e6962..5a2e1bc24 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -44,8 +44,9 @@ Bug fixes :mod:`threading` implementation. If a message is already received, it is returned. Previously, :exc:`TimeoutError` was raised incorrectly. -* Prevented :meth:`~sync.connection.Connection.close` from blocking when - receive buffers are saturated in the :mod:`threading` implementation. +* Prevented :meth:`~asyncio.connection.Connection.close` from blocking when + receive buffers are saturated in the :mod:`asyncio` and :mod:`threading` + implementations. .. _14.1: diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 8581c0551..f05f546d3 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -445,7 +445,7 @@ async def __await_impl__(self) -> ClientConnection: try: await self.connection.handshake(*self.handshake_args) except asyncio.CancelledError: - self.connection.close_transport() + self.connection.transport.abort() raise except Exception as exc: # Always close the connection even though keep-alive is @@ -454,7 +454,7 @@ async def __await_impl__(self) -> ClientConnection: # protocol. In the current design of connect(), there is # no easy way to reuse the network connection that works # in every case nor to reinitialize the protocol. - self.connection.close_transport() + self.connection.transport.abort() uri_or_exc = self.process_redirect(exc) # Response is a valid redirect; follow it. diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index c34d19d58..e2e587e7c 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -928,7 +928,7 @@ async def send_context( # If an error occurred, close the transport to terminate the connection and # raise an exception. if raise_close_exc: - self.close_transport() + self.transport.abort() # Wait for the protocol state to be CLOSED before accessing close_exc. await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from original_exc @@ -969,14 +969,6 @@ def set_recv_exc(self, exc: BaseException | None) -> None: if self.recv_exc is None: self.recv_exc = exc - def close_transport(self) -> None: - """ - Close transport and message assembler. - - """ - self.transport.close() - self.recv_messages.close() - # asyncio.Protocol methods # Connection callbacks diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index fdb928004..49d6f8910 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -356,16 +356,16 @@ async def conn_handler(self, connection: ServerConnection) -> None: self.server_header, ) except asyncio.CancelledError: - connection.close_transport() + connection.transport.abort() raise except Exception: connection.logger.error("opening handshake failed", exc_info=True) - connection.close_transport() + connection.transport.abort() return if connection.protocol.state is not OPEN: # process_request or process_response rejected the handshake. - connection.close_transport() + connection.transport.abort() return try: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 1773c08bd..4db4c038c 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -393,7 +393,7 @@ async def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" def close_connection(self, request): - self.close_transport() + self.transport.close() async with serve(*args, process_request=close_connection) as server: with self.assertRaises(InvalidMessage) as raised: From 7e617b2a57177885926b3a4a8a092621ab719a00 Mon Sep 17 00:00:00 2001 From: dan005 Date: Mon, 13 Jan 2025 23:22:20 +0500 Subject: [PATCH 694/762] Add regex support in `ServerProtocol(origins=...)`. --- src/websockets/asyncio/server.py | 10 +++++---- src/websockets/server.py | 21 +++++++++++++----- src/websockets/sync/server.py | 10 +++++---- tests/test_server.py | 37 ++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 13 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 49d6f8910..080ea3f16 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -4,6 +4,7 @@ import hmac import http import logging +import re import socket import sys from collections.abc import Awaitable, Generator, Iterable, Sequence @@ -599,9 +600,10 @@ def handler(websocket): See :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on. See :meth:`~asyncio.loop.create_server` for details. - origins: Acceptable values of the ``Origin`` header, for defending - against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` - in the list if the lack of an origin is acceptable. + origins: Acceptable values of the ``Origin`` header, including regular + expressions, for defending against Cross-Site WebSocket Hijacking + attacks. Include :obj:`None` in the list if the lack of an origin + is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing @@ -681,7 +683,7 @@ def __init__( port: int | None = None, *, # WebSocket - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( diff --git a/src/websockets/server.py b/src/websockets/server.py index fe3c65a7d..67082ed72 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,6 +4,7 @@ import binascii import email.utils import http +import re import warnings from collections.abc import Generator, Sequence from typing import Any, Callable, cast @@ -49,9 +50,9 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: Acceptable values of the ``Origin`` header; include - :obj:`None` in the list if the lack of an origin is acceptable. - This is useful for defending against Cross-Site WebSocket + origins: Acceptable values of the ``Origin`` header, including regular + expressions; include :obj:`None` in the list if the lack of an origin + is acceptable. This is useful for defending against Cross-Site WebSocket Hijacking attacks. extensions: List of supported extensions, in order in which they should be tried. @@ -73,7 +74,7 @@ class ServerProtocol(Protocol): def __init__( self, *, - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( @@ -309,7 +310,17 @@ def process_origin(self, headers: Headers) -> Origin | None: if origin is not None: origin = cast(Origin, origin) if self.origins is not None: - if origin not in self.origins: + valid = False + for acceptable_origin_or_regex in self.origins: + if isinstance(acceptable_origin_or_regex, re.Pattern): + # `str(origin)` is needed for compatibility + # between `Pattern.match(string=...)` and `origin`. + valid = acceptable_origin_or_regex.match(str(origin)) is not None + else: + valid = acceptable_origin_or_regex == origin + if valid: + break + if not valid: raise InvalidOrigin(origin) return origin diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 9506d6830..c14e558ac 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -4,6 +4,7 @@ import http import logging import os +import re import selectors import socket import ssl as ssl_module @@ -325,7 +326,7 @@ def serve( sock: socket.socket | None = None, ssl: ssl_module.SSLContext | None = None, # WebSocket - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( @@ -399,9 +400,10 @@ def handler(websocket): You may call :func:`socket.create_server` to create a suitable TCP socket. ssl: Configuration for enabling TLS on the connection. - origins: Acceptable values of the ``Origin`` header, for defending - against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` - in the list if the lack of an origin is acceptable. + origins: Acceptable values of the ``Origin`` header, including regular + expressions, for defending against Cross-Site WebSocket Hijacking + attacks. Include :obj:`None` in the list if the lack of an origin + is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/tests/test_server.py b/tests/test_server.py index 69f555689..dd5e0d09a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,5 +1,6 @@ import http import logging +import re import sys import unittest from unittest.mock import patch @@ -623,6 +624,42 @@ def test_unsupported_origin(self): "invalid Origin header: https://original.example.com", ) + def test_supported_origin_by_regex(self): + """ + Handshake succeeds when checking origins and the origin is supported + by a regular expression. + """ + server = ServerProtocol( + origins=["https://example.com", re.compile(r"https://other.*")] + ) + request = make_request() + request.headers["Origin"] = "https://other.example.com" + response = server.accept(request) + server.send_response(response) + + self.assertHandshakeSuccess(server) + self.assertEqual(server.origin, "https://other.example.com") + + def test_unsupported_origin_by_regex(self): + """ + Handshake succeeds when checking origins and the origin is unsupported + by a regular expression. + """ + server = ServerProtocol( + origins=["https://example.com", re.compile(r"https://other.*")] + ) + request = make_request() + request.headers["Origin"] = "https://original.example.com" + response = server.accept(request) + server.send_response(response) + + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "invalid Origin header: https://original.example.com", + ) + def test_no_origin_accepted(self): """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None]) From 7de24bd087e6e51901b6fd5d28bdb6a32404de69 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:19:10 +0100 Subject: [PATCH 695/762] Improve previous commit. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Require fullmatch instead of match — this avoids a vulnerability. * Shorten code and tweak to match my preferred style. * Add changelog. --- docs/project/changelog.rst | 6 ++++++ src/websockets/asyncio/server.py | 9 ++++---- src/websockets/server.py | 25 +++++++++++---------- src/websockets/sync/server.py | 9 ++++---- tests/test_server.py | 37 +++++++++++++++++++++----------- 5 files changed, 52 insertions(+), 34 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5a2e1bc24..74fac904f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,12 @@ notice. *In development* +New features +............ + +* Added support for regular expressions in the ``origins`` argument of + :func:`~asyncio.server.serve`. + Bug fixes ......... diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 080ea3f16..ebe45c2a9 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -600,10 +600,11 @@ def handler(websocket): See :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on. See :meth:`~asyncio.loop.create_server` for details. - origins: Acceptable values of the ``Origin`` header, including regular - expressions, for defending against Cross-Site WebSocket Hijacking - attacks. Include :obj:`None` in the list if the lack of an origin - is acceptable. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/src/websockets/server.py b/src/websockets/server.py index 67082ed72..90e6c9921 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -50,9 +50,11 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: Acceptable values of the ``Origin`` header, including regular - expressions; include :obj:`None` in the list if the lack of an origin - is acceptable. This is useful for defending against Cross-Site WebSocket + origins: Acceptable values of the ``Origin`` header. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket Hijacking attacks. extensions: List of supported extensions, in order in which they should be tried. @@ -310,17 +312,14 @@ def process_origin(self, headers: Headers) -> Origin | None: if origin is not None: origin = cast(Origin, origin) if self.origins is not None: - valid = False - for acceptable_origin_or_regex in self.origins: - if isinstance(acceptable_origin_or_regex, re.Pattern): - # `str(origin)` is needed for compatibility - # between `Pattern.match(string=...)` and `origin`. - valid = acceptable_origin_or_regex.match(str(origin)) is not None - else: - valid = acceptable_origin_or_regex == origin - if valid: + for origin_or_regex in self.origins: + if origin_or_regex == origin or ( + isinstance(origin_or_regex, re.Pattern) + and origin is not None + and origin_or_regex.fullmatch(origin) is not None + ): break - if not valid: + else: raise InvalidOrigin(origin) return origin diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index c14e558ac..50a2f3c06 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -400,10 +400,11 @@ def handler(websocket): You may call :func:`socket.create_server` to create a suitable TCP socket. ssl: Configuration for enabling TLS on the connection. - origins: Acceptable values of the ``Origin`` header, including regular - expressions, for defending against Cross-Site WebSocket Hijacking - attacks. Include :obj:`None` in the list if the lack of an origin - is acceptable. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/tests/test_server.py b/tests/test_server.py index dd5e0d09a..9f328ded5 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -608,7 +608,7 @@ def test_supported_origin(self): self.assertEqual(server.origin, "https://other.example.com") def test_unsupported_origin(self): - """Handshake succeeds when checking origins and the origin is unsupported.""" + """Handshake fails when checking origins and the origin is unsupported.""" server = ServerProtocol( origins=["https://example.com", "https://other.example.com"] ) @@ -624,13 +624,10 @@ def test_unsupported_origin(self): "invalid Origin header: https://original.example.com", ) - def test_supported_origin_by_regex(self): - """ - Handshake succeeds when checking origins and the origin is supported - by a regular expression. - """ + def test_supported_origin_regex(self): + """Handshake succeeds when checking origins and the origin is supported.""" server = ServerProtocol( - origins=["https://example.com", re.compile(r"https://other.*")] + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://other.example.com" @@ -640,13 +637,10 @@ def test_supported_origin_by_regex(self): self.assertHandshakeSuccess(server) self.assertEqual(server.origin, "https://other.example.com") - def test_unsupported_origin_by_regex(self): - """ - Handshake succeeds when checking origins and the origin is unsupported - by a regular expression. - """ + def test_unsupported_origin_regex(self): + """Handshake fails when checking origins and the origin is unsupported.""" server = ServerProtocol( - origins=["https://example.com", re.compile(r"https://other.*")] + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] ) request = make_request() request.headers["Origin"] = "https://original.example.com" @@ -660,6 +654,23 @@ def test_unsupported_origin_by_regex(self): "invalid Origin header: https://original.example.com", ) + def test_partial_match_origin_regex(self): + """Handshake fails when checking origins and the origin a partial match.""" + server = ServerProtocol( + origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")] + ) + request = make_request() + request.headers["Origin"] = "https://other.example.com.hacked" + response = server.accept(request) + server.send_response(response) + + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "invalid Origin header: https://other.example.com.hacked", + ) + def test_no_origin_accepted(self): """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None]) From 17e309a830cabfe5e335bf96cb3795c45f053fbc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:27:28 +0100 Subject: [PATCH 696/762] Mention another symptom in the changelog. Refs #1527. --- docs/project/changelog.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 74fac904f..5f616d0e8 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -51,8 +51,8 @@ Bug fixes returned. Previously, :exc:`TimeoutError` was raised incorrectly. * Prevented :meth:`~asyncio.connection.Connection.close` from blocking when - receive buffers are saturated in the :mod:`asyncio` and :mod:`threading` - implementations. + the network becomes unavailable or when receive buffers are saturated in + the :mod:`asyncio` and :mod:`threading` implementations. .. _14.1: From c8242bbb3ab7e8054a2cbae2bb88cb649bf730f4 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:38:01 +0100 Subject: [PATCH 697/762] Add changelog for #1566. --- docs/project/changelog.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 5f616d0e8..73b28b7b4 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -50,6 +50,9 @@ Bug fixes :mod:`threading` implementation. If a message is already received, it is returned. Previously, :exc:`TimeoutError` was raised incorrectly. +* Fixed a crash in the :mod:`asyncio` implementation when cancelling a ping + then receiving the corresponding pong. + * Prevented :meth:`~asyncio.connection.Connection.close` from blocking when the network becomes unavailable or when receive buffers are saturated in the :mod:`asyncio` and :mod:`threading` implementations. From 624a36cc9c1ea1369971387d7bb23533b2797350 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:42:16 +0100 Subject: [PATCH 698/762] Release version 14.2. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 73b28b7b4..bfb356c62 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 14.2 ---- -*In development* +*January 19, 2025* New features ............ diff --git a/src/websockets/version.py b/src/websockets/version.py index b0df10f0a..4b11b6fe9 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "14.2" From 4af270d035bccd0bd282727c0abfdb7ed2cc0205 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:57:27 +0100 Subject: [PATCH 699/762] Start version 14.3. --- docs/project/changelog.rst | 7 +++++++ src/websockets/version.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index bfb356c62..4ad8f5532 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,6 +25,13 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. +.. _14.3: + +14.3 +---- + +*In development* + .. _14.2: 14.2 diff --git a/src/websockets/version.py b/src/websockets/version.py index 4b11b6fe9..ca9a9115b 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,9 +18,9 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = True +released = False -tag = version = commit = "14.2" +tag = version = commit = "14.3" if not released: # pragma: no cover From 328a20cbb049fe738f77d13f274c29821646b1de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 19 Jan 2025 21:57:36 +0100 Subject: [PATCH 700/762] Prepare upgrade to cibuildwheel 3. --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4e26c757e..d3128f8ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,9 @@ Documentation = "https://websockets.readthedocs.io/" Funding = "https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme" Tracker = "https://github.com/python-websockets/websockets/issues" +[tool.cibuildwheel] +enable = ["pypy"] + # On a macOS runner, build Intel, Universal, and Apple Silicon wheels. [tool.cibuildwheel.macos] archs = ["x86_64", "universal2", "arm64"] From 7a8b757c56777509bf183bf973407de943d0d973 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jan 2025 19:50:41 +0100 Subject: [PATCH 701/762] Rename and reorder for consistency. --- src/websockets/asyncio/connection.py | 16 ++++++++-------- src/websockets/sync/connection.py | 24 ++++++++++++------------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index e2e587e7c..91bc0dda5 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -101,6 +101,14 @@ def __init__( # Protect sending fragmented messages. self.fragmented_send_waiter: asyncio.Future[None] | None = None + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} @@ -120,14 +128,6 @@ def __init__( # Task that sends keepalive pings. None when ping_interval is None. self.keepalive_task: asyncio.Task[None] | None = None - # Exception raised while reading from the connection, to be chained to - # ConnectionClosed in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - - # Completed when the TCP connection is closed and the WebSocket - # connection state becomes CLOSED. - self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() - # Adapted from asyncio.FlowControlMixin self.paused: bool = False self.drain_waiters: collections.deque[asyncio.Future[None]] = ( diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 06ea00efc..653310c35 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -90,14 +90,11 @@ def __init__( resume=self.recv_flow_control.release, ) - # Whether we are busy sending a fragmented message. - self.send_in_progress = False - # Deadline for the closing handshake. self.close_deadline: Deadline | None = None - # Mapping of ping IDs to pong waiters, in chronological order. - self.ping_waiters: dict[bytes, threading.Event] = {} + # Whether we are busy sending a fragmented message. + self.send_in_progress = False # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. @@ -112,6 +109,9 @@ def __init__( ) self.recv_events_thread.start() + # Mapping of ping IDs to pong waiters, in chronological order. + self.pong_waiters: dict[bytes, threading.Event] = {} + # Public attributes @property @@ -581,15 +581,15 @@ def ping(self, data: Data | None = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.ping_waiters: + if data in self.pong_waiters: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.ping_waiters: + while data is None or data in self.pong_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.ping_waiters[data] = pong_waiter + self.pong_waiters[data] = pong_waiter self.protocol.send_ping(data) return pong_waiter @@ -641,22 +641,22 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.ping_waiters: + if data not in self.pong_waiters: return # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.ping_waiters.items(): + for ping_id, ping in self.pong_waiters.items(): ping_ids.append(ping_id) ping.set() if ping_id == data: break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.ping_waiters. + # Remove acknowledged pings from self.pong_waiters. for ping_id in ping_ids: - del self.ping_waiters[ping_id] + del self.pong_waiters[ping_id] def recv_events(self) -> None: """ From 0fac3829353905c2079e05c36958a294b2b49cf1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:35:08 +0100 Subject: [PATCH 702/762] Add latency measurement to the threading implementation. --- docs/project/changelog.rst | 5 +++++ docs/reference/features.rst | 3 +-- docs/reference/sync/client.rst | 2 ++ docs/reference/sync/common.rst | 2 ++ docs/reference/sync/server.rst | 2 ++ src/websockets/sync/connection.py | 23 +++++++++++++++++++---- 6 files changed, 31 insertions(+), 6 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 4ad8f5532..867231241 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,11 @@ notice. *In development* +New features +............ + +* Added latency measurement to the :mod:`threading` implementation. + .. _14.2: 14.2 diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 9187fa505..1135bf829 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -65,10 +65,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Measure latency | ✅ | ❌ | — | ✅ | + | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index 2aa491f6a..89316c997 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -39,6 +39,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 3c03b25b6..d44ff55b6 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -31,6 +31,8 @@ Both sides (:mod:`threading`) .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 1d80450f9..c3d0e8f25 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -52,6 +52,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 653310c35..b0fbf45b6 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -6,6 +6,7 @@ import socket import struct import threading +import time import uuid from collections.abc import Iterable, Iterator, Mapping from types import TracebackType @@ -110,7 +111,16 @@ def __init__( self.recv_events_thread.start() # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, threading.Event] = {} + self.pong_waiters: dict[bytes, tuple[threading.Event, float]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. + """ # Public attributes @@ -589,7 +599,7 @@ def ping(self, data: Data | None = None) -> threading.Event: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pong_waiters[data] = pong_waiter + self.pong_waiters[data] = (pong_waiter, time.monotonic()) self.protocol.send_ping(data) return pong_waiter @@ -643,17 +653,22 @@ def acknowledge_pings(self, data: bytes) -> None: # Ignore unsolicited pong. if data not in self.pong_waiters: return + + pong_timestamp = time.monotonic() + # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pong_waiters.items(): + for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) - ping.set() + pong_waiter.set() if ping_id == data: + self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") + # Remove acknowledged pings from self.pong_waiters. for ping_id in ping_ids: del self.pong_waiters[ping_id] From fc7b151fdfbc092a8d2062ef522d374074153cfe Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:32:30 +0100 Subject: [PATCH 703/762] Add option to set pong waiters on connection close. --- src/websockets/sync/connection.py | 38 +++++++++++++++++++++++++++---- tests/sync/test_connection.py | 9 ++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index b0fbf45b6..5270c1fab 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -111,7 +111,7 @@ def __init__( self.recv_events_thread.start() # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[threading.Event, float]] = {} + self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} self.latency: float = 0 """ @@ -554,7 +554,11 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: # They mean that the connection is closed, which was the goal. pass - def ping(self, data: Data | None = None) -> threading.Event: + def ping( + self, + data: Data | None = None, + ack_on_close: bool = False, + ) -> threading.Event: """ Send a Ping_. @@ -566,6 +570,12 @@ def ping(self, data: Data | None = None) -> threading.Event: Args: data: Payload of the ping. A :class:`str` will be encoded to UTF-8. If ``data`` is :obj:`None`, the payload is four random bytes. + ack_on_close: when this option is :obj:`True`, the event will also + be set when the connection is closed. While this avoids getting + stuck waiting for a pong that will never arrive, it requires + checking that the state of the connection is still ``OPEN`` to + confirm that a pong was received, rather than the connection + being closed. Returns: An event that will be set when the corresponding pong is received. @@ -599,7 +609,7 @@ def ping(self, data: Data | None = None) -> threading.Event: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pong_waiters[data] = (pong_waiter, time.monotonic()) + self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close) self.protocol.send_ping(data) return pong_waiter @@ -660,7 +670,11 @@ def acknowledge_pings(self, data: bytes) -> None: # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + for ping_id, ( + pong_waiter, + ping_timestamp, + _ack_on_close, + ) in self.pong_waiters.items(): ping_ids.append(ping_id) pong_waiter.set() if ping_id == data: @@ -673,6 +687,19 @@ def acknowledge_pings(self, data: bytes) -> None: for ping_id in ping_ids: del self.pong_waiters[ping_id] + def acknowledge_pending_pings(self) -> None: + """ + Acknowledge pending pings when the connection is closed. + + """ + assert self.protocol.state is CLOSED + + for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values(): + if ack_on_close: + pong_waiter.set() + + self.pong_waiters.clear() + def recv_events(self) -> None: """ Read incoming data from the socket and process events. @@ -944,3 +971,6 @@ def close_socket(self) -> None: # Abort recv() with a ConnectionClosed exception. self.recv_messages.close() + + # Acknowledge pings sent with the ack_on_close option. + self.acknowledge_pending_pings() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index aa445498c..ee4aec5f6 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -685,6 +685,15 @@ def test_acknowledge_previous_ping(self): self.remote_connection.pong("that") self.assertTrue(pong_waiter.wait(MS)) + def test_acknowledge_ping_on_close(self): + """ping with ack_on_close is acknowledged when the connection is closed.""" + with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True) + pong_waiter = self.connection.ping("that") + self.connection.close() + self.assertTrue(pong_waiter_ack_on_close.wait(MS)) + self.assertFalse(pong_waiter.wait(MS)) + def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.drop_frames_rcvd(): # drop automatic response to ping From 8f12d8fd16d8945cbb7ad2d37962c47c16cb804f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:21:31 +0100 Subject: [PATCH 704/762] Add keepalive to the threading implementation. --- docs/project/changelog.rst | 3 +- docs/reference/features.rst | 4 +- docs/topics/keepalive.rst | 5 -- src/websockets/asyncio/connection.py | 17 +++-- src/websockets/sync/client.py | 17 ++++- src/websockets/sync/connection.py | 63 +++++++++++++++ src/websockets/sync/server.py | 17 ++++- tests/asyncio/test_connection.py | 26 +++---- tests/sync/test_client.py | 15 ++++ tests/sync/test_connection.py | 110 +++++++++++++++++++++++++++ tests/sync/test_server.py | 21 +++++ 11 files changed, 267 insertions(+), 31 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 867231241..67c16ba9e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,7 +35,8 @@ notice. New features ............ -* Added latency measurement to the :mod:`threading` implementation. +* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the + :mod:`threading` implementation. .. _14.2: diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 1135bf829..6ba42f66b 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -61,9 +61,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Keepalive | ✅ | ❌ | — | ✅ | + | Keepalive | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Heartbeat | ✅ | ❌ | — | ✅ | + | Heartbeat | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index a0467ced2..e63c2f8f5 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -1,11 +1,6 @@ Keepalive and latency ===================== -.. admonition:: This guide applies only to the :mod:`asyncio` implementation. - :class: tip - - The :mod:`threading` implementation doesn't provide keepalive yet. - .. currentmodule:: websockets Long-lived connections diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 91bc0dda5..75c43fa8a 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -686,8 +686,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: pong_waiter = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. - ping_timestamp = self.loop.time() - self.pong_waiters[data] = (pong_waiter, ping_timestamp) + self.pong_waiters[data] = (pong_waiter, self.loop.time()) self.protocol.send_ping(data) return pong_waiter @@ -792,13 +791,19 @@ async def keepalive(self) -> None: latency = 0.0 try: while True: - # If self.ping_timeout > latency > self.ping_interval, pings - # will be sent immediately after receiving pongs. The period - # will be longer than self.ping_interval. + # If self.ping_timeout > latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. await asyncio.sleep(self.ping_interval - latency) - self.logger.debug("% sending keepalive ping") + # This cannot raise ConnectionClosed when the connection is + # closing because ping(), via send_context(), waits for the + # connection to be closed before raising ConnectionClosed. + # However, connection_lost() cancels keepalive_task before + # it gets a chance to resume excuting. pong_waiter = await self.ping() + if self.debug: + self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: try: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 9e6da7caf..8325721b7 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -40,8 +40,8 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. - The ``close_timeout`` and ``max_queue`` arguments have the same meaning as - in :func:`connect`. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`connect`. Args: socket: Socket connected to a WebSocket server. @@ -54,6 +54,8 @@ def __init__( socket: socket.socket, protocol: ClientProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: @@ -62,6 +64,8 @@ def __init__( super().__init__( socket, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -136,6 +140,8 @@ def connect( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -184,6 +190,10 @@ def connect( :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -296,6 +306,8 @@ def connect( connection = create_connection( sock, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -315,6 +327,7 @@ def connect( connection.recv_events_thread.join() raise + connection.start_keepalive() return connection diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 5270c1fab..07f0543e4 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,11 +49,15 @@ def __init__( socket: socket.socket, protocol: Protocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout self.close_timeout = close_timeout if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) @@ -120,8 +124,15 @@ def __init__( Latency is defined as the round-trip time of the connection. It is measured by sending a Ping frame and waiting for a matching Pong frame. Before the first measurement, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ + # Thread that sends keepalive pings. None when ping_interval is None. + self.keepalive_thread: threading.Thread | None = None + # Public attributes @property @@ -700,6 +711,58 @@ def acknowledge_pending_pings(self) -> None: self.pong_waiters.clear() + def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + try: + while True: + # If self.ping_timeout > self.latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + self.recv_events_thread.join(self.ping_interval - self.latency) + if not self.recv_events_thread.is_alive(): + break + + try: + pong_waiter = self.ping(ack_on_close=True) + except ConnectionClosed: + break + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + # + if pong_waiter.wait(self.ping_timeout): + if self.debug: + self.logger.debug("% received keepalive pong") + else: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a thread, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + # This thread is marked as daemon like self.recv_events_thread. + self.keepalive_thread = threading.Thread( + target=self.keepalive, + daemon=True, + ) + self.keepalive_thread.start() + def recv_events(self) -> None: """ Read incoming data from the socket and process events. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 50a2f3c06..643f9b44b 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -52,8 +52,8 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. - The ``close_timeout`` and ``max_queue`` arguments have the same meaning as - in :func:`serve`. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`serve`. Args: socket: Socket connected to a WebSocket client. @@ -66,6 +66,8 @@ def __init__( socket: socket.socket, protocol: ServerProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: @@ -74,6 +76,8 @@ def __init__( super().__init__( socket, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -354,6 +358,8 @@ def serve( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -434,6 +440,10 @@ def handler(websocket): :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing connections in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -563,6 +573,8 @@ def protocol_select_subprotocol( connection = create_connection( sock, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -590,6 +602,7 @@ def protocol_select_subprotocol( assert connection.protocol.state is OPEN try: + connection.start_keepalive() handler(connection) except Exception: connection.logger.error("connection handler failed", exc_info=True) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 788a457ed..b53c97030 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1010,7 +1010,7 @@ async def test_keepalive_times_out(self, getrandbits): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for MS. + # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection is closed. await asyncio.sleep(2 * MS) @@ -1026,9 +1026,9 @@ async def test_keepalive_ignores_timeout(self, getrandbits): getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. - await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for MS. # 4.x ms: a pong frame is dropped. + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. # 6 ms: no pong frame is received; the connection remains open. await asyncio.sleep(2 * MS) # 7 ms: check that the connection is still open. @@ -1036,7 +1036,7 @@ async def test_keepalive_ignores_timeout(self, getrandbits): async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" - self.connection.ping_interval = 2 * MS + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() await asyncio.sleep(MS) await self.connection.close() @@ -1044,15 +1044,15 @@ async def test_keepalive_terminates_while_sleeping(self): async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" - self.connection.ping_interval = 2 * MS - self.connection.ping_timeout = 2 * MS + self.connection.ping_interval = MS + self.connection.ping_timeout = 3 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() - # 2 ms: keepalive() sends a ping frame. - await asyncio.sleep(2 * MS) - # Exiting the context manager sleeps for MS. - # 2.x ms: a pong frame is dropped. - # 3 ms: close the connection before ping_timeout elapses. + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + await asyncio.sleep(MS) + # Exiting the context manager sleeps for 1 ms. + # 2 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) @@ -1062,9 +1062,9 @@ async def test_keepalive_reports_errors(self): async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 2 ms: keepalive() sends a ping frame. - await asyncio.sleep(2 * MS) - # Exiting the context manager sleeps for MS. # 2.x ms: a pong frame is dropped. + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for 1 ms. # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 7ab8f4dd4..1669a0e84 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -76,6 +76,21 @@ def test_disable_compression(self): with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) + def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + with run_server() as server: + with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + time.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + def test_disable_keepalive(self): + """Client disables keepalive.""" + with run_server() as server: + with connect(get_uri(server), ping_interval=None) as client: + time.sleep(2 * MS) + self.assertEqual(client.latency, 0) + def test_logger(self): """Client accepts a logger argument.""" logger = logging.getLogger("test") diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index ee4aec5f6..f191d8944 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -738,6 +738,116 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test keepalive. + + @patch("random.getrandbits") + def test_keepalive(self, getrandbits): + """keepalive sends pings at ping_interval and measures latency.""" + self.connection.ping_interval = 2 * MS + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + self.assertEqual(self.connection.latency, 0) + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is received. + time.sleep(3 * MS) + # 3 ms: check that the ping frame was sent. + self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + time.sleep(3 * MS) + self.assertNoFrameSent() + + @patch("random.getrandbits") + def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + time.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection is closed. + time.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) + + @patch("random.getrandbits") + def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + time.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection remains open. + time.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) + + def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + time.sleep(MS) + self.connection.close() + self.connection.keepalive_thread.join(MS) + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_terminates_when_sending_ping_fails(self): + """keepalive task terminates when sending a ping fails.""" + self.connection.ping_interval = 1 * MS + self.connection.start_keepalive() + with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + self.connection.close() + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = MS + self.connection.ping_timeout = 4 * MS + with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + time.sleep(MS) + # Exiting the context manager sleeps for 1 ms. + # 2 ms: close the connection before ping_timeout elapses. + self.connection.close() + self.connection.keepalive_thread.join(MS) + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("threading.Event.wait", side_effect=Exception("BOOM")): + time.sleep(3 * MS) + # Exiting the context manager sleeps for 1 ms. + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + # Test parameters. def test_close_timeout(self): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index bb2ebae14..8e83f2a81 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -236,6 +236,27 @@ def test_disable_compression(self): with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.extensions", "[]") + def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + with run_server(ping_interval=MS) as server: + with connect(get_uri(server)) as client: + client.send("ws.latency") + latency = eval(client.recv()) + self.assertEqual(latency, 0) + time.sleep(2 * MS) + client.send("ws.latency") + latency = eval(client.recv()) + self.assertGreater(latency, 0) + + def test_disable_keepalive(self): + """Server disables keepalive.""" + with run_server(ping_interval=None) as server: + with connect(get_uri(server)) as client: + time.sleep(2 * MS) + client.send("ws.latency") + latency = eval(client.recv()) + self.assertEqual(latency, 0) + def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test") From ee0d7441940131cc30b0f5a9a2afc905bb490564 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 24 Jan 2025 21:15:08 +0100 Subject: [PATCH 705/762] Avoid shadowing an import with another import. --- pyproject.toml | 2 +- src/websockets/__init__.py | 5 +++-- src/websockets/typing.py | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d3128f8ec..4044de0f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ exclude_lines = [ "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", - "if typing.TYPE_CHECKING:", + "if TYPE_CHECKING:", "raise AssertionError", "self.fail\\(\".*\"\\)", "@unittest.skip", diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index c8df54e0b..1d0abe5cd 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations -import typing +# Importing the typing module would conflict with websockets.typing. +from typing import TYPE_CHECKING from .imports import lazy_import from .version import version as __version__ # noqa: F401 @@ -72,7 +73,7 @@ ] # When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from .asyncio.client import ClientConnection, connect, unix_connect from .asyncio.server import ( Server, diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 0a37141c6..f10481b8b 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,8 +2,7 @@ import http import logging -import typing -from typing import Any, NewType, Optional, Union +from typing import TYPE_CHECKING, Any, NewType, Optional, Union __all__ = [ @@ -31,7 +30,7 @@ # Change to logging.Logger | ... when dropping Python < 3.10. -if typing.TYPE_CHECKING: +if TYPE_CHECKING: LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" else: # remove this branch when dropping support for Python < 3.11 From bba423e510cc422e27dfd77a95771d208e3e766a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 23 Jan 2025 21:22:14 +0100 Subject: [PATCH 706/762] Add type overloads for recv and recv_streaming. Fix #1578. --- docs/project/changelog.rst | 6 +++++ pyproject.toml | 1 + src/websockets/asyncio/connection.py | 20 ++++++++++++++++- src/websockets/asyncio/messages.py | 20 ++++++++++++++++- src/websockets/sync/connection.py | 33 +++++++++++++++++++++++++++- src/websockets/sync/messages.py | 29 +++++++++++++++++++++++- 6 files changed, 105 insertions(+), 4 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 67c16ba9e..7f341d942 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,12 @@ New features * Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the :mod:`threading` implementation. +Improvements +............ + +* Added type overloads for the ``decode`` argument of + :meth:`~asyncio.connection.Connection.recv`. This may simplify static typing. + .. _14.2: 14.2 diff --git a/pyproject.toml b/pyproject.toml index 4044de0f2..c0d9fcfd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ exclude_lines = [ "if TYPE_CHECKING:", "raise AssertionError", "self.fail\\(\".*\"\\)", + "@overload", "@unittest.skip", ] partial_branches = [ diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 75c43fa8a..79429923e 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -11,7 +11,7 @@ import uuid from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType -from typing import Any, cast +from typing import Any, Literal, cast, overload from ..exceptions import ( ConcurrencyError, @@ -243,6 +243,15 @@ async def __aiter__(self) -> AsyncIterator[Data]: except ConnectionClosedOK: return + @overload + async def recv(self, decode: Literal[True]) -> str: ... + + @overload + async def recv(self, decode: Literal[False]) -> bytes: ... + + @overload + async def recv(self, decode: bool | None = None) -> Data: ... + async def recv(self, decode: bool | None = None) -> Data: """ Receive the next message. @@ -312,6 +321,15 @@ async def recv(self, decode: bool | None = None) -> Data: await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc + @overload + def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Receive the next message frame by frame. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index c10072467..581870037 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -4,7 +4,7 @@ import codecs import collections from collections.abc import AsyncIterator, Iterable -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Callable, Generic, Literal, TypeVar, overload from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -116,6 +116,15 @@ def __init__( # pragma: no cover # This flag marks the end of the connection. self.closed = False + @overload + async def get(self, decode: Literal[True]) -> str: ... + + @overload + async def get(self, decode: Literal[False]) -> bytes: ... + + @overload + async def get(self, decode: bool | None = None) -> Data: ... + async def get(self, decode: bool | None = None) -> Data: """ Read the next message. @@ -176,6 +185,15 @@ async def get(self, decode: bool | None = None) -> Data: else: return data + @overload + def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Stream the next message. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 07f0543e4..0c517cc64 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Iterable, Iterator, Mapping from types import TracebackType -from typing import Any +from typing import Any, Literal, overload from ..exceptions import ( ConcurrencyError, @@ -241,6 +241,28 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return + # This overload structure is required to avoid the error: + # "parameter without a default follows parameter with a default" + + @overload + def recv(self, timeout: float | None, decode: Literal[True]) -> str: ... + + @overload + def recv(self, timeout: float | None, decode: Literal[False]) -> bytes: ... + + @overload + def recv(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... + + @overload + def recv( + self, timeout: float | None = None, *, decode: Literal[False] + ) -> bytes: ... + + @overload + def recv( + self, timeout: float | None = None, decode: bool | None = None + ) -> Data: ... + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -311,6 +333,15 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc + @overload + def recv_streaming(self, decode: Literal[True]) -> Iterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> Iterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ... + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index dfabedd65..c619e78a1 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Any, Callable, Iterable, Iterator +from typing import Any, Callable, Iterable, Iterator, Literal, overload from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -110,6 +110,24 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: for frame in queued: # pragma: no cover self.frames.put(frame) + # This overload structure is required to avoid the error: + # "parameter without a default follows parameter with a default" + + @overload + def get(self, timeout: float | None, decode: Literal[True]) -> str: ... + + @overload + def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ... + + @overload + def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... + + @overload + def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ... + + @overload + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ... + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -181,6 +199,15 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: else: return data + @overload + def get_iter(self, decode: Literal[True]) -> Iterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ... + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. From c42672beb30e4c65a9f22ea8e0e68ee86b483a7c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 09:26:50 +0100 Subject: [PATCH 707/762] Narrow down type of client_max_window_bits. --- src/websockets/extensions/permessage_deflate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index cefad4f56..8e74cb282 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -2,7 +2,7 @@ import zlib from collections.abc import Sequence -from typing import Any +from typing import Any, Literal from .. import frames from ..exceptions import ( @@ -212,7 +212,7 @@ def _build_parameters( server_no_context_takeover: bool, client_no_context_takeover: bool, server_max_window_bits: int | None, - client_max_window_bits: int | bool | None, + client_max_window_bits: int | Literal[True] | None, ) -> list[ExtensionParameter]: """ Build a list of ``(name, value)`` pairs for some compression parameters. @@ -234,7 +234,7 @@ def _build_parameters( def _extract_parameters( params: Sequence[ExtensionParameter], *, is_server: bool -) -> tuple[bool, bool, int | None, int | bool | None]: +) -> tuple[bool, bool, int | None, int | Literal[True] | None]: """ Extract compression parameters from a list of ``(name, value)`` pairs. @@ -245,7 +245,7 @@ def _extract_parameters( server_no_context_takeover: bool = False client_no_context_takeover: bool = False server_max_window_bits: int | None = None - client_max_window_bits: int | bool | None = None + client_max_window_bits: int | Literal[True] | None = None for name, value in params: if name == "server_no_context_takeover": @@ -324,7 +324,7 @@ def __init__( server_no_context_takeover: bool = False, client_no_context_takeover: bool = False, server_max_window_bits: int | None = None, - client_max_window_bits: int | bool | None = True, + client_max_window_bits: int | Literal[True] | None = True, compress_settings: dict[str, Any] | None = None, ) -> None: """ From a1bf0008fdcc39499cf1a270c5491e1ca1d4335c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 09:35:17 +0100 Subject: [PATCH 708/762] Rename the wsuri variable. The name looked confusing. --- docs/howto/sansio.rst | 8 ++++---- src/websockets/asyncio/client.py | 28 ++++++++++++++-------------- src/websockets/client.py | 16 +++++++--------- src/websockets/sync/client.py | 12 ++++++------ tests/asyncio/test_client.py | 4 ++-- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/docs/howto/sansio.rst b/docs/howto/sansio.rst index ca530e6a1..27abcdabd 100644 --- a/docs/howto/sansio.rst +++ b/docs/howto/sansio.rst @@ -28,16 +28,16 @@ If you're building a client, parse the URI you'd like to connect to:: from websockets.uri import parse_uri - wsuri = parse_uri("ws://example.com/") + uri = parse_uri("ws://example.com/") -Open a TCP connection to ``(wsuri.host, wsuri.port)`` and perform a TLS -handshake if ``wsuri.secure`` is :obj:`True`. +Open a TCP connection to ``(uri.host, uri.port)`` and perform a TLS handshake +if ``uri.secure`` is :obj:`True`. Initialize a :class:`~client.ClientProtocol`:: from websockets.client import ClientProtocol - protocol = ClientProtocol(wsuri) + protocol = ClientProtocol(uri) Create a WebSocket handshake request with :meth:`~client.ClientProtocol.connect` and send it diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index f05f546d3..237bb273a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -311,10 +311,10 @@ def __init__( if create_connection is None: create_connection = ClientConnection - def protocol_factory(wsuri: WebSocketURI) -> ClientConnection: + def protocol_factory(uri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( - wsuri, + uri, origin=origin, extensions=extensions, subprotocols=subprotocols, @@ -346,15 +346,15 @@ async def create_connection(self) -> ClientConnection: """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() - wsuri = parse_uri(self.uri) + ws_uri = parse_uri(self.uri) kwargs = self.connection_kwargs.copy() def factory() -> ClientConnection: - return self.protocol_factory(wsuri) + return self.protocol_factory(ws_uri) - if wsuri.secure: + if ws_uri.secure: kwargs.setdefault("ssl", True) - kwargs.setdefault("server_hostname", wsuri.host) + kwargs.setdefault("server_hostname", ws_uri.host) if kwargs.get("ssl") is None: raise ValueError("ssl=None is incompatible with a wss:// URI") else: @@ -365,8 +365,8 @@ def factory() -> ClientConnection: _, connection = await loop.create_unix_connection(factory, **kwargs) else: if kwargs.get("sock") is None: - kwargs.setdefault("host", wsuri.host) - kwargs.setdefault("port", wsuri.port) + kwargs.setdefault("host", ws_uri.host) + kwargs.setdefault("port", ws_uri.port) _, connection = await loop.create_connection(factory, **kwargs) return connection @@ -392,9 +392,9 @@ def process_redirect(self, exc: Exception) -> Exception | str: ): return exc - old_wsuri = parse_uri(self.uri) + old_ws_uri = parse_uri(self.uri) new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) - new_wsuri = parse_uri(new_uri) + new_ws_uri = parse_uri(new_uri) # If connect() received a socket, it is closed and cannot be reused. if self.connection_kwargs.get("sock") is not None: @@ -403,14 +403,14 @@ def process_redirect(self, exc: Exception) -> Exception | str: ) # TLS downgrade is forbidden. - if old_wsuri.secure and not new_wsuri.secure: + if old_ws_uri.secure and not new_ws_uri.secure: return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") # Apply restrictions to cross-origin redirects. if ( - old_wsuri.secure != new_wsuri.secure - or old_wsuri.host != new_wsuri.host - or old_wsuri.port != new_wsuri.port + old_ws_uri.secure != new_ws_uri.secure + or old_ws_uri.host != new_ws_uri.host + or old_ws_uri.port != new_ws_uri.port ): # Cross-origin redirects on Unix sockets don't quite make sense. if self.connection_kwargs.get("unix", False): diff --git a/src/websockets/client.py b/src/websockets/client.py index 37e2a8b3a..4c22dfc12 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -50,7 +50,7 @@ class ClientProtocol(Protocol): Sans-I/O implementation of a WebSocket client connection. Args: - wsuri: URI of the WebSocket server, parsed + uri: URI of the WebSocket server, parsed with :func:`~websockets.uri.parse_uri`. origin: Value of the ``Origin`` header. This is useful when connecting to a server that validates the ``Origin`` header to defend against @@ -70,7 +70,7 @@ class ClientProtocol(Protocol): def __init__( self, - wsuri: WebSocketURI, + uri: WebSocketURI, *, origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, @@ -85,7 +85,7 @@ def __init__( max_size=max_size, logger=logger, ) - self.wsuri = wsuri + self.uri = uri self.origin = origin self.available_extensions = extensions self.available_subprotocols = subprotocols @@ -105,12 +105,10 @@ def connect(self) -> Request: """ headers = Headers() - headers["Host"] = build_host( - self.wsuri.host, self.wsuri.port, self.wsuri.secure - ) + headers["Host"] = build_host(self.uri.host, self.uri.port, self.uri.secure) - if self.wsuri.user_info: - headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info) + if self.uri.user_info: + headers["Authorization"] = build_authorization_basic(*self.uri.user_info) if self.origin is not None: headers["Origin"] = self.origin @@ -133,7 +131,7 @@ def connect(self) -> Request: protocol_header = build_subprotocol(self.available_subprotocols) headers["Sec-WebSocket-Protocol"] = protocol_header - return Request(self.wsuri.resource_name, headers) + return Request(self.uri.resource_name, headers) def process_response(self, response: Response) -> None: """ diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 8325721b7..e95036595 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -230,8 +230,8 @@ def connect( DeprecationWarning, ) - wsuri = parse_uri(uri) - if not wsuri.secure and ssl is not None: + ws_uri = parse_uri(uri) + if not ws_uri.secure and ssl is not None: raise ValueError("ssl argument is incompatible with a ws:// URI") # Private APIs for unix_connect() @@ -271,7 +271,7 @@ def connect( sock.connect(path) else: kwargs.setdefault("timeout", deadline.timeout()) - sock = socket.create_connection((wsuri.host, wsuri.port), **kwargs) + sock = socket.create_connection((ws_uri.host, ws_uri.port), **kwargs) sock.settimeout(None) # Disable Nagle algorithm @@ -281,11 +281,11 @@ def connect( # Initialize TLS wrapper and perform TLS handshake - if wsuri.secure: + if ws_uri.secure: if ssl is None: ssl = ssl_module.create_default_context() if server_hostname is None: - server_hostname = wsuri.host + server_hostname = ws_uri.host sock.settimeout(deadline.timeout()) sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) @@ -293,7 +293,7 @@ def connect( # Initialize WebSocket protocol protocol = ClientProtocol( - wsuri, + ws_uri, origin=origin, extensions=extensions, subprotocols=subprotocols, diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 4db4c038c..dd60805c1 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -248,7 +248,7 @@ def redirect(connection, request): async with serve(*args, process_request=redirect) as server: async with connect(get_uri(server) + "/redirect") as client: - self.assertEqual(client.protocol.wsuri.path, "/") + self.assertEqual(client.protocol.uri.path, "/") async def test_cross_origin_redirect(self): """Client follows redirect to a secure URI on a different origin.""" @@ -297,7 +297,7 @@ def redirect(connection, request): async with connect( "ws://overridden/redirect", host=host, port=port ) as client: - self.assertEqual(client.protocol.wsuri.path, "/") + self.assertEqual(client.protocol.uri.path, "/") async def test_cross_origin_redirect_with_explicit_host_port(self): """Client doesn't follow cross-origin redirect with an explicit host / port.""" From bd92cf70a2fc83a85ce1f20e9766cd207a11afad Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 09:41:23 +0100 Subject: [PATCH 709/762] Group parameters more logically. --- src/websockets/asyncio/client.py | 9 +++++---- src/websockets/asyncio/server.py | 9 +++++---- src/websockets/sync/client.py | 9 +++++---- src/websockets/sync/server.py | 9 +++++---- tests/asyncio/test_client.py | 30 +++++++++++++++--------------- tests/asyncio/test_server.py | 32 ++++++++++++++++---------------- tests/sync/test_client.py | 30 +++++++++++++++--------------- tests/sync/test_server.py | 32 ++++++++++++++++---------------- 8 files changed, 82 insertions(+), 78 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 237bb273a..bde0beeea 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -200,14 +200,14 @@ class connect: should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. additional_headers (HeadersLike | None): Arbitrary HTTP headers to add to the handshake request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - compression: The "permessage-deflate" extension is enabled by default. - Set ``compression`` to :obj:`None` to disable it. See the - :doc:`compression guide <../../topics/compression>` for details. process_exception: When reconnecting automatically, tell whether an error is transient or fatal. The default behavior is defined by :func:`process_exception`. Refer to its documentation for details. @@ -275,9 +275,10 @@ def __init__( origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, - compression: str | None = "deflate", process_exception: Callable[[Exception], Exception | None] = process_exception, # Timeouts open_timeout: float | None = 10, diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index ebe45c2a9..2e2b78782 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -617,6 +617,9 @@ def handler(websocket): it has the same behavior as the :meth:`ServerProtocol.select_subprotocol ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response or :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the @@ -630,9 +633,6 @@ def handler(websocket): server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - compression: The "permessage-deflate" extension is enabled by default. - Set ``compression`` to :obj:`None` to disable it. See the - :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -694,6 +694,8 @@ def __init__( ] | None ) = None, + compression: str | None = "deflate", + # HTTP process_request: ( Callable[ [ServerConnection, Request], @@ -709,7 +711,6 @@ def __init__( | None ) = None, server_header: str | None = SERVER, - compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e95036595..da2b88591 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -135,9 +135,10 @@ def connect( origin: Origin | None = None, extensions: Sequence[ClientExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, - compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -180,14 +181,14 @@ def connect( should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. additional_headers (HeadersLike | None): Arbitrary HTTP headers to add to the handshake request. user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - compression: The "permessage-deflate" extension is enabled by default. - Set ``compression`` to :obj:`None` to disable it. See the - :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 643f9b44b..2b753b2c5 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -340,6 +340,8 @@ def serve( ] | None ) = None, + compression: str | None = "deflate", + # HTTP process_request: ( Callable[ [ServerConnection, Request], @@ -355,7 +357,6 @@ def serve( | None ) = None, server_header: str | None = SERVER, - compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -423,6 +424,9 @@ def handler(websocket): it has the same behavior as the :meth:`ServerProtocol.select_subprotocol ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. process_request: Intercept the request during the opening handshake. Return an HTTP response to force the response. Return :obj:`None` to continue normally. When you force an HTTP 101 Continue response, the @@ -435,9 +439,6 @@ def handler(websocket): server_header: Value of the ``Server`` response header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. - compression: The "permessage-deflate" extension is enabled by default. - Set ``compression`` to :obj:`None` to disable it. See the - :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index dd60805c1..f05bfc699 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -76,6 +76,21 @@ async def test_existing_socket(self): async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with serve(*args) as server: + async with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + async def test_additional_headers(self): """Client can set additional headers with additional_headers.""" async with serve(*args) as server: @@ -96,21 +111,6 @@ async def test_remove_user_agent(self): async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) - async def test_compression_is_enabled(self): - """Client enables compression by default.""" - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual( - [type(ext) for ext in client.protocol.extensions], - [PerMessageDeflate], - ) - - async def test_disable_compression(self): - """Client disables compression.""" - async with serve(*args) as server: - async with connect(get_uri(server), compression=None) as client: - self.assertEqual(client.protocol.extensions, []) - async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" async with serve(*args) as server: diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index edad52da5..5c66dc727 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -117,6 +117,22 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with serve(*args, compression=None) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + async def test_process_request_returns_none(self): """Server runs process_request and continues the handshake.""" @@ -359,22 +375,6 @@ async def test_remove_server(self): client, "'Server' in ws.response.headers", "False" ) - async def test_compression_is_enabled(self): - """Server enables compression by default.""" - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - await self.assertEval( - client, - "[type(ext).__name__ for ext in ws.protocol.extensions]", - "['PerMessageDeflate']", - ) - - async def test_disable_compression(self): - """Server disables compression.""" - async with serve(*args, compression=None) as server: - async with connect(get_uri(server)) as client: - await self.assertEval(client, "ws.protocol.extensions", "[]") - async def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" async with serve(*args, ping_interval=MS) as server: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 1669a0e84..736a84c98 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -41,6 +41,21 @@ def test_existing_socket(self): with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") + def test_compression_is_enabled(self): + """Client enables compression by default.""" + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + def test_disable_compression(self): + """Client disables compression.""" + with run_server() as server: + with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + def test_additional_headers(self): """Client can set additional headers with additional_headers.""" with run_server() as server: @@ -61,21 +76,6 @@ def test_remove_user_agent(self): with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) - def test_compression_is_enabled(self): - """Client enables compression by default.""" - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual( - [type(ext) for ext in client.protocol.extensions], - [PerMessageDeflate], - ) - - def test_disable_compression(self): - """Client disables compression.""" - with run_server() as server: - with connect(get_uri(server), compression=None) as client: - self.assertEqual(client.protocol.extensions, []) - def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" with run_server() as server: diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 8e83f2a81..f59671efd 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -115,6 +115,22 @@ def select_subprotocol(ws, subprotocols): "server rejected WebSocket connection: HTTP 500", ) + def test_compression_is_enabled(self): + """Server enables compression by default.""" + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + def test_disable_compression(self): + """Server disables compression.""" + with run_server(compression=None) as server: + with connect(get_uri(server)) as client: + self.assertEval(client, "ws.protocol.extensions", "[]") + def test_process_request_returns_none(self): """Server runs process_request and continues the handshake.""" @@ -220,22 +236,6 @@ def test_remove_server(self): with connect(get_uri(server)) as client: self.assertEval(client, "'Server' in ws.response.headers", "False") - def test_compression_is_enabled(self): - """Server enables compression by default.""" - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEval( - client, - "[type(ext).__name__ for ext in ws.protocol.extensions]", - "['PerMessageDeflate']", - ) - - def test_disable_compression(self): - """Server disables compression.""" - with run_server(compression=None) as server: - with connect(get_uri(server)) as client: - self.assertEval(client, "ws.protocol.extensions", "[]") - def test_keepalive_is_enabled(self): """Server enables keepalive and measures latency.""" with run_server(ping_interval=MS) as server: From 919cd92149edc3395314282e8c5799b52d7c4bc2 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 09:48:53 +0100 Subject: [PATCH 710/762] Remove some excess vertical space. --- src/websockets/client.py | 21 ++++----------------- src/websockets/server.py | 21 ++------------------- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index 4c22dfc12..9ea21c39c 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -104,33 +104,26 @@ def connect(self) -> Request: """ headers = Headers() - headers["Host"] = build_host(self.uri.host, self.uri.port, self.uri.secure) - if self.uri.user_info: headers["Authorization"] = build_authorization_basic(*self.uri.user_info) - if self.origin is not None: headers["Origin"] = self.origin - headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Key"] = self.key headers["Sec-WebSocket-Version"] = "13" - if self.available_extensions is not None: - extensions_header = build_extension( + headers["Sec-WebSocket-Extensions"] = build_extension( [ (extension_factory.name, extension_factory.get_request_params()) for extension_factory in self.available_extensions ] ) - headers["Sec-WebSocket-Extensions"] = extensions_header - if self.available_subprotocols is not None: - protocol_header = build_subprotocol(self.available_subprotocols) - headers["Sec-WebSocket-Protocol"] = protocol_header - + headers["Sec-WebSocket-Protocol"] = build_subprotocol( + self.available_subprotocols + ) return Request(self.uri.resource_name, headers) def process_response(self, response: Response) -> None: @@ -153,7 +146,6 @@ def process_response(self, response: Response) -> None: connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) - if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade( "Connection", ", ".join(connection) if connection else None @@ -162,7 +154,6 @@ def process_response(self, response: Response) -> None: upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) - # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): @@ -174,12 +165,10 @@ def process_response(self, response: Response) -> None: raise InvalidHeader("Sec-WebSocket-Accept") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None - if s_w_accept != accept_key(self.key): raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) self.extensions = self.process_extensions(headers) - self.subprotocol = self.process_subprotocol(headers) def process_extensions(self, headers: Headers) -> list[Extension]: @@ -279,7 +268,6 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None: parsed_subprotocols: Sequence[Subprotocol] = sum( [parse_subprotocol(header_value) for header_value in subprotocols], [] ) - if len(parsed_subprotocols) > 1: raise InvalidHeader( "Sec-WebSocket-Protocol", @@ -287,7 +275,6 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None: ) subprotocol = parsed_subprotocols[0] - if subprotocol not in self.available_subprotocols: raise NegotiationError(f"unsupported subprotocol: {subprotocol}") diff --git a/src/websockets/server.py b/src/websockets/server.py index 90e6c9921..174441203 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -190,19 +190,14 @@ def accept(self, request: Request) -> Response: ) headers = Headers() - headers["Date"] = email.utils.formatdate(usegmt=True) - headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Accept"] = accept_header - if extensions_header is not None: headers["Sec-WebSocket-Extensions"] = extensions_header - if protocol_header is not None: headers["Sec-WebSocket-Protocol"] = protocol_header - return Response(101, "Switching Protocols", headers) def process_request( @@ -234,7 +229,6 @@ def process_request( connection: list[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) - if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade( "Connection", ", ".join(connection) if connection else None @@ -243,7 +237,6 @@ def process_request( upgrade: list[UpgradeProtocol] = sum( [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ) - # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. The RFC always uses "websocket", except # in section 11.2. (IANA registration) where it uses "WebSocket". @@ -256,13 +249,13 @@ def process_request( raise InvalidHeader("Sec-WebSocket-Key") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None - try: raw_key = base64.b64decode(key.encode(), validate=True) except binascii.Error as exc: raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc if len(raw_key) != 16: raise InvalidHeaderValue("Sec-WebSocket-Key", key) + accept_header = accept_key(key) try: version = headers["Sec-WebSocket-Version"] @@ -270,23 +263,14 @@ def process_request( raise InvalidHeader("Sec-WebSocket-Version") from None except MultipleValuesError: raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None - if version != "13": raise InvalidHeaderValue("Sec-WebSocket-Version", version) - accept_header = accept_key(key) - self.origin = self.process_origin(headers) - extensions_header, self.extensions = self.process_extensions(headers) - protocol_header = self.subprotocol = self.process_subprotocol(headers) - return ( - accept_header, - extensions_header, - protocol_header, - ) + return (accept_header, extensions_header, protocol_header) def process_origin(self, headers: Headers) -> Origin | None: """ @@ -426,7 +410,6 @@ def process_subprotocol(self, headers: Headers) -> Subprotocol | None: ], [], ) - return self.select_subprotocol(subprotocols) def select_subprotocol( From 1e62b3c384a9fc345c661ebdd66a2a1a0fbbff52 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 13:59:46 +0100 Subject: [PATCH 711/762] Disable PyPy in CI. Refs #1581. --- .github/workflows/tests.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5ab9c4c72..ca73cd499 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,12 +60,13 @@ jobs: - "3.11" - "3.12" - "3.13" - - "pypy-3.10" +# Disable PyPy per https://github.com/python-websockets/websockets/issues/1581 +# - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} - exclude: - - python: "pypy-3.10" - is_main: false +# exclude: +# - python: "pypy-3.10" +# is_main: false steps: - name: Check out repository uses: actions/checkout@v4 From bd1796625de9a087229bb805acae86b414a4449c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 14:19:29 +0100 Subject: [PATCH 712/762] Fix names of environment variables in docs. --- docs/reference/variables.rst | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index a24773074..a55057a0d 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -54,29 +54,30 @@ Reconnection Reconnection attempts are spaced out with truncated exponential backoff. -.. envvar:: BACKOFF_INITIAL_DELAY +.. envvar:: WEBSOCKETS_BACKOFF_INITIAL_DELAY The first attempt is delayed by a random amount of time between ``0`` and - ``BACKOFF_INITIAL_DELAY`` seconds. + ``WEBSOCKETS_BACKOFF_INITIAL_DELAY`` seconds. The default value is ``5.0`` seconds. -.. envvar:: BACKOFF_MIN_DELAY +.. envvar:: WEBSOCKETS_BACKOFF_MIN_DELAY - The second attempt is delayed by ``BACKOFF_MIN_DELAY`` seconds. + The second attempt is delayed by ``WEBSOCKETS_BACKOFF_MIN_DELAY`` seconds. The default value is ``3.1`` seconds. -.. envvar:: BACKOFF_FACTOR +.. envvar:: WEBSOCKETS_BACKOFF_FACTOR - After the second attempt, the delay is multiplied by ``BACKOFF_FACTOR`` - between each attempt. + After the second attempt, the delay is multiplied by + ``WEBSOCKETS_BACKOFF_FACTOR`` between each attempt. The default value is ``1.618``. -.. envvar:: BACKOFF_MAX_DELAY +.. envvar:: WEBSOCKETS_BACKOFF_MAX_DELAY - The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds. + The delay between attempts is capped at ``WEBSOCKETS_BACKOFF_MAX_DELAY`` + seconds. The default value is ``90.0`` seconds. From 387197f98c3bd605338ed898c4d05a15d4ee1bfc Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 15:08:47 +0100 Subject: [PATCH 713/762] Mitigate test flakiness. --- tests/asyncio/test_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 5c66dc727..38c0315a1 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -542,7 +542,8 @@ async def test_close_server_keeps_handlers_running(self): async with asyncio_timeout(2 * MS): await server.wait_closed() - async with asyncio_timeout(3 * MS): + # Set a large timeout here, else the test becomes flaky. + async with asyncio_timeout(5 * MS): await server.wait_closed() From dfea9b311283e10511c897b9e0253571943a9123 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 18:30:45 +0100 Subject: [PATCH 714/762] Standardize mocking style to @patch in tests. --- tests/asyncio/test_connection.py | 35 ++++++----------- tests/legacy/test_client_server.py | 42 +++++++++----------- tests/sync/test_connection.py | 31 ++++++--------- tests/test_frames.py | 7 +--- tests/test_protocol.py | 61 +++++++++++++++--------------- tests/test_server.py | 6 ++- 6 files changed, 79 insertions(+), 103 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index b53c97030..dc4539948 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -866,10 +866,9 @@ async def test_wait_closed(self): # Test ping. - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) async def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" - getrandbits.return_value = 1918987876 await self.connection.ping() getrandbits.assert_called_once_with(32) await self.assertFrameSent(Frame(Opcode.PING, b"rand")) @@ -978,11 +977,10 @@ async def test_pong_unsupported_type(self): # Test keepalive. - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) async def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 2 * MS - getrandbits.return_value = 1918987876 self.connection.start_keepalive() self.assertEqual(self.connection.latency, 0) # 2 ms: keepalive() sends a ping frame. @@ -1000,13 +998,12 @@ async def test_disable_keepalive(self): await asyncio.sleep(3 * MS) await self.assertNoFrameSent() - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) async def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): - getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) @@ -1017,13 +1014,12 @@ async def test_keepalive_times_out(self, getrandbits): # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) async def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None async with self.drop_frames_rcvd(): - getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. # 4.x ms: a pong frame is dropped. @@ -1142,17 +1138,13 @@ async def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @unittest.mock.patch( - "asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234) - ) + @patch("asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234)) async def test_local_address(self, get_extra_info): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) get_extra_info.assert_called_with("sockname") - @unittest.mock.patch( - "asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234) - ) + @patch("asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234)) async def test_remote_address(self, get_extra_info): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) @@ -1213,12 +1205,11 @@ async def test_writing_in_send_context_fails(self): # Test safety nets — catching all exceptions in case of bugs. - @patch("websockets.protocol.Protocol.events_received") + # Inject a fault in a random call in data_received(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) async def test_unexpected_failure_in_data_received(self, events_received): """Unexpected internal error in data_received() is correctly reported.""" - # Inject a fault in a random call in data_received(). - # This test is tightly coupled to the implementation. - events_received.side_effect = AssertionError # Receive a message to trigger the fault. await self.remote_connection.send("😀") @@ -1229,13 +1220,11 @@ async def test_unexpected_failure_in_data_received(self, events_received): self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) - @patch("websockets.protocol.Protocol.send_text") + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) async def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" - # Inject a fault in a random call in send_context(). - # This test is tightly coupled to the implementation. - send_text.side_effect = AssertionError - # Send a message to trigger the fault. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: diff --git a/tests/legacy/test_client_server.py b/tests/legacy/test_client_server.py index c13c6c92e..2354db022 100644 --- a/tests/legacy/test_client_server.py +++ b/tests/legacy/test_client_server.py @@ -10,10 +10,10 @@ import ssl import sys import unittest -import unittest.mock import urllib.error import urllib.request import warnings +from unittest.mock import patch from websockets.asyncio.compatibility import asyncio_timeout from websockets.datastructures import Headers @@ -968,7 +968,7 @@ def test_extension_order(self): ) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_extensions") + @patch.object(WebSocketServerProtocol, "process_extensions") def test_extensions_error(self, _process_extensions): _process_extensions.return_value = "x-no-op", [NoOpExtension()] @@ -978,7 +978,7 @@ def test_extensions_error(self, _process_extensions): ) @with_server(extensions=[ServerNoOpExtensionFactory()]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_extensions") + @patch.object(WebSocketServerProtocol, "process_extensions") def test_extensions_error_no_extensions(self, _process_extensions): _process_extensions.return_value = "x-no-op", [NoOpExtension()] @@ -1051,7 +1051,7 @@ def test_subprotocol_not_requested(self): self.assertEqual(self.client.subprotocol, None) @with_server(subprotocols=["superchat"]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") + @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error(self, _process_subprotocol): _process_subprotocol.return_value = "superchat" @@ -1060,7 +1060,7 @@ def test_subprotocol_error(self, _process_subprotocol): self.run_loop_once() @with_server(subprotocols=["superchat"]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") + @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = "superchat" @@ -1069,7 +1069,7 @@ def test_subprotocol_error_no_subprotocols(self, _process_subprotocol): self.run_loop_once() @with_server(subprotocols=["superchat", "chat"]) - @unittest.mock.patch.object(WebSocketServerProtocol, "process_subprotocol") + @patch.object(WebSocketServerProtocol, "process_subprotocol") def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): _process_subprotocol.return_value = "superchat, chat" @@ -1078,7 +1078,7 @@ def test_subprotocol_error_two_subprotocols(self, _process_subprotocol): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.legacy.server.read_request") + @patch("websockets.legacy.server.read_request") def test_server_receives_malformed_request(self, _read_request): _read_request.side_effect = ValueError("read_request failed") @@ -1086,7 +1086,7 @@ def test_server_receives_malformed_request(self, _read_request): self.start_client() @with_server() - @unittest.mock.patch("websockets.legacy.client.read_response") + @patch("websockets.legacy.client.read_response") def test_client_receives_malformed_response(self, _read_response): _read_response.side_effect = ValueError("read_response failed") @@ -1095,7 +1095,7 @@ def test_client_receives_malformed_response(self, _read_response): self.run_loop_once() @with_server() - @unittest.mock.patch("websockets.legacy.client.build_request") + @patch("websockets.legacy.client.build_request") def test_client_sends_invalid_handshake_request(self, _build_request): def wrong_build_request(headers): return "42" @@ -1106,7 +1106,7 @@ def wrong_build_request(headers): self.start_client() @with_server() - @unittest.mock.patch("websockets.legacy.server.build_response") + @patch("websockets.legacy.server.build_response") def test_server_sends_invalid_handshake_response(self, _build_response): def wrong_build_response(headers, key): return build_response(headers, "42") @@ -1117,7 +1117,7 @@ def wrong_build_response(headers, key): self.start_client() @with_server() - @unittest.mock.patch("websockets.legacy.client.read_response") + @patch("websockets.legacy.client.read_response") def test_server_does_not_switch_protocols(self, _read_response): async def wrong_read_response(stream): status_code, reason, headers = await read_response(stream) @@ -1130,9 +1130,7 @@ async def wrong_read_response(stream): self.run_loop_once() @with_server() - @unittest.mock.patch( - "websockets.legacy.server.WebSocketServerProtocol.process_request" - ) + @patch("websockets.legacy.server.WebSocketServerProtocol.process_request") def test_server_error_in_handshake(self, _process_request): _process_request.side_effect = Exception("process_request crashed") @@ -1156,7 +1154,7 @@ async def cancelled_client(): sock.send(b"") # socket is closed @with_server() - @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.send") + @patch("websockets.legacy.server.WebSocketServerProtocol.send") def test_server_handler_crashes(self, send): send.side_effect = ValueError("send failed") @@ -1169,7 +1167,7 @@ def test_server_handler_crashes(self, send): self.assertEqual(self.client.close_code, CloseCode.INTERNAL_ERROR) @with_server() - @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.close") + @patch("websockets.legacy.server.WebSocketServerProtocol.close") def test_server_close_crashes(self, close): close.side_effect = ValueError("close failed") @@ -1183,7 +1181,7 @@ def test_server_close_crashes(self, close): @with_server() @with_client() - @unittest.mock.patch.object(WebSocketClientProtocol, "handshake") + @patch.object(WebSocketClientProtocol, "handshake") def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. @@ -1254,12 +1252,8 @@ def test_invalid_status_error_during_client_connect(self): self.assertEqual(exception.status_code, 403) @with_server() - @unittest.mock.patch( - "websockets.legacy.server.WebSocketServerProtocol.write_http_response" - ) - @unittest.mock.patch( - "websockets.legacy.server.WebSocketServerProtocol.read_http_request" - ) + @patch("websockets.legacy.server.WebSocketServerProtocol.write_http_response") + @patch("websockets.legacy.server.WebSocketServerProtocol.read_http_request") def test_connection_error_during_opening_handshake( self, _read_http_request, _write_http_response ): @@ -1277,7 +1271,7 @@ def test_connection_error_during_opening_handshake( _write_http_response.assert_not_called() @with_server() - @unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.close") + @patch("websockets.legacy.server.WebSocketServerProtocol.close") def test_connection_error_during_closing_handshake(self, close): close.side_effect = ConnectionError diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index f191d8944..be7ff36f4 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -645,10 +645,9 @@ def fragments(): # Test ping. - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" - getrandbits.return_value = 1918987876 self.connection.ping() getrandbits.assert_called_once_with(32) self.assertFrameSent(Frame(Opcode.PING, b"rand")) @@ -740,11 +739,10 @@ def test_pong_unsupported_type(self): # Test keepalive. - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" self.connection.ping_interval = 2 * MS - getrandbits.return_value = 1918987876 self.connection.start_keepalive() self.assertEqual(self.connection.latency, 0) # 2 ms: keepalive() sends a ping frame. @@ -762,13 +760,12 @@ def test_disable_keepalive(self): time.sleep(3 * MS) self.assertNoFrameSent() - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS with self.drop_frames_rcvd(): - getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. time.sleep(4 * MS) @@ -779,13 +776,12 @@ def test_keepalive_times_out(self, getrandbits): # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) - @patch("random.getrandbits") + @patch("random.getrandbits", return_value=1918987876) def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None with self.drop_frames_rcvd(): - getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. time.sleep(4 * MS) @@ -910,13 +906,13 @@ def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @unittest.mock.patch("socket.socket.getsockname", return_value=("sock", 1234)) + @patch("socket.socket.getsockname", return_value=("sock", 1234)) def test_local_address(self, getsockname): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) getsockname.assert_called_with() - @unittest.mock.patch("socket.socket.getpeername", return_value=("peer", 1234)) + @patch("socket.socket.getpeername", return_value=("peer", 1234)) def test_remote_address(self, getpeername): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) @@ -984,12 +980,11 @@ def test_writing_in_send_context_fails(self): # Test safety nets — catching all exceptions in case of bugs. - @patch("websockets.protocol.Protocol.events_received") + # Inject a fault in a random call in recv_events(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) def test_unexpected_failure_in_recv_events(self, events_received): """Unexpected internal error in recv_events() is correctly reported.""" - # Inject a fault in a random call in recv_events(). - # This test is tightly coupled to the implementation. - events_received.side_effect = AssertionError # Receive a message to trigger the fault. self.remote_connection.send("😀") @@ -1000,13 +995,11 @@ def test_unexpected_failure_in_recv_events(self, events_received): self.assertEqual(str(exc), "no close frame received or sent") self.assertIsInstance(exc.__cause__, AssertionError) - @patch("websockets.protocol.Protocol.send_text") + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" - # Inject a fault in a random call in send_context(). - # This test is tightly coupled to the implementation. - send_text.side_effect = AssertionError - # Send a message to trigger the fault. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: diff --git a/tests/test_frames.py b/tests/test_frames.py index 857b41fe4..1c372b5de 100644 --- a/tests/test_frames.py +++ b/tests/test_frames.py @@ -1,7 +1,7 @@ import codecs import dataclasses import unittest -import unittest.mock +from unittest.mock import patch from websockets.exceptions import PayloadTooBig, ProtocolError from websockets.frames import * @@ -12,9 +12,6 @@ class FramesTestCase(GeneratorTestCase): - def enforce_mask(self, mask): - return unittest.mock.patch("secrets.token_bytes", return_value=mask) - def parse(self, data, mask, max_size=None, extensions=None): """ Parse a frame from a bytestring. @@ -41,7 +38,7 @@ def assertFrameData(self, frame, data, mask, extensions=None): # Make masking deterministic by reusing the same "random" mask. # This has an effect only when mask is True. mask_bytes = data[2:6] if mask else b"" - with self.enforce_mask(mask_bytes): + with patch("secrets.token_bytes", return_value=mask_bytes): serialized = frame.serialize(mask=mask, extensions=extensions) self.assertEqual(serialized, data) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 1c092459d..9e2d65041 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,5 +1,6 @@ import logging -import unittest.mock +import unittest +from unittest.mock import patch from websockets.exceptions import ( ConnectionClosedError, @@ -106,7 +107,7 @@ class MaskingTests(ProtocolTestCase): def test_client_sends_masked_frame(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\xff\x00\xff"): + with patch("secrets.token_bytes", return_value=b"\x00\xff\x00\xff"): client.send_text(b"Spam", True) self.assertEqual(client.data_to_send(), [self.masked_text_frame_data]) @@ -191,7 +192,7 @@ def test_client_sends_continuation_after_sending_close(self): # Since it isn't possible to send a close frame in a fragmented # message (see test_client_send_close_in_fragmented_message), in fact, # this is the same test as test_client_sends_unexpected_continuation. - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(ProtocolError) as raised: @@ -234,7 +235,7 @@ class TextTests(ProtocolTestCase): def test_client_sends_text(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_text("😀".encode()) self.assertEqual( client.data_to_send(), [b"\x81\x84\x00\x00\x00\x00\xf0\x9f\x98\x80"] @@ -307,15 +308,15 @@ def test_server_receives_text_without_size_limit(self): def test_client_sends_fragmented_text(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_text("😀".encode()[:2], fin=False) self.assertEqual(client.data_to_send(), [b"\x01\x82\x00\x00\x00\x00\xf0\x9f"]) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation("😀😀".encode()[2:6], fin=False) self.assertEqual( client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\x98\x80\xf0\x9f"] ) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation("😀".encode()[2:], fin=True) self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\x98\x80"]) @@ -482,7 +483,7 @@ def test_server_receives_unexpected_text(self): def test_client_sends_text_after_sending_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: @@ -522,7 +523,7 @@ class BinaryTests(ProtocolTestCase): def test_client_sends_binary(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02\xfe\xff") self.assertEqual( client.data_to_send(), [b"\x82\x84\x00\x00\x00\x00\x01\x02\xfe\xff"] @@ -579,15 +580,15 @@ def test_server_receives_binary_over_size_limit(self): def test_client_sends_fragmented_binary(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_binary(b"\x01\x02", fin=False) self.assertEqual(client.data_to_send(), [b"\x02\x82\x00\x00\x00\x00\x01\x02"]) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation(b"\xee\xff\x01\x02", fin=False) self.assertEqual( client.data_to_send(), [b"\x00\x84\x00\x00\x00\x00\xee\xff\x01\x02"] ) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_continuation(b"\xee\xff", fin=True) self.assertEqual(client.data_to_send(), [b"\x80\x82\x00\x00\x00\x00\xee\xff"]) @@ -718,7 +719,7 @@ def test_server_receives_unexpected_binary(self): def test_client_sends_binary_after_sending_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: @@ -806,7 +807,7 @@ def test_close_reason_not_available_yet(self): def test_client_sends_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) @@ -819,7 +820,7 @@ def test_server_sends_close(self): def test_client_receives_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.receive_data(b"\x88\x00") self.assertEqual(client.events_received(), [Frame(OP_CLOSE, b"")]) self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) @@ -890,7 +891,7 @@ def test_server_receives_close_then_sends_close(self): def test_client_sends_close_with_code(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) self.assertIs(client.state, CLOSING) @@ -915,7 +916,7 @@ def test_server_receives_close_with_code(self): def test_client_sends_close_with_code_and_reason(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY, "going away") self.assertEqual( client.data_to_send(), [b"\x88\x8c\x00\x00\x00\x00\x03\xe9going away"] @@ -1002,7 +1003,7 @@ def test_server_receives_close_with_non_utf8_reason(self): def test_client_sends_close_twice(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) with self.assertRaises(InvalidState) as raised: @@ -1040,7 +1041,7 @@ class PingTests(ProtocolTestCase): def test_client_sends_ping(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) @@ -1075,7 +1076,7 @@ def test_server_receives_ping(self): def test_client_sends_ping_with_data(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"\x22\x66\xaa\xee") self.assertEqual( client.data_to_send(), [b"\x89\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] @@ -1144,10 +1145,10 @@ def test_server_receives_fragmented_ping_frame(self): def test_client_sends_ping_after_sending_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\x89\x80\x00\x44\x88\xcc"]) @@ -1199,7 +1200,7 @@ class PongTests(ProtocolTestCase): def test_client_sends_pong(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"") self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) @@ -1226,7 +1227,7 @@ def test_server_receives_pong(self): def test_client_sends_pong_with_data(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"\x22\x66\xaa\xee") self.assertEqual( client.data_to_send(), [b"\x8a\x84\x00\x44\x88\xcc\x22\x22\x22\x22"] @@ -1287,10 +1288,10 @@ def test_server_receives_fragmented_pong_frame(self): def test_client_sends_pong_after_sending_close(self): client = Protocol(CLIENT) - with self.enforce_mask(b"\x00\x00\x00\x00"): + with patch("secrets.token_bytes", return_value=b"\x00\x00\x00\x00"): client.send_close(CloseCode.GOING_AWAY) self.assertEqual(client.data_to_send(), [b"\x88\x82\x00\x00\x00\x00\x03\xe9"]) - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_pong(b"") self.assertEqual(client.data_to_send(), [b"\x8a\x80\x00\x44\x88\xcc"]) @@ -1459,7 +1460,7 @@ def test_client_send_close_in_fragmented_message(self): client = Protocol(CLIENT) client.send_text(b"Spam", fin=False) self.assertFrameSent(client, Frame(OP_TEXT, b"Spam", fin=False)) - with self.enforce_mask(b"\x3c\x3c\x3c\x3c"): + with patch("secrets.token_bytes", return_value=b"\x3c\x3c\x3c\x3c"): client.send_close() self.assertEqual(client.data_to_send(), [b"\x88\x80\x3c\x3c\x3c\x3c"]) self.assertIs(client.state, CLOSING) @@ -1819,7 +1820,7 @@ class ErrorTests(ProtocolTestCase): def test_client_hits_internal_error_reading_frame(self): client = Protocol(CLIENT) # This isn't supposed to happen, so we're simulating it. - with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + with patch("struct.unpack", side_effect=RuntimeError("BOOM")): client.receive_data(b"\x81\x00") self.assertIsInstance(client.parser_exc, RuntimeError) self.assertEqual(str(client.parser_exc), "BOOM") @@ -1828,7 +1829,7 @@ def test_client_hits_internal_error_reading_frame(self): def test_server_hits_internal_error_reading_frame(self): server = Protocol(SERVER) # This isn't supposed to happen, so we're simulating it. - with unittest.mock.patch("struct.unpack", side_effect=RuntimeError("BOOM")): + with patch("struct.unpack", side_effect=RuntimeError("BOOM")): server.receive_data(b"\x81\x80\x00\x00\x00\x00") self.assertIsInstance(server.parser_exc, RuntimeError) self.assertEqual(str(server.parser_exc), "BOOM") @@ -1844,7 +1845,7 @@ class ExtensionsTests(ProtocolTestCase): def test_client_extension_encodes_frame(self): client = Protocol(CLIENT) client.extensions = [Rsv2Extension()] - with self.enforce_mask(b"\x00\x44\x88\xcc"): + with patch("secrets.token_bytes", return_value=b"\x00\x44\x88\xcc"): client.send_ping(b"") self.assertEqual(client.data_to_send(), [b"\xa9\x80\x00\x44\x88\xcc"]) diff --git a/tests/test_server.py b/tests/test_server.py index 9f328ded5..43970a7cd 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -314,12 +314,14 @@ def test_reject_response_supports_int_status(self): self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") - @patch("websockets.server.ServerProtocol.process_request") + @patch( + "websockets.server.ServerProtocol.process_request", + side_effect=Exception("BOOM"), + ) def test_unexpected_error(self, process_request): """accept() handles unexpected errors and returns an error response.""" server = ServerProtocol() request = make_request() - process_request.side_effect = (Exception("BOOM"),) response = server.accept(request) self.assertEqual(response.status_code, 500) From 2d515c84acef39e9642358f442575bdc2d04c8de Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 18:35:10 +0100 Subject: [PATCH 715/762] Increase test timing to reduce flakiness. --- tests/asyncio/test_connection.py | 14 +++++++------- tests/sync/test_connection.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index dc4539948..5230eca89 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -980,13 +980,14 @@ async def test_pong_unsupported_type(self): @patch("random.getrandbits", return_value=1918987876) async def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" - self.connection.ping_interval = 2 * MS + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() + self.assertIsNotNone(self.connection.keepalive_task) self.assertEqual(self.connection.latency, 0) - # 2 ms: keepalive() sends a ping frame. - # 2.x ms: a pong frame is received. - await asyncio.sleep(3 * MS) - # 3 ms: check that the ping frame was sent. + # 3 ms: keepalive() sends a ping frame. + # 3.x ms: a pong frame is received. + await asyncio.sleep(4 * MS) + # 4 ms: check that the ping frame was sent. await self.assertFrameSent(Frame(Opcode.PING, b"rand")) self.assertGreater(self.connection.latency, 0) self.assertLess(self.connection.latency, MS) @@ -995,8 +996,7 @@ async def test_disable_keepalive(self): """keepalive is disabled when ping_interval is None.""" self.connection.ping_interval = None self.connection.start_keepalive() - await asyncio.sleep(3 * MS) - await self.assertNoFrameSent() + self.assertIsNone(self.connection.keepalive_task) @patch("random.getrandbits", return_value=1918987876) async def test_keepalive_times_out(self, getrandbits): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index be7ff36f4..a5aee35bb 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -742,13 +742,14 @@ def test_pong_unsupported_type(self): @patch("random.getrandbits", return_value=1918987876) def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" - self.connection.ping_interval = 2 * MS + self.connection.ping_interval = 4 * MS self.connection.start_keepalive() + self.assertIsNotNone(self.connection.keepalive_thread) self.assertEqual(self.connection.latency, 0) - # 2 ms: keepalive() sends a ping frame. - # 2.x ms: a pong frame is received. - time.sleep(3 * MS) - # 3 ms: check that the ping frame was sent. + # 3 ms: keepalive() sends a ping frame. + # 3.x ms: a pong frame is received. + time.sleep(4 * MS) + # 4 ms: check that the ping frame was sent. self.assertFrameSent(Frame(Opcode.PING, b"rand")) self.assertGreater(self.connection.latency, 0) self.assertLess(self.connection.latency, MS) @@ -757,8 +758,7 @@ def test_disable_keepalive(self): """keepalive is disabled when ping_interval is None.""" self.connection.ping_interval = None self.connection.start_keepalive() - time.sleep(3 * MS) - self.assertNoFrameSent() + self.assertIsNone(self.connection.keepalive_thread) @patch("random.getrandbits", return_value=1918987876) def test_keepalive_times_out(self, getrandbits): From 4ea521f62dfbda4335f21db6604260c83a0fac67 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 13:20:54 +0100 Subject: [PATCH 716/762] Add helpers to locate proxy for client connections. --- docs/reference/exceptions.rst | 10 +- src/websockets/__init__.py | 9 ++ src/websockets/exceptions.py | 42 +++++++- src/websockets/uri.py | 128 +++++++++++++++++++++++- tests/test_exceptions.py | 12 +++ tests/test_uri.py | 182 ++++++++++++++++++++++++++++++++-- tests/utils.py | 16 +++ 7 files changed, 382 insertions(+), 17 deletions(-) diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index d6b7f0f57..e0c2efdd1 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -28,14 +28,20 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs. .. autoexception:: InvalidURI -.. autoexception:: InvalidHandshake +.. autoexception:: InvalidProxy -.. autoexception:: InvalidMessage +.. autoexception:: InvalidHandshake .. autoexception:: SecurityError +.. autoexception:: InvalidMessage + .. autoexception:: InvalidStatus +.. autoexception:: InvalidProxyMessage + +.. autoexception:: InvalidProxyStatus + .. autoexception:: InvalidHeader .. autoexception:: InvalidHeaderFormat diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 1d0abe5cd..8bf282a73 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -39,6 +39,9 @@ "InvalidOrigin", "InvalidParameterName", "InvalidParameterValue", + "InvalidProxy", + "InvalidProxyMessage", + "InvalidProxyStatus", "InvalidState", "InvalidStatus", "InvalidUpgrade", @@ -99,6 +102,9 @@ InvalidOrigin, InvalidParameterName, InvalidParameterValue, + InvalidProxy, + InvalidProxyMessage, + InvalidProxyStatus, InvalidState, InvalidStatus, InvalidUpgrade, @@ -157,6 +163,9 @@ "InvalidOrigin": ".exceptions", "InvalidParameterName": ".exceptions", "InvalidParameterValue": ".exceptions", + "InvalidProxy": ".exceptions", + "InvalidProxyMessage": ".exceptions", + "InvalidProxyStatus": ".exceptions", "InvalidState": ".exceptions", "InvalidStatus": ".exceptions", "InvalidUpgrade": ".exceptions", diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index 73b24debf..e70aac92e 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -6,11 +6,14 @@ * :exc:`ConnectionClosedOK` * :exc:`ConnectionClosedError` * :exc:`InvalidURI` + * :exc:`InvalidProxy` * :exc:`InvalidHandshake` * :exc:`SecurityError` * :exc:`InvalidMessage` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) + * :exc:`InvalidProxyMessage` + * :exc:`InvalidProxyStatus` * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` @@ -42,13 +45,16 @@ "ConnectionClosedOK", "ConnectionClosedError", "InvalidURI", + "InvalidProxy", "InvalidHandshake", "SecurityError", + "InvalidMessage", "InvalidStatus", + "InvalidProxyMessage", + "InvalidProxyStatus", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", - "InvalidMessage", "InvalidOrigin", "InvalidUpgrade", "NegotiationError", @@ -169,6 +175,20 @@ def __str__(self) -> str: return f"{self.uri} isn't a valid URI: {self.msg}" +class InvalidProxy(WebSocketException): + """ + Raised when connecting via a proxy that isn't valid. + + """ + + def __init__(self, proxy: str, msg: str) -> None: + self.proxy = proxy + self.msg = msg + + def __str__(self) -> str: + return f"{self.proxy} isn't a valid proxy: {self.msg}" + + class InvalidHandshake(WebSocketException): """ Base class for exceptions raised when the opening handshake fails. @@ -208,6 +228,26 @@ def __str__(self) -> str: ) +class InvalidProxyMessage(InvalidHandshake): + """ + Raised when a proxy response is malformed. + + """ + + +class InvalidProxyStatus(InvalidHandshake): + """ + Raised when a proxy rejects the connection. + + """ + + def __init__(self, response: http11.Response) -> None: + self.response = response + + def __str__(self) -> str: + return f"proxy rejected connection: HTTP {self.response.status_code:d}" + + class InvalidHeader(InvalidHandshake): """ Raised when an HTTP header doesn't have a valid format or value. diff --git a/src/websockets/uri.py b/src/websockets/uri.py index 16bb3f1c1..b925b99b5 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,13 +2,18 @@ import dataclasses import urllib.parse +import urllib.request -from .exceptions import InvalidURI +from .exceptions import InvalidProxy, InvalidURI __all__ = ["parse_uri", "WebSocketURI"] +# All characters from the gen-delims and sub-delims sets in RFC 3987. +DELIMS = ":/?#[]@!$&'()*+,;=" + + @dataclasses.dataclass class WebSocketURI: """ @@ -53,10 +58,6 @@ def user_info(self) -> tuple[str, str] | None: return (self.username, self.password) -# All characters from the gen-delims and sub-delims sets in RFC 3987. -DELIMS = ":/?#[]@!$&'()*+,;=" - - def parse_uri(uri: str) -> WebSocketURI: """ Parse and validate a WebSocket URI. @@ -105,3 +106,120 @@ def parse_uri(uri: str) -> WebSocketURI: password = urllib.parse.quote(password, safe=DELIMS) return WebSocketURI(secure, host, port, path, query, username, password) + + +@dataclasses.dataclass +class Proxy: + """ + Proxy. + + Attributes: + scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``, + ``"https"``, or ``"http"``. + host: Normalized to lower case. + port: Always set even if it's the default. + username: Available when the proxy address contains `User Information`_. + password: Available when the proxy address contains `User Information`_. + + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 + + """ + + scheme: str + host: str + port: int + username: str | None = None + password: str | None = None + + @property + def user_info(self) -> tuple[str, str] | None: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) + + +def parse_proxy(proxy: str) -> Proxy: + """ + Parse and validate a proxy. + + Args: + proxy: proxy. + + Returns: + Parsed proxy. + + Raises: + InvalidProxy: If ``proxy`` isn't a valid proxy. + + """ + parsed = urllib.parse.urlparse(proxy) + if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]: + raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported") + if parsed.hostname is None: + raise InvalidProxy(proxy, "hostname isn't provided") + if parsed.path not in ["", "/"]: + raise InvalidProxy(proxy, "path is meaningless") + if parsed.query != "": + raise InvalidProxy(proxy, "query is meaningless") + if parsed.fragment != "": + raise InvalidProxy(proxy, "fragment is meaningless") + + scheme = parsed.scheme + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == "https" else 80) + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise InvalidProxy(proxy, "username provided without password") + + try: + proxy.encode("ascii") + except UnicodeEncodeError: + # Input contains non-ASCII characters. + # It must be an IRI. Convert it to a URI. + host = host.encode("idna").decode() + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return Proxy(scheme, host, port, username, password) + + +def get_proxy(uri: WebSocketURI) -> str | None: + """ + Return the proxy to use for connecting to the given WebSocket URI, if any. + + """ + if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"): + return None + + # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if + # available, else favor the proxy for HTTPS connections over the proxy for + # HTTP connections. + + # The priority of a proxy for WebSocket connections is unspecified. We give + # it the highest priority. This makes it easy to configure a specific proxy + # for websockets. + + # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or + # as {"https": "socks5h://host:port"} depending on whether they're declared + # in the operating system or in environment variables. + + proxies = urllib.request.getproxies() + if uri.secure: + schemes = ["wss", "socks", "https"] + else: + schemes = ["ws", "socks", "https", "http"] + + for scheme in schemes: + proxy = proxies.get(scheme) + if proxy is not None: + if scheme == "socks" and proxy.startswith("http://"): + proxy = "socks5h://" + proxy[7:] + return proxy + else: + return None diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index e0518b0e0..8b437ab5e 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -83,6 +83,10 @@ def test_str(self): InvalidURI("|", "not at all!"), "| isn't a valid URI: not at all!", ), + ( + InvalidProxy("|", "not at all!"), + "| isn't a valid proxy: not at all!", + ), ( InvalidHandshake("invalid request"), "invalid request", @@ -99,6 +103,14 @@ def test_str(self): InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", ), + ( + InvalidProxyMessage("malformed HTTP message"), + "malformed HTTP message", + ), + ( + InvalidProxyStatus(Response(401, "Unauthorized", Headers())), + "proxy rejected connection: HTTP 401", + ), ( InvalidHeader("Name"), "missing Name header", diff --git a/tests/test_uri.py b/tests/test_uri.py index 8acc01c18..35b51fa58 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -1,7 +1,10 @@ import unittest -from websockets.exceptions import InvalidURI +from websockets.exceptions import InvalidProxy, InvalidURI from websockets.uri import * +from websockets.uri import Proxy, get_proxy, parse_proxy + +from .utils import patch_environ VALID_URIS = [ @@ -59,38 +62,199 @@ "ws:///path", ] -RESOURCE_NAMES = [ +URIS_WITH_RESOURCE_NAMES = [ ("ws://localhost/", "/"), ("ws://localhost", "/"), ("ws://localhost/path?query", "/path?query"), ("ws://høst/πass?qùéry", "/%CF%80ass?q%C3%B9%C3%A9ry"), ] -USER_INFOS = [ +URIS_WITH_USER_INFO = [ ("ws://localhost/", None), ("ws://user:pass@localhost/", ("user", "pass")), ("ws://üser:påss@høst/", ("%C3%BCser", "p%C3%A5ss")), ] +VALID_PROXIES = [ + ( + "http://proxy:8080", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "https://proxy:8080", + Proxy("https", "proxy", 8080, None, None), + ), + ( + "http://proxy", + Proxy("http", "proxy", 80, None, None), + ), + ( + "http://proxy:8080/", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "http://PROXY:8080", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "http://user:pass@proxy:8080", + Proxy("http", "proxy", 8080, "user", "pass"), + ), + ( + "http://høst:8080/", + Proxy("http", "xn--hst-0na", 8080, None, None), + ), + ( + "http://üser:påss@høst:8080", + Proxy("http", "xn--hst-0na", 8080, "%C3%BCser", "p%C3%A5ss"), + ), +] + +INVALID_PROXIES = [ + "ws://proxy:8080", + "wss://proxy:8080", + "http://proxy:8080/path", + "http://proxy:8080/?query", + "http://proxy:8080/#fragment", + "http://user@proxy", + "http:///", +] + +PROXIES_WITH_USER_INFO = [ + ("http://proxy", None), + ("http://user:pass@proxy", ("user", "pass")), + ("http://üser:påss@høst", ("%C3%BCser", "p%C3%A5ss")), +] + +PROXY_ENVS = [ + ( + {"ws_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"ws_proxy": "http://proxy:8080"}, + "wss://example.com/", + None, + ), + ( + {"wss_proxy": "http://proxy:8080"}, + "ws://example.com/", + None, + ), + ( + {"wss_proxy": "http://proxy:8080"}, + "wss://example.com/", + "http://proxy:8080", + ), + ( + {"http_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"http_proxy": "http://proxy:8080"}, + "wss://example.com/", + None, + ), + ( + {"https_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"https_proxy": "http://proxy:8080"}, + "wss://example.com/", + "http://proxy:8080", + ), + ( + {"socks_proxy": "http://proxy:1080"}, + "ws://example.com/", + "socks5h://proxy:1080", + ), + ( + {"socks_proxy": "http://proxy:1080"}, + "wss://example.com/", + "socks5h://proxy:1080", + ), + ( + {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, + "ws://example.com/", + "http://proxy1:8080", + ), + ( + {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, + "wss://example.com/", + "http://proxy2:8080", + ), + ( + {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, + "ws://example.com/", + "http://proxy2:8080", + ), + ( + {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, + "wss://example.com/", + "http://proxy2:8080", + ), + ( + {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, + "ws://example.com/", + "socks5h://proxy:1080", + ), + ( + {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, + "wss://example.com/", + "socks5h://proxy:1080", + ), + ( + {"socks_proxy": "http://proxy:1080", "no_proxy": ".local"}, + "ws://example.local/", + None, + ), +] + class URITests(unittest.TestCase): - def test_success(self): + def test_parse_valid_uris(self): for uri, parsed in VALID_URIS: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri), parsed) - def test_error(self): + def test_parse_invalid_uris(self): for uri in INVALID_URIS: with self.subTest(uri=uri): with self.assertRaises(InvalidURI): parse_uri(uri) - def test_resource_name(self): - for uri, resource_name in RESOURCE_NAMES: + def test_parse_resource_name(self): + for uri, resource_name in URIS_WITH_RESOURCE_NAMES: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri).resource_name, resource_name) - def test_user_info(self): - for uri, user_info in USER_INFOS: + def test_parse_user_info(self): + for uri, user_info in URIS_WITH_USER_INFO: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri).user_info, user_info) + + def test_parse_valid_proxies(self): + for proxy, parsed in VALID_PROXIES: + with self.subTest(proxy=proxy): + self.assertEqual(parse_proxy(proxy), parsed) + + def test_parse_invalid_proxies(self): + for proxy in INVALID_PROXIES: + with self.subTest(proxy=proxy): + with self.assertRaises(InvalidProxy): + parse_proxy(proxy) + + def test_parse_proxy_user_info(self): + for proxy, user_info in PROXIES_WITH_USER_INFO: + with self.subTest(proxy=proxy): + self.assertEqual(parse_proxy(proxy).user_info, user_info) + + def test_get_proxy(self): + for environ, uri, proxy in PROXY_ENVS: + with patch_environ(environ): + with self.subTest(environ=environ, uri=uri): + self.assertEqual(get_proxy(parse_uri(uri)), proxy) diff --git a/tests/utils.py b/tests/utils.py index 77d020726..f68a447b1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -139,6 +139,22 @@ def assertNoLogs(self, logger=None, level=None): self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) +@contextlib.contextmanager +def patch_environ(environ): + backup = {} + for key, value in environ.items(): + backup[key] = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + for key, value in backup.items(): + if value is None: + del os.environ[key] + else: # pragma: no cover + os.environ[key] = value + + @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: From 4a89e5616ffed1a8662fe195ad14827bb93a9bed Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 17:18:27 +0100 Subject: [PATCH 717/762] Add support for SOCKS proxies. Fix #475. --- docs/project/changelog.rst | 19 ++++++- docs/reference/features.rst | 3 +- docs/topics/index.rst | 1 + docs/topics/proxies.rst | 66 +++++++++++++++++++++++ src/websockets/asyncio/client.py | 64 +++++++++++++++++++++-- src/websockets/sync/client.py | 68 ++++++++++++++++++++++-- src/websockets/version.py | 2 +- tests/asyncio/test_client.py | 83 ++++++++++++++++++++++++++++- tests/proxy.py | 89 ++++++++++++++++++++++++++++++++ tests/requirements.txt | 2 + tests/sync/test_client.py | 80 +++++++++++++++++++++++++++- tox.ini | 21 ++++++-- 12 files changed, 479 insertions(+), 19 deletions(-) create mode 100644 docs/topics/proxies.rst create mode 100644 tests/proxy.py create mode 100644 tests/requirements.txt diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7f341d942..2a429b43e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -25,13 +25,28 @@ fixing regressions shortly after a release. Only documented APIs are public. Undocumented, private APIs may change without notice. -.. _14.3: +.. _15.0: -14.3 +15.0 ---- *In development* +Backwards-incompatible changes +.............................. + +.. admonition:: Client connections use SOCKS proxies automatically. + :class: important + + If a proxy is configured in the operating system or with an environment + variable, websockets uses it automatically when connecting to a server. + This feature requires installing the third-party library `python-socks`_. + + If you want to disable the proxy, add ``proxy=None`` when calling + :func:`~asyncio.client.connect`. See :doc:`../topics/proxies` for details. + + .. _python-socks: https://github.com/romis2012/python-socks + New features ............ diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 6ba42f66b..eaecd02a9 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -168,11 +168,10 @@ Client +------------------------------------+--------+--------+--------+--------+ | Connect via HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ - | Connect via SOCKS5 proxy (`#475`_) | ❌ | ❌ | — | ❌ | + | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ .. _#364: https://github.com/python-websockets/websockets/issues/364 -.. _#475: https://github.com/python-websockets/websockets/issues/475 .. _#784: https://github.com/python-websockets/websockets/issues/784 Known limitations diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 616753c6c..a08d487c9 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -15,3 +15,4 @@ Get a deeper understanding of how websockets is built and why. memory security performance + proxies diff --git a/docs/topics/proxies.rst b/docs/topics/proxies.rst new file mode 100644 index 000000000..fd3ae78b6 --- /dev/null +++ b/docs/topics/proxies.rst @@ -0,0 +1,66 @@ +Proxies +======= + +.. currentmodule:: websockets + +If a proxy is configured in the operating system or with an environment +variable, websockets uses it automatically when connecting to a server. + +Configuration +------------- + +First, if the server is in the proxy bypass list of the operating system or in +the ``no_proxy`` environment variable, websockets connects directly. + +Then, it looks for a proxy in the following locations: + +1. The ``wss_proxy`` or ``ws_proxy`` environment variables for ``wss://`` and + ``ws://`` connections respectively. They allow configuring a specific proxy + for WebSocket connections. +2. A SOCKS proxy configured in the operating system. +3. An HTTP proxy configured in the operating system or in the ``https_proxy`` + environment variable, for both ``wss://`` and ``ws://`` connections. +4. An HTTP proxy configured in the operating system or in the ``http_proxy`` + environment variable, only for ``ws://`` connections. + +Finally, if no proxy is found, websockets connects directly. + +While environment variables are case-insensitive, the lower-case spelling is the +most common, for `historical reasons`_, and recommended. + +.. _historical reasons: https://unix.stackexchange.com/questions/212894/ + +.. admonition:: Any environment variable can configure a SOCKS proxy or an HTTP proxy. + :class: tip + + For example, ``https_proxy=socks5h://proxy:1080/`` configures a SOCKS proxy + for all WebSocket connections. Likewise, ``wss_proxy=http://proxy:8080/`` + configures an HTTP proxy only for ``wss://`` connections. + +.. admonition:: What if websockets doesn't select the right proxy? + :class: hint + + websockets relies on :func:`~urllib.request.getproxies()` to read the proxy + configuration. Check that it returns what you expect. If it doesn't, review + your proxy configuration. + +You can override the default configuration and configure a proxy explicitly with +the ``proxy`` argument of :func:`~asyncio.client.connect`. Set ``proxy=None`` to +disable the proxy. + +SOCKS proxies +------------- + +Connecting through a SOCKS proxy requires installing the third-party library +`python-socks`_:: + + $ pip install python-socks\[asyncio\] + +.. _python-socks: https://github.com/romis2012/python-socks + +python-socks supports SOCKS4, SOCKS4a, SOCKS5, and SOCKS5h. The protocol version +is configured in the address of the proxy e.g. ``socks5h://proxy:1080/``. When a +SOCKS proxy is configured in the operating system, python-socks uses SOCKS5h. + +python-socks supports username/password authentication for SOCKS5 (:rfc:`1929`) +but does not support other authentication methods such as GSSAPI (:rfc:`1961`). diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index bde0beeea..f76095ead 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -7,7 +7,7 @@ import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, Callable +from typing import Any, Callable, Literal from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike @@ -18,7 +18,7 @@ from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import WebSocketURI, parse_uri +from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .compatibility import TimeoutError, asyncio_timeout from .connection import Connection @@ -208,6 +208,10 @@ class connect: user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. process_exception: When reconnecting automatically, tell whether an error is transient or fatal. The default behavior is defined by :func:`process_exception`. Refer to its documentation for details. @@ -279,6 +283,7 @@ def __init__( # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, process_exception: Callable[[Exception], Exception | None] = process_exception, # Timeouts open_timeout: float | None = 10, @@ -333,6 +338,7 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: ) return connection + self.proxy = proxy self.protocol_factory = protocol_factory self.handshake_args = ( additional_headers, @@ -346,9 +352,20 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: async def create_connection(self) -> ClientConnection: """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() + kwargs = self.connection_kwargs.copy() ws_uri = parse_uri(self.uri) - kwargs = self.connection_kwargs.copy() + + proxy = self.proxy + proxy_uri: Proxy | None = None + if kwargs.get("unix", False): + proxy = None + if kwargs.get("sock") is not None: + proxy = None + if proxy is True: + proxy = get_proxy(ws_uri) + if proxy is not None: + proxy_uri = parse_proxy(proxy) def factory() -> ClientConnection: return self.protocol_factory(ws_uri) @@ -365,6 +382,47 @@ def factory() -> ClientConnection: if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) else: + if proxy_uri is not None: + if proxy_uri.scheme[:5] == "socks": + try: + from python_socks import ProxyType + from python_socks.async_.asyncio import Proxy + except ImportError: + raise ImportError( + "python-socks is required to use a SOCKS proxy" + ) + if proxy_uri.scheme == "socks5h": + proxy_type = ProxyType.SOCKS5 + rdns = True + elif proxy_uri.scheme == "socks5": + proxy_type = ProxyType.SOCKS5 + rdns = False + # We use mitmproxy for testing and it doesn't support SOCKS4. + elif proxy_uri.scheme == "socks4a": # pragma: no cover + proxy_type = ProxyType.SOCKS4 + rdns = True + elif proxy_uri.scheme == "socks4": # pragma: no cover + proxy_type = ProxyType.SOCKS4 + rdns = False + # Proxy types are enforced in parse_proxy(). + else: + raise AssertionError("unsupported SOCKS proxy") + socks_proxy = Proxy( + proxy_type, + proxy_uri.host, + proxy_uri.port, + proxy_uri.username, + proxy_uri.password, + rdns, + ) + kwargs["sock"] = await socks_proxy.connect( + ws_uri.host, + ws_uri.port, + local_addr=kwargs.pop("local_addr", None), + ) + # Proxy types are enforced in parse_proxy(). + else: + raise AssertionError("unsupported proxy") if kwargs.get("sock") is None: kwargs.setdefault("host", ws_uri.host) kwargs.setdefault("port", ws_uri.port) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index da2b88591..96f62edab 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -5,7 +5,7 @@ import threading import warnings from collections.abc import Sequence -from typing import Any +from typing import Any, Literal from ..client import ClientProtocol from ..datastructures import HeadersLike @@ -15,7 +15,7 @@ from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import parse_uri +from ..uri import Proxy, get_proxy, parse_proxy, parse_uri from .connection import Connection from .utils import Deadline @@ -139,6 +139,7 @@ def connect( # HTTP additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -189,6 +190,10 @@ def connect( user_agent_header: Value of the ``User-Agent`` request header. It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -253,6 +258,16 @@ def connect( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + proxy_uri: Proxy | None = None + if unix: + proxy = None + if sock is not None: + proxy = None + if proxy is True: + proxy = get_proxy(ws_uri) + if proxy is not None: + proxy_uri = parse_proxy(proxy) + # Calculate timeouts on the TCP, TLS, and WebSocket handshakes. # The TCP and TLS timeouts must be set on the socket, then removed # to avoid conflicting with the WebSocket timeout in handshake(). @@ -271,8 +286,53 @@ def connect( assert path is not None # mypy cannot figure this out sock.connect(path) else: - kwargs.setdefault("timeout", deadline.timeout()) - sock = socket.create_connection((ws_uri.host, ws_uri.port), **kwargs) + if proxy_uri is not None: + if proxy_uri.scheme[:5] == "socks": + try: + from python_socks import ProxyType + from python_socks.sync import Proxy + except ImportError: + raise ImportError( + "python-socks is required to use a SOCKS proxy" + ) + if proxy_uri.scheme == "socks5h": + proxy_type = ProxyType.SOCKS5 + rdns = True + elif proxy_uri.scheme == "socks5": + proxy_type = ProxyType.SOCKS5 + rdns = False + # We use mitmproxy for testing and it doesn't support SOCKS4. + elif proxy_uri.scheme == "socks4a": # pragma: no cover + proxy_type = ProxyType.SOCKS4 + rdns = True + elif proxy_uri.scheme == "socks4": # pragma: no cover + proxy_type = ProxyType.SOCKS4 + rdns = False + # Proxy types are enforced in parse_proxy(). + else: + raise AssertionError("unsupported SOCKS proxy") + socks_proxy = Proxy( + proxy_type, + proxy_uri.host, + proxy_uri.port, + proxy_uri.username, + proxy_uri.password, + rdns, + ) + sock = socks_proxy.connect( + ws_uri.host, + ws_uri.port, + timeout=deadline.timeout(), + local_addr=kwargs.pop("local_addr", None), + ) + # Proxy types are enforced in parse_proxy(). + else: + raise AssertionError("unsupported proxy") + else: + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection( + (ws_uri.host, ws_uri.port), **kwargs + ) sock.settimeout(None) # Disable Nagle algorithm diff --git a/src/websockets/version.py b/src/websockets/version.py index ca9a9115b..611e7d238 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -20,7 +20,7 @@ released = False -tag = version = commit = "14.3" +tag = version = commit = "15.0" if not released: # pragma: no cover diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index f05bfc699..cb2b8ede6 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -4,6 +4,7 @@ import logging import socket import ssl +import sys import unittest from websockets.asyncio.client import * @@ -13,13 +14,21 @@ from websockets.exceptions import ( InvalidHandshake, InvalidMessage, + InvalidProxy, InvalidStatus, InvalidURI, SecurityError, ) from websockets.extensions.permessage_deflate import PerMessageDeflate -from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path +from ..proxy import async_proxy +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, + patch_environ, + temp_unix_socket_path, +) from .server import args, get_host_port, get_uri, handler @@ -555,6 +564,78 @@ def redirect(connection, request): ) +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class ProxyClientTests(unittest.IsolatedAsyncioTestCase): + @contextlib.asynccontextmanager + async def socks_proxy(self, auth=None): + if auth: + proxyauth = "hello:iloveyou" + proxy_uri = "http://hello:iloveyou@localhost:1080" + else: + proxyauth = None + proxy_uri = "http://localhost:1080" + async with async_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows: + with patch_environ({"socks_proxy": proxy_uri}): + yield record_flows + + async def test_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy.""" + async with self.socks_proxy() as proxy: + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + async def test_secure_socks_proxy(self): + """Client connects to server securely through a SOCKS5 proxy.""" + async with self.socks_proxy() as proxy: + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + async def test_authenticated_socks_proxy(self): + """Client connects to server through an authenticated SOCKS5 proxy.""" + async with self.socks_proxy(auth=True) as proxy: + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + async def test_explicit_proxy(self): + """Client connects to server through a proxy set explicitly.""" + async with async_proxy(mode=["socks5"]) as proxy: + async with serve(*args) as server: + async with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:1080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + async def test_ignore_proxy_with_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with self.socks_proxy() as proxy: + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 0) + + async def test_unsupported_proxy(self): + """Client connects to server through an unsupported proxy.""" + with patch_environ({"ws_proxy": "other://localhost:1080"}): + with self.assertRaises(InvalidProxy) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:1080 isn't a valid proxy: scheme other isn't supported", + ) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): diff --git a/tests/proxy.py b/tests/proxy.py new file mode 100644 index 000000000..95525a360 --- /dev/null +++ b/tests/proxy.py @@ -0,0 +1,89 @@ +import asyncio +import contextlib +import pathlib +import threading +import warnings + + +warnings.filterwarnings("ignore", category=DeprecationWarning, module="mitmproxy") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") + +try: + from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig + from mitmproxy.master import Master + from mitmproxy.options import Options +except ImportError: + pass + + +class RecordFlows: + def __init__(self): + self.ready = asyncio.get_running_loop().create_future() + self.flows = [] + + def running(self): + self.ready.set_result(None) + + def websocket_start(self, flow): + self.flows.append(flow) + + def get_flows(self): + flows, self.flows[:] = self.flows[:], [] + return flows + + +@contextlib.asynccontextmanager +async def async_proxy(mode, **config): + options = Options(mode=mode) + master = Master(options) + record_flows = RecordFlows() + master.addons.add( + core.Core(), + proxyauth.ProxyAuth(), + proxyserver.Proxyserver(), + next_layer.NextLayer(), + tlsconfig.TlsConfig(), + record_flows, + ) + config.update( + # Use our test certificate for TLS between client and proxy + # and disable TLS verification between proxy and upstream. + certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], + ssl_insecure=True, + ) + options.update(**config) + + asyncio.create_task(master.run()) + try: + await record_flows.ready + yield record_flows + finally: + for server in master.addons.get("proxyserver").servers: + await server.stop() + master.shutdown() + + +@contextlib.contextmanager +def sync_proxy(mode, **config): + loop = None + test_done = None + proxy_ready = threading.Event() + record_flows = None + + async def proxy_coroutine(): + nonlocal loop, test_done, proxy_ready, record_flows + loop = asyncio.get_running_loop() + test_done = loop.create_future() + async with async_proxy(mode, **config) as record_flows: + proxy_ready.set() + await test_done + + proxy_thread = threading.Thread(target=asyncio.run, args=(proxy_coroutine(),)) + proxy_thread.start() + try: + proxy_ready.wait() + yield record_flows + finally: + loop.call_soon_threadsafe(test_done.set_result, None) + proxy_thread.join() diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 000000000..f375e6f69 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,2 @@ +python-socks[asyncio] +mitmproxy diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 736a84c98..2f62dd34d 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,8 +1,10 @@ +import contextlib import http import logging import socket import socketserver import ssl +import sys import threading import time import unittest @@ -10,17 +12,20 @@ from websockets.exceptions import ( InvalidHandshake, InvalidMessage, + InvalidProxy, InvalidStatus, InvalidURI, ) from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * +from ..proxy import sync_proxy from ..utils import ( CLIENT_CONTEXT, MS, SERVER_CONTEXT, DeprecationTestCase, + patch_environ, temp_unix_socket_path, ) from .server import get_uri, run_server, run_unix_server @@ -37,7 +42,7 @@ def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to the right socket. + # Use a non-existing domain to ensure we connect to sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -300,6 +305,79 @@ def test_reject_invalid_server_hostname(self): ) +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class ProxyClientTests(unittest.TestCase): + @contextlib.contextmanager + def socks_proxy(self, auth=None): + if auth: + proxyauth = "hello:iloveyou" + proxy_uri = "http://hello:iloveyou@localhost:1080" + else: + proxyauth = None + proxy_uri = "http://localhost:1080" + + with sync_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows: + with patch_environ({"socks_proxy": proxy_uri}): + yield record_flows + + def test_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy.""" + with self.socks_proxy() as proxy: + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + def test_secure_socks_proxy(self): + """Client connects to server securely through a SOCKS5 proxy.""" + with self.socks_proxy() as proxy: + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + def test_authenticated_socks_proxy(self): + """Client connects to server through an authenticated SOCKS5 proxy.""" + with self.socks_proxy(auth=True) as proxy: + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + def test_explicit_proxy(self): + """Client connects to server through a proxy set explicitly.""" + with sync_proxy(mode=["socks5"]) as proxy: + with run_server() as server: + with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:1080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 1) + + def test_ignore_proxy_with_existing_socket(self): + """Client connects using a pre-existing socket.""" + with self.socks_proxy() as proxy: + with run_server() as server: + with socket.create_connection(server.socket.getsockname()) as sock: + # Use a non-existing domain to ensure we connect to sock. + with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(len(proxy.get_flows()), 0) + + def test_unsupported_proxy(self): + """Client connects to server through an unsupported proxy.""" + with patch_environ({"ws_proxy": "other://localhost:1080"}): + with self.assertRaises(InvalidProxy) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:1080 isn't a valid proxy: scheme other isn't supported", + ) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.TestCase): def test_connection(self): diff --git a/tox.ini b/tox.ini index 0bcec5ded..f5a2f5d3c 100644 --- a/tox.ini +++ b/tox.ini @@ -12,27 +12,38 @@ env_list = [testenv] commands = python -W error::DeprecationWarning -W error::PendingDeprecationWarning -m unittest {posargs} -pass_env = WEBSOCKETS_* +pass_env = + WEBSOCKETS_* +deps = + mitmproxy + python-socks[asyncio] [testenv:coverage] commands = python -m coverage run --source {envsitepackagesdir}/websockets,tests -m unittest {posargs} python -m coverage report --show-missing --fail-under=100 -deps = coverage +deps = + coverage + {[testenv]deps} [testenv:maxi_cov] commands = python tests/maxi_cov.py {envsitepackagesdir} python -m coverage report --show-missing --fail-under=100 -deps = coverage +deps = + coverage + {[testenv]deps} [testenv:ruff] commands = ruff format --check src tests ruff check src tests -deps = ruff +deps = + ruff [testenv:mypy] commands = mypy --strict src -deps = mypy +deps = + mypy + python-socks From 10175f7a41ea1cf17aff7983d7eace3c2e4da5c0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 27 Jan 2025 22:07:21 +0100 Subject: [PATCH 718/762] Refactor SOCKS proxy implementation. --- src/websockets/asyncio/client.py | 109 +++++++++++++++----------- src/websockets/sync/client.py | 127 ++++++++++++++++++------------- 2 files changed, 141 insertions(+), 95 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index f76095ead..7052ca85a 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -3,6 +3,7 @@ import asyncio import logging import os +import socket import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence @@ -357,15 +358,12 @@ async def create_connection(self) -> ClientConnection: ws_uri = parse_uri(self.uri) proxy = self.proxy - proxy_uri: Proxy | None = None if kwargs.get("unix", False): proxy = None if kwargs.get("sock") is not None: proxy = None if proxy is True: proxy = get_proxy(ws_uri) - if proxy is not None: - proxy_uri = parse_proxy(proxy) def factory() -> ClientConnection: return self.protocol_factory(ws_uri) @@ -381,48 +379,14 @@ def factory() -> ClientConnection: if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) + elif proxy is not None: + kwargs["sock"] = await connect_proxy( + parse_proxy(proxy), + ws_uri, + local_addr=kwargs.pop("local_addr", None), + ) + _, connection = await loop.create_connection(factory, **kwargs) else: - if proxy_uri is not None: - if proxy_uri.scheme[:5] == "socks": - try: - from python_socks import ProxyType - from python_socks.async_.asyncio import Proxy - except ImportError: - raise ImportError( - "python-socks is required to use a SOCKS proxy" - ) - if proxy_uri.scheme == "socks5h": - proxy_type = ProxyType.SOCKS5 - rdns = True - elif proxy_uri.scheme == "socks5": - proxy_type = ProxyType.SOCKS5 - rdns = False - # We use mitmproxy for testing and it doesn't support SOCKS4. - elif proxy_uri.scheme == "socks4a": # pragma: no cover - proxy_type = ProxyType.SOCKS4 - rdns = True - elif proxy_uri.scheme == "socks4": # pragma: no cover - proxy_type = ProxyType.SOCKS4 - rdns = False - # Proxy types are enforced in parse_proxy(). - else: - raise AssertionError("unsupported SOCKS proxy") - socks_proxy = Proxy( - proxy_type, - proxy_uri.host, - proxy_uri.port, - proxy_uri.username, - proxy_uri.password, - rdns, - ) - kwargs["sock"] = await socks_proxy.connect( - ws_uri.host, - ws_uri.port, - local_addr=kwargs.pop("local_addr", None), - ) - # Proxy types are enforced in parse_proxy(). - else: - raise AssertionError("unsupported proxy") if kwargs.get("sock") is None: kwargs.setdefault("host", ws_uri.host) kwargs.setdefault("port", ws_uri.port) @@ -624,3 +588,60 @@ def unix_connect( else: uri = "wss://localhost/" return connect(uri=uri, unix=True, path=path, **kwargs) + + +try: + from python_socks import ProxyType + from python_socks.async_.asyncio import Proxy as SocksProxy + + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> socket.socket: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + +except ImportError: + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> socket.socket: + raise ImportError("python-socks is required to use a SOCKS proxy") + + +async def connect_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, +) -> socket.socket: + """Connect via a proxy and return the socket.""" + # parse_proxy() validates proxy.scheme. + if proxy.scheme[:5] == "socks": + return await connect_socks_proxy(proxy, ws_uri, **kwargs) + else: + raise AssertionError("unsupported proxy") diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 96f62edab..5fbec67ad 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -15,7 +15,7 @@ from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import Proxy, get_proxy, parse_proxy, parse_uri +from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .connection import Connection from .utils import Deadline @@ -258,15 +258,12 @@ def connect( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") - proxy_uri: Proxy | None = None if unix: proxy = None if sock is not None: proxy = None if proxy is True: proxy = get_proxy(ws_uri) - if proxy is not None: - proxy_uri = parse_proxy(proxy) # Calculate timeouts on the TCP, TLS, and WebSocket handshakes. # The TCP and TLS timeouts must be set on the socket, then removed @@ -285,54 +282,21 @@ def connect( sock.settimeout(deadline.timeout()) assert path is not None # mypy cannot figure this out sock.connect(path) + elif proxy is not None: + sock = connect_proxy( + parse_proxy(proxy), + ws_uri, + deadline, + # websockets is consistent with the socket module while + # python_socks is consistent across implementations. + local_addr=kwargs.pop("source_address", None), + ) else: - if proxy_uri is not None: - if proxy_uri.scheme[:5] == "socks": - try: - from python_socks import ProxyType - from python_socks.sync import Proxy - except ImportError: - raise ImportError( - "python-socks is required to use a SOCKS proxy" - ) - if proxy_uri.scheme == "socks5h": - proxy_type = ProxyType.SOCKS5 - rdns = True - elif proxy_uri.scheme == "socks5": - proxy_type = ProxyType.SOCKS5 - rdns = False - # We use mitmproxy for testing and it doesn't support SOCKS4. - elif proxy_uri.scheme == "socks4a": # pragma: no cover - proxy_type = ProxyType.SOCKS4 - rdns = True - elif proxy_uri.scheme == "socks4": # pragma: no cover - proxy_type = ProxyType.SOCKS4 - rdns = False - # Proxy types are enforced in parse_proxy(). - else: - raise AssertionError("unsupported SOCKS proxy") - socks_proxy = Proxy( - proxy_type, - proxy_uri.host, - proxy_uri.port, - proxy_uri.username, - proxy_uri.password, - rdns, - ) - sock = socks_proxy.connect( - ws_uri.host, - ws_uri.port, - timeout=deadline.timeout(), - local_addr=kwargs.pop("local_addr", None), - ) - # Proxy types are enforced in parse_proxy(). - else: - raise AssertionError("unsupported proxy") - else: - kwargs.setdefault("timeout", deadline.timeout()) - sock = socket.create_connection( - (ws_uri.host, ws_uri.port), **kwargs - ) + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection( + (ws_uri.host, ws_uri.port), + **kwargs, + ) sock.settimeout(None) # Disable Nagle algorithm @@ -420,3 +384,64 @@ def unix_connect( else: uri = "wss://localhost/" return connect(uri=uri, unix=True, path=path, **kwargs) + + +try: + from python_socks import ProxyType + from python_socks.sync import Proxy as SocksProxy + + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + **kwargs: Any, + ) -> socket.socket: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + kwargs.setdefault("timeout", deadline.timeout()) + return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + +except ImportError: + + def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + **kwargs: Any, + ) -> socket.socket: + raise ImportError("python-socks is required to use a SOCKS proxy") + + +def connect_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + **kwargs: Any, +) -> socket.socket: + """Connect via a proxy and return the socket.""" + # parse_proxy() validates proxy.scheme. + if proxy.scheme[:5] == "socks": + return connect_socks_proxy(proxy, ws_uri, deadline, **kwargs) + else: + raise AssertionError("unsupported proxy") From 321be894176262a74a0b020aa1327f2aaca728ca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 28 Jan 2025 22:42:11 +0100 Subject: [PATCH 719/762] Improve error handling for SOCKS proxy. --- docs/reference/exceptions.rst | 8 ++- src/websockets/__init__.py | 3 + src/websockets/asyncio/client.py | 17 +++++- src/websockets/exceptions.py | 41 ++++++++------ src/websockets/sync/client.py | 10 +++- tests/asyncio/test_client.py | 94 +++++++++++++++++++++++--------- tests/asyncio/test_server.py | 6 +- tests/sync/test_client.py | 94 +++++++++++++++++++++++--------- tests/sync/test_server.py | 4 +- tests/test_exceptions.py | 4 ++ 10 files changed, 203 insertions(+), 78 deletions(-) diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index e0c2efdd1..6c09a13fa 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -34,14 +34,16 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs. .. autoexception:: SecurityError -.. autoexception:: InvalidMessage - -.. autoexception:: InvalidStatus +.. autoexception:: ProxyError .. autoexception:: InvalidProxyMessage .. autoexception:: InvalidProxyStatus +.. autoexception:: InvalidMessage + +.. autoexception:: InvalidStatus + .. autoexception:: InvalidHeader .. autoexception:: InvalidHeaderFormat diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 8bf282a73..28a10910b 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -49,6 +49,7 @@ "NegotiationError", "PayloadTooBig", "ProtocolError", + "ProxyError", "SecurityError", "WebSocketException", # .frames @@ -112,6 +113,7 @@ NegotiationError, PayloadTooBig, ProtocolError, + ProxyError, SecurityError, WebSocketException, ) @@ -173,6 +175,7 @@ "NegotiationError": ".exceptions", "PayloadTooBig": ".exceptions", "ProtocolError": ".exceptions", + "ProxyError": ".exceptions", "SecurityError": ".exceptions", "WebSocketException": ".exceptions", # .frames diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 7052ca85a..9582a4bb9 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -12,7 +12,7 @@ from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike -from ..exceptions import InvalidMessage, InvalidStatus, SecurityError +from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -148,7 +148,9 @@ def process_exception(exc: Exception) -> Exception | None: That exception will be raised, breaking out of the retry loop. """ - if isinstance(exc, (OSError, asyncio.TimeoutError)): + # This catches python-socks' ProxyConnectionError and ProxyTimeoutError. + # Remove asyncio.TimeoutError when dropping Python < 3.11. + if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)): return None if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError): return None @@ -266,6 +268,7 @@ class connect: Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. @@ -622,7 +625,15 @@ async def connect_socks_proxy( proxy.password, SOCKS_PROXY_RDNS[proxy.scheme], ) - return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + # connect() is documented to raise OSError. + # socks_proxy.connect() doesn't raise TimeoutError; it gets canceled. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + except OSError: + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc except ImportError: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index e70aac92e..ab1a15ca8 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -9,11 +9,12 @@ * :exc:`InvalidProxy` * :exc:`InvalidHandshake` * :exc:`SecurityError` + * :exc:`ProxyError` + * :exc:`InvalidProxyMessage` + * :exc:`InvalidProxyStatus` * :exc:`InvalidMessage` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) - * :exc:`InvalidProxyMessage` - * :exc:`InvalidProxyStatus` * :exc:`InvalidHeader` * :exc:`InvalidHeaderFormat` * :exc:`InvalidHeaderValue` @@ -48,10 +49,11 @@ "InvalidProxy", "InvalidHandshake", "SecurityError", - "InvalidMessage", - "InvalidStatus", + "ProxyError", "InvalidProxyMessage", "InvalidProxyStatus", + "InvalidMessage", + "InvalidStatus", "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", @@ -206,16 +208,23 @@ class SecurityError(InvalidHandshake): """ -class InvalidMessage(InvalidHandshake): +class ProxyError(InvalidHandshake): """ - Raised when a handshake request or response is malformed. + Raised when failing to connect to a proxy. """ -class InvalidStatus(InvalidHandshake): +class InvalidProxyMessage(ProxyError): """ - Raised when a handshake response rejects the WebSocket upgrade. + Raised when an HTTP proxy response is malformed. + + """ + + +class InvalidProxyStatus(ProxyError): + """ + Raised when an HTTP proxy rejects the connection. """ @@ -223,21 +232,19 @@ def __init__(self, response: http11.Response) -> None: self.response = response def __str__(self) -> str: - return ( - f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" - ) + return f"proxy rejected connection: HTTP {self.response.status_code:d}" -class InvalidProxyMessage(InvalidHandshake): +class InvalidMessage(InvalidHandshake): """ - Raised when a proxy response is malformed. + Raised when a handshake request or response is malformed. """ -class InvalidProxyStatus(InvalidHandshake): +class InvalidStatus(InvalidHandshake): """ - Raised when a proxy rejects the connection. + Raised when a handshake response rejects the WebSocket upgrade. """ @@ -245,7 +252,9 @@ def __init__(self, response: http11.Response) -> None: self.response = response def __str__(self) -> str: - return f"proxy rejected connection: HTTP {self.response.status_code:d}" + return ( + f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" + ) class InvalidHeader(InvalidHandshake): diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 5fbec67ad..e2a287648 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -9,6 +9,7 @@ from ..client import ClientProtocol from ..datastructures import HeadersLike +from ..exceptions import ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -420,7 +421,14 @@ def connect_socks_proxy( SOCKS_PROXY_RDNS[proxy.scheme], ) kwargs.setdefault("timeout", deadline.timeout()) - return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + # connect() is documented to raise OSError and TimeoutError. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + except (OSError, TimeoutError, socket.timeout): + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc except ImportError: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index cb2b8ede6..bdd519fb8 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -17,6 +17,7 @@ InvalidProxy, InvalidStatus, InvalidURI, + ProxyError, SecurityError, ) from websockets.extensions.permessage_deflate import PerMessageDeflate @@ -379,24 +380,16 @@ def remove_accept_header(self, request, response): async def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - gate = asyncio.get_running_loop().create_future() - - async def stall_connection(self, request): - await gate - - # The connection will be open for the server but failed for the client. - # Use a connection handler that exits immediately to avoid an exception. - async with serve(*args, process_request=stall_connection) as server: - try: - with self.assertRaises(TimeoutError) as raised: - async with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "timed out during handshake", - ) - finally: - gate.set_result(None) + # Replace the WebSocket server with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + async with connect(f"ws://{host}:{port}", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) async def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" @@ -570,11 +563,14 @@ class ProxyClientTests(unittest.IsolatedAsyncioTestCase): async def socks_proxy(self, auth=None): if auth: proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:1080" + proxy_uri = "http://hello:iloveyou@localhost:51080" else: proxyauth = None - proxy_uri = "http://localhost:1080" - async with async_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows: + proxy_uri = "http://localhost:51080" + async with async_proxy( + mode=["socks5@51080"], + proxyauth=proxyauth, + ) as record_flows: with patch_environ({"socks_proxy": proxy_uri}): yield record_flows @@ -602,14 +598,62 @@ async def test_authenticated_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(len(proxy.get_flows()), 1) + async def test_socks_proxy_connection_error(self): + """Client receives an error when connecting to the SOCKS5 proxy.""" + from python_socks import ProxyError as SocksProxyError + + async with self.socks_proxy(auth=True) as proxy: + with self.assertRaises(ProxyError) as raised: + async with connect( + "ws://example.com/", + proxy="socks5h://localhost:51080", # remove credentials + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "failed to connect to SOCKS proxy", + ) + self.assertIsInstance(raised.exception.__cause__, SocksProxyError) + self.assertEqual(len(proxy.get_flows()), 0) + + async def test_socks_proxy_connection_fails(self): + """Client fails to connect to the SOCKS5 proxy.""" + from python_socks import ProxyConnectionError as SocksProxyConnectionError + + with self.assertRaises(OSError) as raised: + async with connect( + "ws://example.com/", + proxy="socks5h://localhost:51080", # nothing at this address + ): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyConnectionError) + + async def test_socks_proxy_connection_timeout(self): + """Client times out while connecting to the SOCKS5 proxy.""" + # Replace the proxy with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + async with connect( + "ws://example.com/", + proxy=f"socks5h://{host}:{port}/", + open_timeout=MS, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) + async def test_explicit_proxy(self): """Client connects to server through a proxy set explicitly.""" - async with async_proxy(mode=["socks5"]) as proxy: + async with async_proxy(mode=["socks5@51080"]) as proxy: async with serve(*args) as server: async with connect( get_uri(server), # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:1080", + proxy="socks5://localhost:51080", ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(len(proxy.get_flows()), 1) @@ -626,13 +670,13 @@ async def test_ignore_proxy_with_existing_socket(self): async def test_unsupported_proxy(self): """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:1080"}): + with patch_environ({"ws_proxy": "other://localhost:51080"}): with self.assertRaises(InvalidProxy) as raised: async with connect("ws://example.com/"): self.fail("did not raise") self.assertEqual( str(raised.exception), - "other://localhost:1080 isn't a valid proxy: scheme other isn't supported", + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 38c0315a1..6adfff8e9 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -65,9 +65,9 @@ async def test_connection_handler_raises_exception(self): async def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: - async with serve(handler, sock=sock, host=None, port=None): - uri = "ws://{}:{}/".format(*sock.getsockname()) - async with connect(uri) as client: + host, port = sock.getsockname() + async with serve(handler, sock=sock): + async with connect(f"ws://{host}:{port}/") as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") async def test_select_subprotocol(self): diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 2f62dd34d..dbecadcac 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -15,6 +15,7 @@ InvalidProxy, InvalidStatus, InvalidURI, + ProxyError, ) from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -148,24 +149,16 @@ def remove_accept_header(self, request, response): def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - gate = threading.Event() - - def stall_connection(self, request): - gate.wait() - - # The connection will be open for the server but failed for the client. - # Use a connection handler that exits immediately to avoid an exception. - with run_server(process_request=stall_connection) as server: - try: - with self.assertRaises(TimeoutError) as raised: - with connect(get_uri(server) + "/no-op", open_timeout=2 * MS): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "timed out during handshake", - ) - finally: - gate.set() + # Replace the WebSocket server with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + with connect(f"ws://{host}:{port}", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during handshake", + ) def test_connection_closed_during_handshake(self): """Client reads EOF before receiving handshake response from server.""" @@ -311,12 +304,15 @@ class ProxyClientTests(unittest.TestCase): def socks_proxy(self, auth=None): if auth: proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:1080" + proxy_uri = "http://hello:iloveyou@localhost:51080" else: proxyauth = None - proxy_uri = "http://localhost:1080" + proxy_uri = "http://localhost:51080" - with sync_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows: + with sync_proxy( + mode=["socks5@51080"], + proxyauth=proxyauth, + ) as record_flows: with patch_environ({"socks_proxy": proxy_uri}): yield record_flows @@ -344,14 +340,62 @@ def test_authenticated_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(len(proxy.get_flows()), 1) + def test_socks_proxy_connection_error(self): + """Client receives an error when connecting to the SOCKS5 proxy.""" + from python_socks import ProxyError as SocksProxyError + + with self.socks_proxy(auth=True) as proxy: + with self.assertRaises(ProxyError) as raised: + with connect( + "ws://example.com/", + proxy="socks5h://localhost:51080", # remove credentials + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "failed to connect to SOCKS proxy", + ) + self.assertIsInstance(raised.exception.__cause__, SocksProxyError) + self.assertEqual(len(proxy.get_flows()), 0) + + def test_socks_proxy_connection_fails(self): + """Client fails to connect to the SOCKS5 proxy.""" + from python_socks import ProxyConnectionError as SocksProxyConnectionError + + with self.assertRaises(OSError) as raised: + with connect( + "ws://example.com/", + proxy="socks5h://localhost:51080", # nothing at this address + ): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyConnectionError) + + def test_socks_proxy_timeout(self): + """Client times out before connecting to the SOCKS5 proxy.""" + from python_socks import ProxyTimeoutError as SocksProxyTimeoutError + + # Replace the proxy with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + with connect( + "ws://example.com/", + proxy=f"socks5h://{host}:{port}/", + open_timeout=MS, + ): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyTimeoutError) + def test_explicit_proxy(self): """Client connects to server through a proxy set explicitly.""" - with sync_proxy(mode=["socks5"]) as proxy: + with sync_proxy(mode=["socks5@51080"]) as proxy: with run_server() as server: with connect( get_uri(server), # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:1080", + proxy="socks5://localhost:51080", ) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertEqual(len(proxy.get_flows()), 1) @@ -368,13 +412,13 @@ def test_ignore_proxy_with_existing_socket(self): def test_unsupported_proxy(self): """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:1080"}): + with patch_environ({"ws_proxy": "other://localhost:51080"}): with self.assertRaises(InvalidProxy) as raised: with connect("ws://example.com/"): self.fail("did not raise") self.assertEqual( str(raised.exception), - "other://localhost:1080 isn't a valid proxy: scheme other isn't supported", + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", ) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index f59671efd..d04d1859a 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -64,9 +64,9 @@ def test_connection_handler_raises_exception(self): def test_existing_socket(self): """Server receives connection using a pre-existing socket.""" with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() with run_server(sock=sock): - uri = "ws://{}:{}/".format(*sock.getsockname()) - with connect(uri) as client: + with connect(f"ws://{host}:{port}/") as client: self.assertEval(client, "ws.protocol.state.name", "OPEN") def test_select_subprotocol(self): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8b437ab5e..b4e7acee7 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -95,6 +95,10 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), + ( + ProxyError("failed to connect to SOCKS proxy"), + "failed to connect to SOCKS proxy", + ), ( InvalidMessage("malformed HTTP message"), "malformed HTTP message", From a00c18436a8c1500928f032f7131fa226f745334 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Jan 2025 21:38:33 +0100 Subject: [PATCH 720/762] Support the legacy pattern for setting User-Agent. Support setting it with additional_headers for backwards compatibility with the legacy implementation and the API until websoockets 10.4. Before this change, the header was added twice: once with the custom value and once with the default value. Fix #1583. --- docs/howto/upgrade.rst | 14 ++++++++++---- src/websockets/asyncio/client.py | 2 +- src/websockets/sync/client.py | 2 +- tests/asyncio/test_client.py | 8 ++++++++ tests/sync/test_client.py | 8 ++++++++ 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index db6bf11f1..8cfd7b4b5 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -259,9 +259,12 @@ Arguments of :func:`~asyncio.client.connect` ``extra_headers`` → ``additional_headers`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you're adding headers to the handshake request sent by -:func:`~legacy.client.connect` with the ``extra_headers`` argument, you must -rename it to ``additional_headers``. +If you're setting the ``User-Agent`` header with the ``extra_headers`` argument, +you should set it with ``user_agent_header`` instead. + +If you're adding other headers to the handshake request sent by +:func:`~legacy.client.connect` with ``extra_headers``, you must rename it to +``additional_headers``. Arguments of :func:`~asyncio.server.serve` .......................................... @@ -310,7 +313,10 @@ replace it with a ``process_request`` function or coroutine. ``extra_headers`` → ``process_response`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you're adding headers to the handshake response sent by +If you're setting the ``Server`` header with ``extra_headers``, you should set +it with the ``server_header`` argument instead. + +If you're adding other headers to the handshake response sent by :func:`~legacy.server.serve` with the ``extra_headers`` argument, you must write a ``process_response`` callable instead. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 9582a4bb9..1e560fe0c 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -89,7 +89,7 @@ async def handshake( if additional_headers is not None: self.request.headers.update(additional_headers) if user_agent_header: - self.request.headers["User-Agent"] = user_agent_header + self.request.headers.setdefault("User-Agent", user_agent_header) self.protocol.send_request(self.request) await asyncio.wait( diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index e2a287648..b7ab83664 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -86,7 +86,7 @@ def handshake( if additional_headers is not None: self.request.headers.update(additional_headers) if user_agent_header is not None: - self.request.headers["User-Agent"] = user_agent_header + self.request.headers.setdefault("User-Agent", user_agent_header) self.protocol.send_request(self.request) if not self.response_rcvd.wait(timeout): diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index bdd519fb8..9c7ee46ad 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -121,6 +121,14 @@ async def test_remove_user_agent(self): async with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) + async def test_legacy_user_agent(self): + """Client can override User-Agent header with additional_headers.""" + async with serve(*args) as server: + async with connect( + get_uri(server), additional_headers={"User-Agent": "Smith"} + ) as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + async def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" async with serve(*args) as server: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index dbecadcac..4844d3b5e 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -82,6 +82,14 @@ def test_remove_user_agent(self): with connect(get_uri(server), user_agent_header=None) as client: self.assertNotIn("User-Agent", client.request.headers) + def test_legacy_user_agent(self): + """Client can override User-Agent header with additional_headers.""" + with run_server() as server: + with connect( + get_uri(server), additional_headers={"User-Agent": "Smith"} + ) as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + def test_keepalive_is_enabled(self): """Client enables keepalive and measures latency by default.""" with run_server() as server: From 7bb226a4aff8479994851a1cc83fc1aaf21a3724 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 29 Jan 2025 21:46:45 +0100 Subject: [PATCH 721/762] Standardize spelling of canceling/canceled. --- docs/project/changelog.rst | 2 +- src/websockets/asyncio/messages.py | 4 ++-- tests/asyncio/test_connection.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 2a429b43e..bfbfa793f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -84,7 +84,7 @@ Bug fixes :mod:`threading` implementation. If a message is already received, it is returned. Previously, :exc:`TimeoutError` was raised incorrectly. -* Fixed a crash in the :mod:`asyncio` implementation when cancelling a ping +* Fixed a crash in the :mod:`asyncio` implementation when canceling a ping then receiving the corresponding pong. * Prevented :meth:`~asyncio.connection.Connection.close` from blocking when diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 581870037..1fd41811c 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -152,7 +152,7 @@ async def get(self, decode: bool | None = None) -> Data: self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution - # until get() fetches a complete message or is cancelled. + # until get() fetches a complete message or is canceled. try: # First frame @@ -224,7 +224,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: self.get_in_progress = True # Locking with get_in_progress prevents concurrent execution - # until get_iter() fetches a complete message or is cancelled. + # until get_iter() fetches a complete message or is canceled. # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 5230eca89..668f55cbd 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -245,7 +245,7 @@ async def test_recv_during_recv_streaming(self): ) async def test_recv_cancellation_before_receiving(self): - """recv can be cancelled before receiving a frame.""" + """recv can be canceled before receiving a frame.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task @@ -257,7 +257,7 @@ async def test_recv_cancellation_before_receiving(self): self.assertEqual(await self.connection.recv(), "😀") async def test_recv_cancellation_while_receiving(self): - """recv cannot be cancelled after receiving a frame.""" + """recv cannot be canceled after receiving a frame.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task @@ -386,7 +386,7 @@ async def test_recv_streaming_during_recv_streaming(self): ) async def test_recv_streaming_cancellation_before_receiving(self): - """recv_streaming can be cancelled before receiving a frame.""" + """recv_streaming can be canceled before receiving a frame.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) @@ -403,7 +403,7 @@ async def test_recv_streaming_cancellation_before_receiving(self): ) async def test_recv_streaming_cancellation_while_receiving(self): - """recv_streaming cannot be cancelled after receiving a frame.""" + """recv_streaming cannot be canceled after receiving a frame.""" recv_streaming_task = asyncio.create_task( alist(self.connection.recv_streaming()) ) From 737d5ffb5121531f4c1efba2a25c3fd7db76b868 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 31 Jan 2025 21:37:15 +0100 Subject: [PATCH 722/762] Run mitmproxy only once per test case. This reduces the run time of the test suite by 40%, from 6.7s to 4.1s. --- tests/asyncio/test_client.py | 146 +++++++++++++++------------------ tests/proxy.py | 132 ++++++++++++++++-------------- tests/sync/test_client.py | 152 ++++++++++++++++------------------- 3 files changed, 207 insertions(+), 223 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 9c7ee46ad..be8ef8a42 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -22,7 +22,7 @@ ) from websockets.extensions.permessage_deflate import PerMessageDeflate -from ..proxy import async_proxy +from ..proxy import ProxyMixin from ..utils import ( CLIENT_CONTEXT, MS, @@ -388,7 +388,7 @@ def remove_accept_header(self, request, response): async def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - # Replace the WebSocket server with a TCP server that does't respond. + # Replace the WebSocket server with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with self.assertRaises(TimeoutError) as raised: @@ -508,7 +508,7 @@ async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate isn't trusted system-wide. + # The test certificate is self-signed. async with connect(get_uri(server)): self.fail("did not raise") self.assertIn( @@ -566,126 +566,105 @@ def redirect(connection, request): @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") -class ProxyClientTests(unittest.IsolatedAsyncioTestCase): - @contextlib.asynccontextmanager - async def socks_proxy(self, auth=None): - if auth: - proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:51080" - else: - proxyauth = None - proxy_uri = "http://localhost:51080" - async with async_proxy( - mode=["socks5@51080"], - proxyauth=proxyauth, - ) as record_flows: - with patch_environ({"socks_proxy": proxy_uri}): - yield record_flows +class SocksProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "socks5@51080" async def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - async with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + self.assertNumFlows(1) async def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - async with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + self.assertNumFlows(1) async def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" - async with self.socks_proxy(auth=True) as proxy: - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"socks_proxy": "http://hello:iloveyou@localhost:51080"} + ): + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) - async def test_socks_proxy_connection_error(self): - """Client receives an error when connecting to the SOCKS5 proxy.""" + async def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError - async with self.socks_proxy(auth=True) as proxy: - with self.assertRaises(ProxyError) as raised: - async with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # remove credentials - ): - self.fail("did not raise") + try: + self.proxy_options.update(proxyauth="any") + with patch_environ({"socks_proxy": "http://localhost:51080"}): + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "failed to connect to SOCKS proxy", ) self.assertIsInstance(raised.exception.__cause__, SocksProxyError) - self.assertEqual(len(proxy.get_flows()), 0) + self.assertNumFlows(0) - async def test_socks_proxy_connection_fails(self): + async def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with self.assertRaises(OSError) as raised: - async with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # nothing at this address - ): - self.fail("did not raise") + with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) async def test_socks_proxy_connection_timeout(self): """Client times out while connecting to the SOCKS5 proxy.""" - # Replace the proxy with a TCP server that does't respond. + # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with self.assertRaises(TimeoutError) as raised: - async with connect( - "ws://example.com/", - proxy=f"socks5h://{host}:{port}/", - open_timeout=MS, - ): - self.fail("did not raise") + with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") self.assertEqual( str(raised.exception), "timed out during handshake", ) + self.assertNumFlows(0) - async def test_explicit_proxy(self): - """Client connects to server through a proxy set explicitly.""" - async with async_proxy(mode=["socks5@51080"]) as proxy: - async with serve(*args) as server: - async with connect( - get_uri(server), - # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:51080", - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + async def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + async with serve(*args) as server: + async with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - async with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: # Use a non-existing domain to ensure we connect to sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 0) - - async def test_unsupported_proxy(self): - """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:51080"}): - with self.assertRaises(InvalidProxy) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", - ) + self.assertNumFlows(0) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -724,10 +703,7 @@ def redirect(connection, request): "cannot follow cross-origin redirect to ws://other/ with a Unix socket", ) - -@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") -class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): - async def test_connection(self): + async def test_secure_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: async with unix_serve(handler, path, ssl=SERVER_CONTEXT): @@ -769,6 +745,16 @@ async def test_secure_uri_without_ssl(self): "ssl=None is incompatible with a wss:// URI", ) + async def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + async with connect("ws://example.com/", proxy="other://localhost:51080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", + ) + async def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: diff --git a/tests/proxy.py b/tests/proxy.py index 95525a360..8b62b0eb7 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -1,15 +1,14 @@ import asyncio -import contextlib import pathlib import threading import warnings -warnings.filterwarnings("ignore", category=DeprecationWarning, module="mitmproxy") -warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") -warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") - try: + # Ignore deprecation warnings raised by mitmproxy dependencies at import time. + warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") + warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") + from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig from mitmproxy.master import Master from mitmproxy.options import Options @@ -18,13 +17,10 @@ class RecordFlows: - def __init__(self): - self.ready = asyncio.get_running_loop().create_future() + def __init__(self, on_running): + self.running = on_running self.flows = [] - def running(self): - self.ready.set_result(None) - def websocket_start(self, flow): self.flows.append(flow) @@ -32,58 +28,76 @@ def get_flows(self): flows, self.flows[:] = self.flows[:], [] return flows + def reset_flows(self): + self.flows = [] + + +class ProxyMixin: + """ + Run mitmproxy in a background thread. + + While it's uncommon to run two event loops in two threads, tests for the + asyncio implementation rely on this class too because it starts an event + loop for mitm proxy once, then a new event loop for each test. + """ + + proxy_mode = None + + @classmethod + async def run_proxy(cls): + cls.proxy_loop = loop = asyncio.get_event_loop() + cls.proxy_stop = stop = loop.create_future() + + cls.proxy_options = options = Options(mode=[cls.proxy_mode]) + cls.proxy_master = master = Master(options) + master.addons.add( + core.Core(), + proxyauth.ProxyAuth(), + proxyserver.Proxyserver(), + next_layer.NextLayer(), + tlsconfig.TlsConfig(), + RecordFlows(on_running=cls.proxy_ready.set), + ) + options.update( + # Use test certificate for TLS between client and proxy. + certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], + # Disable TLS verification between proxy and upstream. + ssl_insecure=True, + ) + + task = loop.create_task(cls.proxy_master.run()) + await stop -@contextlib.asynccontextmanager -async def async_proxy(mode, **config): - options = Options(mode=mode) - master = Master(options) - record_flows = RecordFlows() - master.addons.add( - core.Core(), - proxyauth.ProxyAuth(), - proxyserver.Proxyserver(), - next_layer.NextLayer(), - tlsconfig.TlsConfig(), - record_flows, - ) - config.update( - # Use our test certificate for TLS between client and proxy - # and disable TLS verification between proxy and upstream. - certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], - ssl_insecure=True, - ) - options.update(**config) - - asyncio.create_task(master.run()) - try: - await record_flows.ready - yield record_flows - finally: for server in master.addons.get("proxyserver").servers: await server.stop() master.shutdown() + await task + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # Ignore deprecation warnings raised by mitmproxy at run time. + warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="mitmproxy" + ) + + cls.proxy_ready = threading.Event() + cls.proxy_thread = threading.Thread(target=asyncio.run, args=(cls.run_proxy(),)) + cls.proxy_thread.start() + cls.proxy_ready.wait() + + def assertNumFlows(self, num_flows): + record_flows = self.proxy_master.addons.get("recordflows") + self.assertEqual(len(record_flows.get_flows()), num_flows) + def tearDown(self): + record_flows = self.proxy_master.addons.get("recordflows") + record_flows.reset_flows() + super().tearDown() -@contextlib.contextmanager -def sync_proxy(mode, **config): - loop = None - test_done = None - proxy_ready = threading.Event() - record_flows = None - - async def proxy_coroutine(): - nonlocal loop, test_done, proxy_ready, record_flows - loop = asyncio.get_running_loop() - test_done = loop.create_future() - async with async_proxy(mode, **config) as record_flows: - proxy_ready.set() - await test_done - - proxy_thread = threading.Thread(target=asyncio.run, args=(proxy_coroutine(),)) - proxy_thread.start() - try: - proxy_ready.wait() - yield record_flows - finally: - loop.call_soon_threadsafe(test_done.set_result, None) - proxy_thread.join() + @classmethod + def tearDownClass(cls): + cls.proxy_loop.call_soon_threadsafe(cls.proxy_stop.set_result, None) + cls.proxy_thread.join() + super().tearDownClass() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 4844d3b5e..38cfab7b3 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,4 +1,3 @@ -import contextlib import http import logging import socket @@ -20,7 +19,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..proxy import sync_proxy +from ..proxy import ProxyMixin from ..utils import ( CLIENT_CONTEXT, MS, @@ -157,7 +156,7 @@ def remove_accept_header(self, request, response): def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - # Replace the WebSocket server with a TCP server that does't respond. + # Replace the WebSocket server with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with self.assertRaises(TimeoutError) as raised: @@ -283,7 +282,7 @@ def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate isn't trusted system-wide. + # The test certificate is self-signed. with connect(get_uri(server)): self.fail("did not raise") self.assertIn( @@ -307,127 +306,105 @@ def test_reject_invalid_server_hostname(self): @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") -class ProxyClientTests(unittest.TestCase): - @contextlib.contextmanager - def socks_proxy(self, auth=None): - if auth: - proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:51080" - else: - proxyauth = None - proxy_uri = "http://localhost:51080" - - with sync_proxy( - mode=["socks5@51080"], - proxyauth=proxyauth, - ) as record_flows: - with patch_environ({"socks_proxy": proxy_uri}): - yield record_flows +class SocksProxyClientTests(ProxyMixin, unittest.TestCase): + proxy_mode = "socks5@51080" def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + self.assertNumFlows(1) def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with self.socks_proxy() as proxy: + with patch_environ({"socks_proxy": "http://localhost:51080"}): with run_server(ssl=SERVER_CONTEXT) as server: with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + self.assertNumFlows(1) def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" - with self.socks_proxy(auth=True) as proxy: - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"socks_proxy": "http://hello:iloveyou@localhost:51080"} + ): + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) - def test_socks_proxy_connection_error(self): - """Client receives an error when connecting to the SOCKS5 proxy.""" + def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError - with self.socks_proxy(auth=True) as proxy: - with self.assertRaises(ProxyError) as raised: - with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # remove credentials - ): - self.fail("did not raise") + try: + self.proxy_options.update(proxyauth="any") + with patch_environ({"socks_proxy": "http://localhost:51080"}): + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "failed to connect to SOCKS proxy", ) self.assertIsInstance(raised.exception.__cause__, SocksProxyError) - self.assertEqual(len(proxy.get_flows()), 0) + self.assertNumFlows(0) - def test_socks_proxy_connection_fails(self): + def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with self.assertRaises(OSError) as raised: - with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # nothing at this address - ): - self.fail("did not raise") + with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) - def test_socks_proxy_timeout(self): - """Client times out before connecting to the SOCKS5 proxy.""" + def test_socks_proxy_connection_timeout(self): + """Client times out while connecting to the SOCKS5 proxy.""" from python_socks import ProxyTimeoutError as SocksProxyTimeoutError - # Replace the proxy with a TCP server that does't respond. + # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with self.assertRaises(TimeoutError) as raised: - with connect( - "ws://example.com/", - proxy=f"socks5h://{host}:{port}/", - open_timeout=MS, - ): - self.fail("did not raise") + with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyTimeoutError) + self.assertNumFlows(0) - def test_explicit_proxy(self): - """Client connects to server through a proxy set explicitly.""" - with sync_proxy(mode=["socks5@51080"]) as proxy: - with run_server() as server: - with connect( - get_uri(server), - # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:51080", - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + with run_server() as server: + with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with self.socks_proxy() as proxy: + with patch_environ({"ws_proxy": "http://localhost:58080"}): with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: # Use a non-existing domain to ensure we connect to sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 0) - - def test_unsupported_proxy(self): - """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:51080"}): - with self.assertRaises(InvalidProxy) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", - ) + self.assertNumFlows(0) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -447,10 +424,7 @@ def test_set_host_header(self): with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") - -@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") -class SecureUnixClientTests(unittest.TestCase): - def test_connection(self): + def test_secure_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): @@ -488,6 +462,16 @@ def test_unix_without_path_or_sock(self): "missing path argument", ) + def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + with connect("ws://example.com/", proxy="other://localhost:58080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:58080 isn't a valid proxy: scheme other isn't supported", + ) + def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) From 0eab9b82959cdcea658ff5ab943f03503caa082b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 31 Jan 2025 21:54:06 +0100 Subject: [PATCH 723/762] Simplify SOCKS proxy implementation. --- src/websockets/asyncio/client.py | 37 ++++++++++++++++---------------- src/websockets/sync/client.py | 35 +++++++++++------------------- 2 files changed, 31 insertions(+), 41 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 1e560fe0c..a3fcab039 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -383,16 +383,28 @@ def factory() -> ClientConnection: if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) elif proxy is not None: - kwargs["sock"] = await connect_proxy( - parse_proxy(proxy), - ws_uri, - local_addr=kwargs.pop("local_addr", None), - ) - _, connection = await loop.create_connection(factory, **kwargs) + proxy_parsed = parse_proxy(proxy) + if proxy_parsed.scheme[:5] == "socks": + # Connect to the server through the proxy. + sock = await connect_socks_proxy( + proxy_parsed, + ws_uri, + local_addr=kwargs.pop("local_addr", None), + ) + # Initialize WebSocket connection via the proxy. + _, connection = await loop.create_connection( + factory, + sock=sock, + **kwargs, + ) + else: + raise AssertionError("unsupported proxy") else: + # Connect to the server directly. if kwargs.get("sock") is None: kwargs.setdefault("host", ws_uri.host) kwargs.setdefault("port", ws_uri.port) + # Initialize WebSocket connection. _, connection = await loop.create_connection(factory, **kwargs) return connection @@ -643,16 +655,3 @@ async def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") - - -async def connect_proxy( - proxy: Proxy, - ws_uri: WebSocketURI, - **kwargs: Any, -) -> socket.socket: - """Connect via a proxy and return the socket.""" - # parse_proxy() validates proxy.scheme. - if proxy.scheme[:5] == "socks": - return await connect_socks_proxy(proxy, ws_uri, **kwargs) - else: - raise AssertionError("unsupported proxy") diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index b7ab83664..722def319 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -284,14 +284,19 @@ def connect( assert path is not None # mypy cannot figure this out sock.connect(path) elif proxy is not None: - sock = connect_proxy( - parse_proxy(proxy), - ws_uri, - deadline, - # websockets is consistent with the socket module while - # python_socks is consistent across implementations. - local_addr=kwargs.pop("source_address", None), - ) + proxy_parsed = parse_proxy(proxy) + if proxy_parsed.scheme[:5] == "socks": + # Connect to the server through the proxy. + sock = connect_socks_proxy( + proxy_parsed, + ws_uri, + deadline, + # websockets is consistent with the socket module while + # python_socks is consistent across implementations. + local_addr=kwargs.pop("source_address", None), + ) + else: + raise AssertionError("unsupported proxy") else: kwargs.setdefault("timeout", deadline.timeout()) sock = socket.create_connection( @@ -439,17 +444,3 @@ def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") - - -def connect_proxy( - proxy: Proxy, - ws_uri: WebSocketURI, - deadline: Deadline, - **kwargs: Any, -) -> socket.socket: - """Connect via a proxy and return the socket.""" - # parse_proxy() validates proxy.scheme. - if proxy.scheme[:5] == "socks": - return connect_socks_proxy(proxy, ws_uri, deadline, **kwargs) - else: - raise AssertionError("unsupported proxy") From 3def92ead354c8bd9dc5e2901e82fbfa43902485 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 12:12:11 +0100 Subject: [PATCH 724/762] Don't intercept TLS connections in tests. This avoids mixing TLS termination by mitmproxy and by websockets. --- tests/proxy.py | 17 ++++++++--------- tox.ini | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/proxy.py b/tests/proxy.py index 8b62b0eb7..6cae7b49e 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -1,5 +1,4 @@ import asyncio -import pathlib import threading import warnings @@ -21,7 +20,7 @@ def __init__(self, on_running): self.running = on_running self.flows = [] - def websocket_start(self, flow): + def tcp_start(self, flow): self.flows.append(flow) def get_flows(self): @@ -48,7 +47,13 @@ async def run_proxy(cls): cls.proxy_loop = loop = asyncio.get_event_loop() cls.proxy_stop = stop = loop.create_future() - cls.proxy_options = options = Options(mode=[cls.proxy_mode]) + cls.proxy_options = options = Options( + mode=[cls.proxy_mode], + # Don't intercept connections, but record them. + ignore_hosts=["^localhost:", "^127.0.0.1:", "^::1:"], + # This option requires mitmproxy 11.0.0, which requires Python 3.11. + show_ignored_hosts=True, + ) cls.proxy_master = master = Master(options) master.addons.add( core.Core(), @@ -58,12 +63,6 @@ async def run_proxy(cls): tlsconfig.TlsConfig(), RecordFlows(on_running=cls.proxy_ready.set), ) - options.update( - # Use test certificate for TLS between client and proxy. - certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], - # Disable TLS verification between proxy and upstream. - ssl_insecure=True, - ) task = loop.create_task(cls.proxy_master.run()) await stop diff --git a/tox.ini b/tox.ini index f5a2f5d3c..918aeaaec 100644 --- a/tox.ini +++ b/tox.ini @@ -15,8 +15,8 @@ commands = pass_env = WEBSOCKETS_* deps = - mitmproxy - python-socks[asyncio] + py311,py312,py313,coverage,maxi_cov: mitmproxy + py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] [testenv:coverage] commands = From ba7104caa7540b29db7c02aaf9f07c95691e42f9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 21:28:58 +0100 Subject: [PATCH 725/762] Clarify timeout errors. Specifically, not hiding the __cause__ of TimeoutError makes it visible when it happens while connecting to a proxy. --- src/websockets/asyncio/client.py | 4 ++-- src/websockets/sync/client.py | 2 +- src/websockets/sync/server.py | 2 +- tests/asyncio/test_client.py | 4 ++-- tests/sync/test_client.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index a3fcab039..058313388 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -511,9 +511,9 @@ async def __await_impl__(self) -> ClientConnection: else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") - except TimeoutError: + except TimeoutError as exc: # Re-raise exception with an informative error message. - raise TimeoutError("timed out during handshake") from None + raise TimeoutError("timed out during opening handshake") from exc # ... = yield from connect(...) - remove when dropping Python < 3.10 diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 722def319..8ce9e7d84 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -90,7 +90,7 @@ def handshake( self.protocol.send_request(self.request) if not self.response_rcvd.wait(timeout): - raise TimeoutError("timed out during handshake") + raise TimeoutError("timed out while waiting for handshake response") # self.protocol.handshake_exc is set when the connection is lost before # receiving a response, when the response cannot be parsed, or when the diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 2b753b2c5..10e3b6816 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -128,7 +128,7 @@ def handshake( """ if not self.request_rcvd.wait(timeout): - raise TimeoutError("timed out during handshake") + raise TimeoutError("timed out while waiting for handshake request") if self.request is not None: with self.send_context(expected_state=CONNECTING): diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index be8ef8a42..3728c7734 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -396,7 +396,7 @@ async def test_timeout_during_handshake(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out during opening handshake", ) async def test_connection_closed_during_handshake(self): @@ -641,7 +641,7 @@ async def test_socks_proxy_connection_timeout(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out during opening handshake", ) self.assertNumFlows(0) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 38cfab7b3..47886e015 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -164,7 +164,7 @@ def test_timeout_during_handshake(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out while waiting for handshake response", ) def test_connection_closed_during_handshake(self): From 6c15b9c8b039d9da610d0a1563a302cc54cf62b0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 25 Jan 2025 13:35:19 +0100 Subject: [PATCH 726/762] Add option to always include port in build_host helper. --- src/websockets/headers.py | 10 +++++++-- tests/test_headers.py | 43 +++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/websockets/headers.py b/src/websockets/headers.py index e05948a1f..c42abd976 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -36,7 +36,13 @@ T = TypeVar("T") -def build_host(host: str, port: int, secure: bool) -> str: +def build_host( + host: str, + port: int, + secure: bool, + *, + always_include_port: bool = False, +) -> str: """ Build a ``Host`` header. @@ -53,7 +59,7 @@ def build_host(host: str, port: int, secure: bool) -> str: if address.version == 6: host = f"[{host}]" - if port != (443 if secure else 80): + if always_include_port or port != (443 if secure else 80): host = f"{host}:{port}" return host diff --git a/tests/test_headers.py b/tests/test_headers.py index 4ebd8b90c..816afc541 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -6,26 +6,33 @@ class HeadersTests(unittest.TestCase): def test_build_host(self): - for (host, port, secure), result in [ - (("localhost", 80, False), "localhost"), - (("localhost", 8000, False), "localhost:8000"), - (("localhost", 443, True), "localhost"), - (("localhost", 8443, True), "localhost:8443"), - (("example.com", 80, False), "example.com"), - (("example.com", 8000, False), "example.com:8000"), - (("example.com", 443, True), "example.com"), - (("example.com", 8443, True), "example.com:8443"), - (("127.0.0.1", 80, False), "127.0.0.1"), - (("127.0.0.1", 8000, False), "127.0.0.1:8000"), - (("127.0.0.1", 443, True), "127.0.0.1"), - (("127.0.0.1", 8443, True), "127.0.0.1:8443"), - (("::1", 80, False), "[::1]"), - (("::1", 8000, False), "[::1]:8000"), - (("::1", 443, True), "[::1]"), - (("::1", 8443, True), "[::1]:8443"), + for (host, port, secure), (result, result_with_port) in [ + (("localhost", 80, False), ("localhost", "localhost:80")), + (("localhost", 8000, False), ("localhost:8000", "localhost:8000")), + (("localhost", 443, True), ("localhost", "localhost:443")), + (("localhost", 8443, True), ("localhost:8443", "localhost:8443")), + (("example.com", 80, False), ("example.com", "example.com:80")), + (("example.com", 8000, False), ("example.com:8000", "example.com:8000")), + (("example.com", 443, True), ("example.com", "example.com:443")), + (("example.com", 8443, True), ("example.com:8443", "example.com:8443")), + (("127.0.0.1", 80, False), ("127.0.0.1", "127.0.0.1:80")), + (("127.0.0.1", 8000, False), ("127.0.0.1:8000", "127.0.0.1:8000")), + (("127.0.0.1", 443, True), ("127.0.0.1", "127.0.0.1:443")), + (("127.0.0.1", 8443, True), ("127.0.0.1:8443", "127.0.0.1:8443")), + (("::1", 80, False), ("[::1]", "[::1]:80")), + (("::1", 8000, False), ("[::1]:8000", "[::1]:8000")), + (("::1", 443, True), ("[::1]", "[::1]:443")), + (("::1", 8443, True), ("[::1]:8443", "[::1]:8443")), ]: with self.subTest(host=host, port=port, secure=secure): - self.assertEqual(build_host(host, port, secure), result) + self.assertEqual( + build_host(host, port, secure), + result, + ) + self.assertEqual( + build_host(host, port, secure, always_include_port=True), + result_with_port, + ) def test_parse_connection(self): for header, parsed in [ From e5e85d21a07995ac993c325c7b2c38441c2cbf83 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 31 Jan 2025 22:16:21 +0100 Subject: [PATCH 727/762] Add option not to read body in Response.parse. This allows keeping the connection open after reading the response to a CONNECT request. --- src/websockets/http11.py | 10 +++++++--- tests/test_http11.py | 8 +++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 49d7b9a41..530ac3d09 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -210,6 +210,7 @@ def parse( read_line: Callable[[int], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], read_to_eof: Callable[[int], Generator[None, None, bytes]], + include_body: bool = True, ) -> Generator[None, None, Response]: """ Parse a WebSocket handshake response. @@ -265,9 +266,12 @@ def parse( headers = yield from parse_headers(read_line) - body = yield from read_body( - status_code, headers, read_line, read_exact, read_to_eof - ) + if include_body: + body = yield from read_body( + status_code, headers, read_line, read_exact, read_to_eof + ) + else: + body = b"" return cls(status_code, reason, headers, body) diff --git a/tests/test_http11.py b/tests/test_http11.py index bb0d27b95..3afb6d02c 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -130,11 +130,12 @@ def setUp(self): super().setUp() self.reader = StreamReader() - def parse(self): + def parse(self, **kwargs): return Response.parse( self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, + **kwargs, ) def test_parse(self): @@ -322,6 +323,11 @@ def test_parse_body_not_modified(self): response = self.assertGeneratorReturns(self.parse()) self.assertEqual(response.body, b"") + def test_parse_without_body(self): + self.reader.feed_data(b"HTTP/1.1 200 Connection Established\r\n\r\n") + response = self.assertGeneratorReturns(self.parse(include_body=False)) + self.assertEqual(response.body, b"") + def test_serialize(self): # Example from the protocol overview in RFC 6455 response = Response( From ec706c922deeaf6fcf98174d4b2f16a6819117c7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 26 Jan 2025 21:21:49 +0100 Subject: [PATCH 728/762] Add support for HTTP(S) proxies. Fix #364. --- docs/project/changelog.rst | 4 +- docs/reference/features.rst | 3 +- docs/topics/proxies.rst | 19 +++ src/websockets/asyncio/client.py | 154 ++++++++++++++++++++++- src/websockets/sync/client.py | 204 ++++++++++++++++++++++++++++++- tests/asyncio/test_client.py | 198 +++++++++++++++++++++++++++++- tests/proxy.py | 37 +++++- tests/sync/test_client.py | 184 ++++++++++++++++++++++++++++ tests/utils.py | 7 +- 9 files changed, 791 insertions(+), 19 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index bfbfa793f..7bb94b349 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,12 +35,12 @@ notice. Backwards-incompatible changes .............................. -.. admonition:: Client connections use SOCKS proxies automatically. +.. admonition:: Client connections use SOCKS and HTTP proxies automatically. :class: important If a proxy is configured in the operating system or with an environment variable, websockets uses it automatically when connecting to a server. - This feature requires installing the third-party library `python-socks`_. + SOCKS proxies require installing the third-party library `python-socks`_. If you want to disable the proxy, add ``proxy=None`` when calling :func:`~asyncio.client.connect`. See :doc:`../topics/proxies` for details. diff --git a/docs/reference/features.rst b/docs/reference/features.rst index eaecd02a9..93b083d20 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -166,12 +166,11 @@ Client | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | | (`#784`_) | | | | | +------------------------------------+--------+--------+--------+--------+ - | Connect via HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ -.. _#364: https://github.com/python-websockets/websockets/issues/364 .. _#784: https://github.com/python-websockets/websockets/issues/784 Known limitations diff --git a/docs/topics/proxies.rst b/docs/topics/proxies.rst index fd3ae78b6..14fc68c0c 100644 --- a/docs/topics/proxies.rst +++ b/docs/topics/proxies.rst @@ -30,6 +30,9 @@ most common, for `historical reasons`_, and recommended. .. _historical reasons: https://unix.stackexchange.com/questions/212894/ +websockets authenticates automatically when the address of the proxy includes +credentials e.g. ``http://user:password@proxy:8080/``. + .. admonition:: Any environment variable can configure a SOCKS proxy or an HTTP proxy. :class: tip @@ -64,3 +67,19 @@ SOCKS proxy is configured in the operating system, python-socks uses SOCKS5h. python-socks supports username/password authentication for SOCKS5 (:rfc:`1929`) but does not support other authentication methods such as GSSAPI (:rfc:`1961`). + +HTTP proxies +------------ + +When the address of the proxy starts with ``https://``, websockets secures the +connection to the proxy with TLS. + +When the address of the server starts with ``wss://``, websockets secures the +connection from the proxy to the server with TLS. + +These two options are compatible. TLS-in-TLS is supported. + +The documentation of :func:`~asyncio.client.connect` describes how to configure +TLS from websockets to the proxy and from the proxy to the server. + +websockets supports proxy authentication with Basic Auth. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 058313388..c19a53f8c 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -4,20 +4,29 @@ import logging import os import socket +import ssl as ssl_module import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, cast from ..client import ClientProtocol, backoff -from ..datastructures import HeadersLike -from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError +from ..datastructures import Headers, HeadersLike +from ..exceptions import ( + InvalidMessage, + InvalidProxyMessage, + InvalidProxyStatus, + InvalidStatus, + ProxyError, + SecurityError, +) from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import validate_subprotocols +from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .compatibility import TimeoutError, asyncio_timeout @@ -266,6 +275,16 @@ class connect: :meth:`~asyncio.loop.create_connection` method) to create a suitable client socket and customize it. + When using a proxy: + + * Prefix keyword arguments with ``proxy_`` for configuring TLS between the + client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``, + ``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``. + * Use the standard keyword arguments for configuring TLS between the proxy + and the WebSocket server: ``ssl``, ``server_hostname``, + ``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``. + * Other keyword arguments are used only for connecting to the proxy. + Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. InvalidProxy: If ``proxy`` isn't a valid proxy. @@ -397,6 +416,47 @@ def factory() -> ClientConnection: sock=sock, **kwargs, ) + elif proxy_parsed.scheme[:4] == "http": + # Split keyword arguments between the proxy and the server. + all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {} + for key, value in all_kwargs.items(): + if key.startswith("ssl") or key == "server_hostname": + kwargs[key] = value + elif key.startswith("proxy_"): + proxy_kwargs[key[6:]] = value + else: + proxy_kwargs[key] = value + # Validate the proxy_ssl argument. + if proxy_parsed.scheme == "https": + proxy_kwargs.setdefault("ssl", True) + if proxy_kwargs.get("ssl") is None: + raise ValueError( + "proxy_ssl=None is incompatible with an https:// proxy" + ) + else: + if proxy_kwargs.get("ssl") is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + transport = await connect_http_proxy( + proxy_parsed, + ws_uri, + **proxy_kwargs, + ) + # Initialize WebSocket connection via the proxy. + connection = factory() + transport.set_protocol(connection) + ssl = kwargs.pop("ssl", None) + if ssl is True: + ssl = ssl_module.create_default_context() + if ssl is not None: + new_transport = await loop.start_tls( + transport, connection, ssl, **kwargs + ) + assert new_transport is not None # help mypy + transport = new_transport + connection.connection_made(transport) else: raise AssertionError("unsupported proxy") else: @@ -655,3 +715,89 @@ async def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") + + +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() + + +class HTTPProxyConnection(asyncio.Protocol): + def __init__(self, ws_uri: WebSocketURI, proxy: Proxy): + self.ws_uri = ws_uri + self.proxy = proxy + + self.reader = StreamReader() + self.parser = Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + include_body=False, + ) + + loop = asyncio.get_running_loop() + self.response: asyncio.Future[Response] = loop.create_future() + + def run_parser(self) -> None: + try: + next(self.parser) + except StopIteration as exc: + response = exc.value + if 200 <= response.status_code < 300: + self.response.set_result(response) + else: + self.response.set_exception(InvalidProxyStatus(response)) + except Exception as exc: + proxy_exc = InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) + proxy_exc.__cause__ = exc + self.response.set_exception(proxy_exc) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.transport.write(prepare_connect_request(self.proxy, self.ws_uri)) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + self.run_parser() + + def eof_received(self) -> None: + self.reader.feed_eof() + self.run_parser() + + def connection_lost(self, exc: Exception | None) -> None: + self.reader.feed_eof() + if exc is not None: + self.response.set_exception(exc) + + +async def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, +) -> asyncio.Transport: + transport, protocol = await asyncio.get_running_loop().create_connection( + lambda: HTTPProxyConnection(ws_uri, proxy), + proxy.host, + proxy.port, + **kwargs, + ) + + try: + # This raises exceptions if the connection to the proxy fails. + await protocol.response + except Exception: + transport.close() + raise + + return transport diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 8ce9e7d84..c0fe6901a 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -5,16 +5,17 @@ import threading import warnings from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Callable, Literal, TypeVar, cast from ..client import ClientProtocol -from ..datastructures import HeadersLike -from ..exceptions import ProxyError +from ..datastructures import Headers, HeadersLike +from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import validate_subprotocols +from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .connection import Connection @@ -141,6 +142,8 @@ def connect( additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, proxy: str | Literal[True] | None = True, + proxy_ssl: ssl_module.SSLContext | None = None, + proxy_server_hostname: str | None = None, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -195,6 +198,9 @@ def connect( to :obj:`None` to disable the proxy or to the address of a proxy to override the system configuration. See the :doc:`proxy docs <../../topics/proxies>` for details. + proxy_ssl: Configuration for enabling TLS on the proxy connection. + proxy_server_hostname: Host name for the TLS handshake with the proxy. + ``proxy_server_hostname`` overrides the host name from ``proxy``. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -295,6 +301,21 @@ def connect( # python_socks is consistent across implementations. local_addr=kwargs.pop("source_address", None), ) + elif proxy_parsed.scheme[:4] == "http": + # Validate the proxy_ssl argument. + if proxy_parsed.scheme != "https" and proxy_ssl is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + sock = connect_http_proxy( + proxy_parsed, + ws_uri, + deadline, + ssl=proxy_ssl, + server_hostname=proxy_server_hostname, + **kwargs, + ) else: raise AssertionError("unsupported proxy") else: @@ -318,7 +339,12 @@ def connect( if server_hostname is None: server_hostname = ws_uri.host sock.settimeout(deadline.timeout()) - sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + if proxy_ssl is None: + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + else: + sock_2 = SSLSSLSocket(sock, ssl, server_hostname=server_hostname) + # Let's pretend that sock is a socket, even though it isn't. + sock = cast(socket.socket, sock_2) sock.settimeout(None) # Initialize WebSocket protocol @@ -444,3 +470,171 @@ def connect_socks_proxy( **kwargs: Any, ) -> socket.socket: raise ImportError("python-socks is required to use a SOCKS proxy") + + +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() + + +def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response: + reader = StreamReader() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + include_body=False, + ) + try: + while True: + sock.settimeout(deadline.timeout()) + data = sock.recv(4096) + if data: + reader.feed_data(data) + else: + reader.feed_eof() + next(parser) + except StopIteration as exc: + assert isinstance(exc.value, Response) # help mypy + response = exc.value + if 200 <= response.status_code < 300: + return response + else: + raise InvalidProxyStatus(response) + except socket.timeout: + raise TimeoutError("timed out while connecting to HTTP proxy") + except Exception as exc: + raise InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) from exc + finally: + sock.settimeout(None) + + +def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + *, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + **kwargs: Any, +) -> socket.socket: + # Connect socket + + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection((proxy.host, proxy.port), **kwargs) + + # Initialize TLS wrapper and perform TLS handshake + + if proxy.scheme == "https": + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = proxy.host + sock.settimeout(deadline.timeout()) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + sock.settimeout(None) + + # Send CONNECT request to the proxy and read response. + + sock.sendall(prepare_connect_request(proxy, ws_uri)) + try: + read_connect_response(sock, deadline) + except Exception: + sock.close() + raise + + return sock + + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., T]) + + +class SSLSSLSocket: + """ + Socket-like object providing TLS-in-TLS. + + Only methods that are used by websockets are implemented. + + """ + + recv_bufsize = 65536 + + def __init__( + self, + sock: socket.socket, + ssl_context: ssl_module.SSLContext, + server_hostname: str | None = None, + ) -> None: + self.incoming = ssl_module.MemoryBIO() + self.outgoing = ssl_module.MemoryBIO() + self.ssl_socket = sock + self.ssl_object = ssl_context.wrap_bio( + self.incoming, + self.outgoing, + server_hostname=server_hostname, + ) + self.run_io(self.ssl_object.do_handshake) + + def run_io(self, func: Callable[..., T], *args: Any) -> T: + while True: + want_read = False + want_write = False + try: + result = func(*args) + except ssl_module.SSLWantReadError: + want_read = True + except ssl_module.SSLWantWriteError: # pragma: no cover + want_write = True + + # Write outgoing data in all cases. + data = self.outgoing.read() + if data: + self.ssl_socket.sendall(data) + + # Read incoming data and retry on SSLWantReadError. + if want_read: + data = self.ssl_socket.recv(self.recv_bufsize) + if data: + self.incoming.write(data) + else: + self.incoming.write_eof() + continue + # Retry after writing outgoing data on SSLWantWriteError. + if want_write: # pragma: no cover + continue + # Return result if no error happened. + return result + + def recv(self, buflen: int) -> bytes: + try: + return self.run_io(self.ssl_object.read, buflen) + except ssl_module.SSLEOFError: + return b"" # always ignore ragged EOFs + + def send(self, data: bytes) -> int: + return self.run_io(self.ssl_object.write, data) + + def sendall(self, data: bytes) -> None: + # adapted from ssl_module.SSLSocket.sendall() + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + count += self.send(byte_view[count:]) + + # recv_into(), recvfrom(), recvfrom_into(), sendto(), unwrap(), and the + # flags argument aren't implemented because websockets doesn't need them. + + def __getattr__(self, name: str) -> Any: + return getattr(self.ssl_socket, name) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 3728c7734..c6ff26ae4 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -15,6 +15,7 @@ InvalidHandshake, InvalidMessage, InvalidProxy, + InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, @@ -667,6 +668,181 @@ async def test_ignore_proxy_with_existing_socket(self): self.assertNumFlows(0) +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "regular@58080" + + async def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + async def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + async def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"https_proxy": "http://hello:iloveyou@localhost:58080"} + ): + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + async def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + async def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + async def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + async def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + async def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + async def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args) as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + async def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + async def test_https_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args) as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + self.assertNumFlows(1) + + async def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception), + ) + + async def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + async with serve(*args, ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + self.assertNumFlows(1) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): @@ -737,7 +913,7 @@ async def test_ssl_without_secure_uri(self): ) async def test_secure_uri_without_ssl(self): - """Client rejects no ssl when URI is secure.""" + """Client rejects ssl=None when URI is secure.""" with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( @@ -745,6 +921,26 @@ async def test_secure_uri_without_ssl(self): "ssl=None is incompatible with a wss:// URI", ) + async def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with patch_environ({"https_proxy": "http://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", proxy_ssl=True) + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + + async def test_https_proxy_without_ssl(self): + """Client rejects proxy_ssl=None when proxy is HTTPS.""" + with patch_environ({"https_proxy": "https://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + await connect("ws://localhost/", proxy_ssl=None) + self.assertEqual( + str(raised.exception), + "proxy_ssl=None is incompatible with an https:// proxy", + ) + async def test_unsupported_proxy(self): """Client rejects unsupported proxy.""" with self.assertRaises(InvalidProxy) as raised: diff --git a/tests/proxy.py b/tests/proxy.py index 6cae7b49e..9746e3382 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -1,4 +1,6 @@ import asyncio +import pathlib +import ssl import threading import warnings @@ -8,9 +10,11 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") + from mitmproxy import ctx from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig + from mitmproxy.http import Response from mitmproxy.master import Master - from mitmproxy.options import Options + from mitmproxy.options import CONF_BASENAME, CONF_DIR, Options except ImportError: pass @@ -31,6 +35,31 @@ def reset_flows(self): self.flows = [] +class AlterRequest: + def load(self, loader): + loader.add_option( + name="break_http_connect", + typespec=bool, + default=False, + help="Respond to HTTP CONNECT requests with a 999 status code.", + ) + loader.add_option( + name="close_http_connect", + typespec=bool, + default=False, + help="Do not respond to HTTP CONNECT requests.", + ) + + def http_connect(self, flow): + if ctx.options.break_http_connect: + # mitmproxy can send a response with a status code not between 100 + # and 599, while websockets treats it as a protocol error. + # This is used for testing HTTP parsing errors. + flow.response = Response.make(999, "not a valid HTTP response") + if ctx.options.close_http_connect: + flow.kill() + + class ProxyMixin: """ Run mitmproxy in a background thread. @@ -62,6 +91,7 @@ async def run_proxy(cls): next_layer.NextLayer(), tlsconfig.TlsConfig(), RecordFlows(on_running=cls.proxy_ready.set), + AlterRequest(), ) task = loop.create_task(cls.proxy_master.run()) @@ -86,6 +116,11 @@ def setUpClass(cls): cls.proxy_thread.start() cls.proxy_ready.wait() + certificate = pathlib.Path(CONF_DIR) / f"{CONF_BASENAME}-ca-cert.pem" + certificate = certificate.expanduser() + cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + cls.proxy_context.load_verify_locations(bytes(certificate)) + def assertNumFlows(self, num_flows): record_flows = self.proxy_master.addons.get("recordflows") self.assertEqual(len(record_flows.get_flows()), num_flows) diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 47886e015..386caf56a 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -12,6 +12,7 @@ InvalidHandshake, InvalidMessage, InvalidProxy, + InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, @@ -407,6 +408,179 @@ def test_ignore_proxy_with_existing_socket(self): self.assertNumFlows(0) +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "regular@58080" + + def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + self.assertNumFlows(1) + + def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with patch_environ( + {"https_proxy": "http://hello:iloveyou@localhost:58080"} + ): + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with patch_environ({"https_proxy": "http://localhost:58080"}): + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port + with self.assertRaises(OSError): + with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out while connecting to HTTP proxy", + ) + + def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server() as server: + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + self.assertNumFlows(1) + + def test_https_proxy_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + self.assertNumFlows(1) + + def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception), + ) + self.assertNumFlows(0) + + def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when server certificate isn't trusted.""" + with patch_environ({"https_proxy": "https://localhost:58080"}): + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + self.assertNumFlows(1) + + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.TestCase): def test_connection(self): @@ -453,6 +627,16 @@ def test_ssl_without_secure_uri(self): "ssl argument is incompatible with a ws:// URI", ) + def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with patch_environ({"https_proxy": "http://localhost:8080"}): + with self.assertRaises(ValueError) as raised: + connect("ws://localhost/", proxy_ssl=True) + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: diff --git a/tests/utils.py b/tests/utils.py index f68a447b1..389381345 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,14 +20,13 @@ # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt -CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) +CERTIFICATE = pathlib.Path(__file__).with_name("test_localhost.pem") CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) - +CLIENT_CONTEXT.load_verify_locations(bytes(CERTIFICATE)) SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) -SERVER_CONTEXT.load_cert_chain(CERTIFICATE) +SERVER_CONTEXT.load_cert_chain(bytes(CERTIFICATE)) # Work around https://github.com/openssl/openssl/issues/7967 From 2b9a90a7edc305f5619b229ae5cbda755e173da9 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 21:51:10 +0100 Subject: [PATCH 729/762] Replace patch_environ with unittest.mock.patch.dict. --- tests/asyncio/test_client.py | 228 +++++++++++++++++------------------ tests/sync/test_client.py | 206 +++++++++++++++---------------- tests/test_uri.py | 6 +- tests/utils.py | 16 --- 4 files changed, 219 insertions(+), 237 deletions(-) diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index c6ff26ae4..c2a96f3ec 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -2,10 +2,12 @@ import contextlib import http import logging +import os import socket import ssl import sys import unittest +from unittest.mock import patch from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError @@ -24,13 +26,7 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from ..proxy import ProxyMixin -from ..utils import ( - CLIENT_CONTEXT, - MS, - SERVER_CONTEXT, - patch_environ, - temp_unix_socket_path, -) +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path from .server import args, get_host_port, get_uri, handler @@ -570,46 +566,44 @@ def redirect(connection, request): class SocksProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "socks5@51080" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) async def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"socks_proxy": "http://hello:iloveyou@localhost:51080"} - ): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_authenticated_socks_proxy_error(self): """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError try: self.proxy_options.update(proxyauth="any") - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with self.assertRaises(ProxyError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -619,14 +613,14 @@ async def test_authenticated_socks_proxy_error(self): self.assertIsInstance(raised.exception.__cause__, SocksProxyError) self.assertNumFlows(0) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port async def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) self.assertNumFlows(0) @@ -636,7 +630,7 @@ async def test_socks_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: async with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -657,14 +651,14 @@ async def test_explicit_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - async with serve(*args) as server: - with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. - async with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -672,46 +666,44 @@ async def test_ignore_proxy_with_existing_socket(self): class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "regular@58080" + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy(self): """Client connects to server through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_secure_http_proxy(self): """Client connects to server securely through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.version()[:3], "TLS") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) async def test_authenticated_http_proxy(self): """Client connects to server through an authenticated HTTP proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"https_proxy": "http://hello:iloveyou@localhost:58080"} - ): - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_authenticated_http_proxy_error(self): """Client fails to authenticate to the HTTP proxy.""" try: self.proxy_options.update(proxyauth="any") - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(ProxyError) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -720,14 +712,14 @@ async def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" try: self.proxy_options.update(break_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(break_http_connect=False) self.assertEqual( @@ -736,14 +728,14 @@ async def test_http_proxy_protocol_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_connection_error(self): """Client receives no response when connecting to the HTTP proxy.""" try: self.proxy_options.update(close_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(close_http_connect=False) self.assertEqual( @@ -752,12 +744,12 @@ async def test_http_proxy_connection_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port async def test_http_proxy_connection_failure(self): """Client fails to connect to the HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError): - async with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertNumFlows(0) @@ -766,7 +758,7 @@ async def test_http_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: async with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -775,67 +767,67 @@ async def test_http_proxy_connection_timeout(self): "timed out during opening handshake", ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy(self): """Client connects to server through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args) as server: - async with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async with serve(*args) as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_secure_https_proxy(self): """Client connects to server securely through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect( - get_uri(server), - ssl=CLIENT_CONTEXT, - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.version()[:3], "TLS") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_server_hostname(self): """Client sets server_hostname to the value of proxy_server_hostname.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args) as server: - # Pass an argument not prefixed with proxy_ for coverage. - kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} - async with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - proxy_server_hostname="overridden", - **kwargs, - ) as client: - ssl_object = client.transport.get_extra_info("ssl_object") - self.assertEqual(ssl_object.server_hostname, "overridden") + async with serve(*args) as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy_invalid_proxy_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The proxy certificate isn't trusted. - async with connect("wss://example.com/"): - self.fail("did not raise") + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") self.assertIn( "certificate verify failed: unable to get local issuer certificate", str(raised.exception), ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) async def test_https_proxy_invalid_server_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - async with serve(*args, ssl=SERVER_CONTEXT) as server: - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate is self-signed. - async with connect(get_uri(server), proxy_ssl=self.proxy_context): - self.fail("did not raise") + async with serve(*args, ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), @@ -923,9 +915,12 @@ async def test_secure_uri_without_ssl(self): async def test_proxy_ssl_without_https_proxy(self): """Client rejects proxy_ssl when proxy isn't HTTPS.""" - with patch_environ({"https_proxy": "http://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - await connect("ws://localhost/", proxy_ssl=True) + with self.assertRaises(ValueError) as raised: + await connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ) self.assertEqual( str(raised.exception), "proxy_ssl argument is incompatible with an http:// proxy", @@ -933,9 +928,12 @@ async def test_proxy_ssl_without_https_proxy(self): async def test_https_proxy_without_ssl(self): """Client rejects proxy_ssl=None when proxy is HTTPS.""" - with patch_environ({"https_proxy": "https://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - await connect("ws://localhost/", proxy_ssl=None) + with self.assertRaises(ValueError) as raised: + await connect( + "ws://localhost/", + proxy="https://localhost:8080", + proxy_ssl=None, + ) self.assertEqual( str(raised.exception), "proxy_ssl=None is incompatible with an https:// proxy", diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 386caf56a..e4927bb32 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,5 +1,6 @@ import http import logging +import os import socket import socketserver import ssl @@ -7,6 +8,7 @@ import threading import time import unittest +from unittest.mock import patch from websockets.exceptions import ( InvalidHandshake, @@ -26,7 +28,6 @@ MS, SERVER_CONTEXT, DeprecationTestCase, - patch_environ, temp_unix_socket_path, ) from .server import get_uri, run_server, run_unix_server @@ -310,46 +311,44 @@ def test_reject_invalid_server_hostname(self): class SocksProxyClientTests(ProxyMixin, unittest.TestCase): proxy_mode = "socks5@51080" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"socks_proxy": "http://hello:iloveyou@localhost:51080"} - ): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_authenticated_socks_proxy_error(self): """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError try: self.proxy_options.update(proxyauth="any") - with patch_environ({"socks_proxy": "http://localhost:51080"}): - with self.assertRaises(ProxyError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -359,14 +358,14 @@ def test_authenticated_socks_proxy_error(self): self.assertIsInstance(raised.exception.__cause__, SocksProxyError) self.assertNumFlows(0) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError - with patch_environ({"socks_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) self.assertNumFlows(0) @@ -378,7 +377,7 @@ def test_socks_proxy_connection_timeout(self): # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"socks_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -397,14 +396,14 @@ def test_explicit_socks_proxy(self): self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"ws_proxy": "http://localhost:58080"}) def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with patch_environ({"ws_proxy": "http://localhost:58080"}): - with run_server() as server: - with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. - with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with socket.create_connection(server.socket.getsockname()) as sock: + # Use a non-existing domain to ensure we connect to sock. + with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -412,45 +411,43 @@ def test_ignore_proxy_with_existing_socket(self): class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): proxy_mode = "regular@58080" + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy(self): """Client connects to server through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_secure_http_proxy(self): """Client connects to server securely through an HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(client.socket.version()[:3], "TLS") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) def test_authenticated_http_proxy(self): """Client connects to server through an authenticated HTTP proxy.""" try: self.proxy_options.update(proxyauth="hello:iloveyou") - with patch_environ( - {"https_proxy": "http://hello:iloveyou@localhost:58080"} - ): - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") finally: self.proxy_options.update(proxyauth=None) self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_authenticated_http_proxy_error(self): """Client fails to authenticate to the HTTP proxy.""" try: self.proxy_options.update(proxyauth="any") - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(ProxyError) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(proxyauth=None) self.assertEqual( @@ -459,14 +456,14 @@ def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" try: self.proxy_options.update(break_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(break_http_connect=False) self.assertEqual( @@ -475,14 +472,14 @@ def test_http_proxy_protocol_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_connection_error(self): """Client receives no response when connecting to the HTTP proxy.""" try: self.proxy_options.update(close_http_connect=True) - with patch_environ({"https_proxy": "http://localhost:58080"}): - with self.assertRaises(InvalidProxyMessage) as raised: - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") finally: self.proxy_options.update(close_http_connect=False) self.assertEqual( @@ -491,12 +488,12 @@ def test_http_proxy_connection_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port def test_http_proxy_connection_failure(self): """Client fails to connect to the HTTP proxy.""" - with patch_environ({"https_proxy": "http://localhost:61080"}): # bad port - with self.assertRaises(OSError): - with connect("ws://example.com/"): - self.fail("did not raise") + with self.assertRaises(OSError): + with connect("ws://example.com/"): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertNumFlows(0) @@ -505,7 +502,7 @@ def test_http_proxy_connection_timeout(self): # Replace the proxy with a TCP server that does't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with patch_environ({"https_proxy": f"http://{host}:{port}"}): + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): with self.assertRaises(TimeoutError) as raised: with connect("ws://example.com/", open_timeout=MS): self.fail("did not raise") @@ -514,66 +511,66 @@ def test_http_proxy_connection_timeout(self): "timed out while connecting to HTTP proxy", ) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy(self): """Client connects to server through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server() as server: - with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + with run_server() as server: + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_secure_https_proxy(self): """Client connects to server securely through an HTTPS proxy.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with connect( - get_uri(server), - ssl=CLIENT_CONTEXT, - proxy_ssl=self.proxy_context, - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(client.socket.version()[:3], "TLS") + with run_server(ssl=SERVER_CONTEXT) as server: + with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_server_hostname(self): """Client sets server_hostname to the value of proxy_server_hostname.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server() as server: - # Pass an argument not prefixed with proxy_ for coverage. - kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} - with connect( - get_uri(server), - proxy_ssl=self.proxy_context, - proxy_server_hostname="overridden", - **kwargs, - ) as client: - self.assertEqual(client.socket.server_hostname, "overridden") + with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") self.assertNumFlows(1) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_invalid_proxy_certificate(self): """Client rejects certificate when proxy certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The proxy certificate isn't trusted. - with connect("wss://example.com/"): - self.fail("did not raise") + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + with connect("wss://example.com/"): + self.fail("did not raise") self.assertIn( "certificate verify failed: unable to get local issuer certificate", str(raised.exception), ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) def test_https_proxy_invalid_server_certificate(self): """Client rejects certificate when server certificate isn't trusted.""" - with patch_environ({"https_proxy": "https://localhost:58080"}): - with run_server(ssl=SERVER_CONTEXT) as server: - with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate is self-signed. - with connect(get_uri(server), proxy_ssl=self.proxy_context): - self.fail("did not raise") + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") self.assertIn( "certificate verify failed: self signed certificate", str(raised.exception).replace("-", " "), @@ -629,9 +626,12 @@ def test_ssl_without_secure_uri(self): def test_proxy_ssl_without_https_proxy(self): """Client rejects proxy_ssl when proxy isn't HTTPS.""" - with patch_environ({"https_proxy": "http://localhost:8080"}): - with self.assertRaises(ValueError) as raised: - connect("ws://localhost/", proxy_ssl=True) + with self.assertRaises(ValueError) as raised: + connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ) self.assertEqual( str(raised.exception), "proxy_ssl argument is incompatible with an http:// proxy", diff --git a/tests/test_uri.py b/tests/test_uri.py index 35b51fa58..3ccf21158 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -1,11 +1,11 @@ +import os import unittest +from unittest.mock import patch from websockets.exceptions import InvalidProxy, InvalidURI from websockets.uri import * from websockets.uri import Proxy, get_proxy, parse_proxy -from .utils import patch_environ - VALID_URIS = [ ( @@ -255,6 +255,6 @@ def test_parse_proxy_user_info(self): def test_get_proxy(self): for environ, uri, proxy in PROXY_ENVS: - with patch_environ(environ): + with patch.dict(os.environ, environ): with self.subTest(environ=environ, uri=uri): self.assertEqual(get_proxy(parse_uri(uri)), proxy) diff --git a/tests/utils.py b/tests/utils.py index 389381345..7932aae60 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -138,22 +138,6 @@ def assertNoLogs(self, logger=None, level=None): self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) -@contextlib.contextmanager -def patch_environ(environ): - backup = {} - for key, value in environ.items(): - backup[key] = os.environ.get(key) - os.environ[key] = value - try: - yield - finally: - for key, value in backup.items(): - if value is None: - del os.environ[key] - else: # pragma: no cover - os.environ[key] = value - - @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: From 1a90f1e9ebf6a9ea034d89d1e706bb8e3bee1f25 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 23:14:32 +0100 Subject: [PATCH 730/762] Set User-Agent header in CONNECT requests. --- src/websockets/asyncio/client.py | 37 ++++++++++++++++++++++--------- src/websockets/sync/client.py | 12 ++++++++-- tests/asyncio/test_client.py | 18 +++++++++++++++ tests/proxy.py | 38 ++++++++++++++++++++++---------- tests/sync/test_client.py | 18 +++++++++++++++ 5 files changed, 99 insertions(+), 24 deletions(-) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index c19a53f8c..38a56ddda 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -97,7 +97,7 @@ async def handshake( self.request = self.protocol.connect() if additional_headers is not None: self.request.headers.update(additional_headers) - if user_agent_header: + if user_agent_header is not None: self.request.headers.setdefault("User-Agent", user_agent_header) self.protocol.send_request(self.request) @@ -363,10 +363,8 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: self.proxy = proxy self.protocol_factory = protocol_factory - self.handshake_args = ( - additional_headers, - user_agent_header, - ) + self.additional_headers = additional_headers + self.user_agent_header = user_agent_header self.process_exception = process_exception self.open_timeout = open_timeout self.logger = logger @@ -442,6 +440,7 @@ def factory() -> ClientConnection: transport = await connect_http_proxy( proxy_parsed, ws_uri, + user_agent_header=self.user_agent_header, **proxy_kwargs, ) # Initialize WebSocket connection via the proxy. @@ -541,7 +540,10 @@ async def __await_impl__(self) -> ClientConnection: for _ in range(MAX_REDIRECTS): self.connection = await self.create_connection() try: - await self.connection.handshake(*self.handshake_args) + await self.connection.handshake( + self.additional_headers, + self.user_agent_header, + ) except asyncio.CancelledError: self.connection.transport.abort() raise @@ -717,10 +719,16 @@ async def connect_socks_proxy( raise ImportError("python-socks is required to use a SOCKS proxy") -def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: +def prepare_connect_request( + proxy: Proxy, + ws_uri: WebSocketURI, + user_agent_header: str | None = None, +) -> bytes: host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) headers = Headers() headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if user_agent_header is not None: + headers["User-Agent"] = user_agent_header if proxy.username is not None: assert proxy.password is not None # enforced by parse_proxy() headers["Proxy-Authorization"] = build_authorization_basic( @@ -731,9 +739,15 @@ def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: class HTTPProxyConnection(asyncio.Protocol): - def __init__(self, ws_uri: WebSocketURI, proxy: Proxy): + def __init__( + self, + ws_uri: WebSocketURI, + proxy: Proxy, + user_agent_header: str | None = None, + ): self.ws_uri = ws_uri self.proxy = proxy + self.user_agent_header = user_agent_header self.reader = StreamReader() self.parser = Response.parse( @@ -765,7 +779,9 @@ def run_parser(self) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) self.transport = transport - self.transport.write(prepare_connect_request(self.proxy, self.ws_uri)) + self.transport.write( + prepare_connect_request(self.proxy, self.ws_uri, self.user_agent_header) + ) def data_received(self, data: bytes) -> None: self.reader.feed_data(data) @@ -784,10 +800,11 @@ def connection_lost(self, exc: Exception | None) -> None: async def connect_http_proxy( proxy: Proxy, ws_uri: WebSocketURI, + user_agent_header: str | None = None, **kwargs: Any, ) -> asyncio.Transport: transport, protocol = await asyncio.get_running_loop().create_connection( - lambda: HTTPProxyConnection(ws_uri, proxy), + lambda: HTTPProxyConnection(ws_uri, proxy, user_agent_header), proxy.host, proxy.port, **kwargs, diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index c0fe6901a..58cb84710 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -312,6 +312,7 @@ def connect( proxy_parsed, ws_uri, deadline, + user_agent_header=user_agent_header, ssl=proxy_ssl, server_hostname=proxy_server_hostname, **kwargs, @@ -472,10 +473,16 @@ def connect_socks_proxy( raise ImportError("python-socks is required to use a SOCKS proxy") -def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: +def prepare_connect_request( + proxy: Proxy, + ws_uri: WebSocketURI, + user_agent_header: str | None = None, +) -> bytes: host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) headers = Headers() headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if user_agent_header is not None: + headers["User-Agent"] = user_agent_header if proxy.username is not None: assert proxy.password is not None # enforced by parse_proxy() headers["Proxy-Authorization"] = build_authorization_basic( @@ -524,6 +531,7 @@ def connect_http_proxy( ws_uri: WebSocketURI, deadline: Deadline, *, + user_agent_header: str | None = None, ssl: ssl_module.SSLContext | None = None, server_hostname: str | None = None, **kwargs: Any, @@ -546,7 +554,7 @@ def connect_http_proxy( # Send CONNECT request to the proxy and read response. - sock.sendall(prepare_connect_request(proxy, ws_uri)) + sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header)) try: read_connect_response(sock, deadline) except Exception: diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index c2a96f3ec..465ea2bdb 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -712,6 +712,24 @@ async def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.protocol.state.name, "OPEN") + [http_connect] = self.get_http_connects() + self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith") + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header=None) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + [http_connect] = self.get_http_connects() + self.assertNotIn(b"User-Agent", http_connect.request.headers) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) async def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" diff --git a/tests/proxy.py b/tests/proxy.py index 9746e3382..236c49337 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -22,17 +22,26 @@ class RecordFlows: def __init__(self, on_running): self.running = on_running - self.flows = [] + self.http_connects = [] + self.tcp_flows = [] + + def http_connect(self, flow): + self.http_connects.append(flow) def tcp_start(self, flow): - self.flows.append(flow) + self.tcp_flows.append(flow) + + def get_http_connects(self): + http_connects, self.http_connects[:] = self.http_connects[:], [] + return http_connects - def get_flows(self): - flows, self.flows[:] = self.flows[:], [] - return flows + def get_tcp_flows(self): + tcp_flows, self.tcp_flows[:] = self.tcp_flows[:], [] + return tcp_flows - def reset_flows(self): - self.flows = [] + def reset(self): + self.http_connects = [] + self.tcp_flows = [] class AlterRequest: @@ -121,13 +130,18 @@ def setUpClass(cls): cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) cls.proxy_context.load_verify_locations(bytes(certificate)) - def assertNumFlows(self, num_flows): - record_flows = self.proxy_master.addons.get("recordflows") - self.assertEqual(len(record_flows.get_flows()), num_flows) + def get_http_connects(self): + return self.proxy_master.addons.get("recordflows").get_http_connects() + + def get_tcp_flows(self): + return self.proxy_master.addons.get("recordflows").get_tcp_flows() + + def assertNumFlows(self, num_tcp_flows): + self.assertEqual(len(self.get_tcp_flows()), num_tcp_flows) def tearDown(self): - record_flows = self.proxy_master.addons.get("recordflows") - record_flows.reset_flows() + record_tcp_flows = self.proxy_master.addons.get("recordflows") + record_tcp_flows.reset() super().tearDown() @classmethod diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index e4927bb32..415343911 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -456,6 +456,24 @@ def test_authenticated_http_proxy_error(self): ) self.assertNumFlows(0) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_http_proxy_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + with run_server() as server: + with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.protocol.state.name, "OPEN") + [http_connect] = self.get_http_connects() + self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith") + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_http_proxy_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + with run_server() as server: + with connect(get_uri(server), user_agent_header=None) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + [http_connect] = self.get_http_connects() + self.assertNotIn(b"User-Agent", http_connect.request.headers) + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) def test_http_proxy_protocol_error(self): """Client receives invalid data when connecting to the HTTP proxy.""" From d60255bfbb21853f764999dad462e06d9ac3ba72 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 23:27:01 +0100 Subject: [PATCH 731/762] Highlight potential backwards incompatibility. --- docs/project/changelog.rst | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7bb94b349..4eaa0a4d1 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -47,11 +47,14 @@ Backwards-incompatible changes .. _python-socks: https://github.com/romis2012/python-socks -New features -............ +.. admonition:: Keepalive is enabled in the :mod:`threading` implementation. + :class: note + + The :mod:`threading` implementation now sends Ping frames at regular + intervals and closes the connection if it doesn't receive a matching Pong + frame just like the :mod:`asyncio` implementation. -* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the - :mod:`threading` implementation. + See :doc:`keepalive and latency <../topics/keepalive>` for details. Improvements ............ From 7a2f8f40af801de9e09b4411aa4af3cc5a86139f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Feb 2025 23:24:43 +0100 Subject: [PATCH 732/762] Start recv_events only after attributes are initialized. Else, a race condition could lead to accessing self.pong_waiters before it is defined. --- src/websockets/asyncio/connection.py | 16 ++++++++-------- src/websockets/sync/connection.py | 28 +++++++++++++++------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 79429923e..1b51e4791 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -101,14 +101,6 @@ def __init__( # Protect sending fragmented messages. self.fragmented_send_waiter: asyncio.Future[None] | None = None - # Exception raised while reading from the connection, to be chained to - # ConnectionClosed in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - - # Completed when the TCP connection is closed and the WebSocket - # connection state becomes CLOSED. - self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() - # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} @@ -128,6 +120,14 @@ def __init__( # Task that sends keepalive pings. None when ping_interval is None. self.keepalive_task: asyncio.Task[None] | None = None + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + # Adapted from asyncio.FlowControlMixin self.paused: bool = False self.drain_waiters: collections.deque[asyncio.Future[None]] = ( diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 0c517cc64..8b9e06257 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -101,19 +101,6 @@ def __init__( # Whether we are busy sending a fragmented message. self.send_in_progress = False - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - - # Receiving events from the socket. This thread is marked as daemon to - # allow creating a connection in a non-daemon thread and using it in a - # daemon thread. This mustn't prevent the interpreter from exiting. - self.recv_events_thread = threading.Thread( - target=self.recv_events, - daemon=True, - ) - self.recv_events_thread.start() - # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} @@ -133,6 +120,21 @@ def __init__( # Thread that sends keepalive pings. None when ping_interval is None. self.keepalive_thread: threading.Thread | None = None + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Receiving events from the socket. This thread is marked as daemon to + # allow creating a connection in a non-daemon thread and using it in a + # daemon thread. This mustn't prevent the interpreter from exiting. + self.recv_events_thread = threading.Thread( + target=self.recv_events, + daemon=True, + ) + + # Start recv_events only after all attributes are initialized. + self.recv_events_thread.start() + # Public attributes @property From 602d7195c09c3c568bfd0268311a7b07a07ffb3d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 4 Feb 2025 22:09:49 +0100 Subject: [PATCH 733/762] Standardized and improved admonition types. --- docs/faq/asyncio.rst | 2 +- docs/faq/client.rst | 2 +- docs/faq/server.rst | 2 +- docs/project/changelog.rst | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/faq/asyncio.rst b/docs/faq/asyncio.rst index 3bc381cfd..a1bb663b5 100644 --- a/docs/faq/asyncio.rst +++ b/docs/faq/asyncio.rst @@ -4,7 +4,7 @@ Using asyncio .. currentmodule:: websockets.asyncio.connection .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. - :class: hint + :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. diff --git a/docs/faq/client.rst b/docs/faq/client.rst index cc9856a8b..d3f627684 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -4,7 +4,7 @@ Client .. currentmodule:: websockets.asyncio.client .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. - :class: hint + :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. diff --git a/docs/faq/server.rst b/docs/faq/server.rst index ce7e1962d..e6a3abe85 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -4,7 +4,7 @@ Server .. currentmodule:: websockets.asyncio.server .. admonition:: This FAQ is written for the new :mod:`asyncio` implementation. - :class: hint + :class: tip Answers are also valid for the legacy :mod:`asyncio` implementation. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 4eaa0a4d1..1f02a6cde 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -48,7 +48,7 @@ Backwards-incompatible changes .. _python-socks: https://github.com/romis2012/python-socks .. admonition:: Keepalive is enabled in the :mod:`threading` implementation. - :class: note + :class: important The :mod:`threading` implementation now sends Ping frames at regular intervals and closes the connection if it doesn't receive a matching Pong @@ -135,7 +135,7 @@ Backwards-incompatible changes websockets 13.1 is the last version supporting Python 3.8. .. admonition:: The new :mod:`asyncio` implementation is now the default. - :class: danger + :class: attention The following aliases in the ``websockets`` package were switched to the new :mod:`asyncio` implementation:: @@ -749,7 +749,7 @@ Security fix ............ .. admonition:: websockets 9.1 fixes a security issue introduced in 8.0. - :class: important + :class: danger Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords (`CVE-2021-33880`_). @@ -1168,7 +1168,7 @@ Security fix ............ .. admonition:: websockets 5.0 fixes a security issue introduced in 4.0. - :class: important + :class: danger Version 4.0 was vulnerable to denial of service by memory exhaustion because it didn't enforce ``max_size`` when decompressing compressed From 4b9caad779c3c995845abb099c185e7d6009570f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 2 Feb 2025 23:15:59 +0100 Subject: [PATCH 734/762] Add a router based on werkzeug.routing. Fix #311. --- docs/conf.py | 5 +- docs/faq/client.rst | 2 +- docs/faq/server.rst | 4 +- docs/project/changelog.rst | 6 + docs/reference/asyncio/server.rst | 19 ++- docs/reference/features.rst | 2 + docs/reference/sync/server.rst | 27 +++- docs/requirements.txt | 1 + src/websockets/__init__.py | 9 ++ src/websockets/asyncio/router.py | 196 +++++++++++++++++++++++++++++ src/websockets/asyncio/server.py | 4 +- src/websockets/sync/router.py | 190 ++++++++++++++++++++++++++++ src/websockets/sync/server.py | 4 +- tests/asyncio/server.py | 8 +- tests/asyncio/test_router.py | 198 ++++++++++++++++++++++++++++++ tests/sync/server.py | 44 +++++-- tests/sync/test_router.py | 174 ++++++++++++++++++++++++++ tests/test_exports.py | 2 + tox.ini | 2 + 19 files changed, 879 insertions(+), 18 deletions(-) create mode 100644 src/websockets/asyncio/router.py create mode 100644 src/websockets/sync/router.py create mode 100644 tests/asyncio/test_router.py create mode 100644 tests/sync/test_router.py diff --git a/docs/conf.py b/docs/conf.py index 2c621bf41..c6b9ac7d8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,7 +82,10 @@ assert PythonDomain.object_types["data"].roles == ("data", "obj") PythonDomain.object_types["data"].roles = ("data", "class", "obj") -intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None), +} spelling_show_suggestions = True diff --git a/docs/faq/client.rst b/docs/faq/client.rst index d3f627684..c39e588ca 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -81,7 +81,7 @@ The connection is closed when exiting the context manager. How do I reconnect when the connection drops? --------------------------------------------- -Use :func:`~websockets.asyncio.client.connect` as an asynchronous iterator:: +Use :func:`connect` as an asynchronous iterator:: from websockets.asyncio.client import connect from websockets.exceptions import ConnectionClosed diff --git a/docs/faq/server.rst b/docs/faq/server.rst index e6a3abe85..d00dcafba 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -116,7 +116,7 @@ Record all connections in a global variable:: finally: CONNECTIONS.remove(websocket) -Then, call :func:`~websockets.asyncio.server.broadcast`:: +Then, call :func:`broadcast`:: from websockets.asyncio.server import broadcast @@ -219,6 +219,8 @@ You may route a connection to different handlers depending on the request path:: # No handler for this path; close the connection. return +For more complex routing, you may use :func:`~websockets.asyncio.router.route`. + You may also route the connection based on the first message received from the client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to authenticate the connection before routing it, this is usually more convenient. diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 1f02a6cde..d7db6167a 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -56,6 +56,12 @@ Backwards-incompatible changes See :doc:`keepalive and latency <../topics/keepalive>` for details. +New features +............ + +* Added :func:`~asyncio.router.route` and :func:`~asyncio.router.unix_route` to + dispatch connections to different handlers depending on the URL. + Improvements ............ diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 49bd6f072..8d8b700f3 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -12,6 +12,21 @@ Creating a server .. autofunction:: unix_serve :async: +Routing connections +------------------- + +.. automodule:: websockets.asyncio.router + +.. autofunction:: route + :async: + +.. autofunction:: unix_route + :async: + +.. autoclass:: Router + +.. currentmodule:: websockets.asyncio.server + Running a server ---------------- @@ -89,7 +104,7 @@ Using a connection Broadcast --------- -.. autofunction:: websockets.asyncio.server.broadcast +.. autofunction:: broadcast HTTP Basic Authentication ------------------------- @@ -97,4 +112,4 @@ HTTP Basic Authentication websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. -.. autofunction:: websockets.asyncio.server.basic_auth +.. autofunction:: basic_auth diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 93b083d20..0da966ccc 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -127,6 +127,8 @@ Server +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | +------------------------------------+--------+--------+--------+--------+ + | Dispatch connections to handlers | ✅ | ✅ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+ Client ------ diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index c3d0e8f25..f6a45a659 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -10,6 +10,31 @@ Creating a server .. autofunction:: unix_serve +Routing connections +------------------- + +.. automodule:: websockets.sync.router + +.. autofunction:: route + +.. autofunction:: unix_route + +.. autoclass:: Router + +.. currentmodule:: websockets.sync.server + +Routing connections +------------------- + +.. autofunction:: route + :async: + +.. autofunction:: unix_route + :async: + +.. autoclass:: Server + + Running a server ---------------- @@ -78,4 +103,4 @@ HTTP Basic Authentication websockets supports HTTP Basic Authentication according to :rfc:`7235` and :rfc:`7617`. -.. autofunction:: websockets.sync.server.basic_auth +.. autofunction:: basic_auth diff --git a/docs/requirements.txt b/docs/requirements.txt index bcd1d7114..77c87f4dc 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,3 +6,4 @@ sphinx-inline-tabs sphinxcontrib-spelling sphinxcontrib-trio sphinxext-opengraph +werkzeug diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 28a10910b..f90aff5b9 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -12,6 +12,10 @@ "connect", "unix_connect", "ClientConnection", + # .asyncio.router + "route", + "unix_route", + "Router", # .asyncio.server "basic_auth", "broadcast", @@ -79,6 +83,7 @@ # When type checking, import non-deprecated aliases eagerly. Else, import on demand. if TYPE_CHECKING: from .asyncio.client import ClientConnection, connect, unix_connect + from .asyncio.router import Router, route, unix_route from .asyncio.server import ( Server, ServerConnection, @@ -138,6 +143,10 @@ "connect": ".asyncio.client", "unix_connect": ".asyncio.client", "ClientConnection": ".asyncio.client", + # .asyncio.router + "route": ".asyncio.router", + "unix_route": ".asyncio.router", + "Router": ".asyncio.router", # .asyncio.server "basic_auth": ".asyncio.server", "broadcast": ".asyncio.server", diff --git a/src/websockets/asyncio/router.py b/src/websockets/asyncio/router.py new file mode 100644 index 000000000..cd95022c1 --- /dev/null +++ b/src/websockets/asyncio/router.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import http +import ssl as ssl_module +import urllib.parse +from typing import Any, Awaitable, Callable, Literal + +from werkzeug.exceptions import NotFound +from werkzeug.routing import Map, RequestRedirect + +from ..http11 import Request, Response +from .server import Server, ServerConnection, serve + + +__all__ = ["route", "unix_route", "Router"] + + +class Router: + """WebSocket router supporting :func:`route`.""" + + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + else: + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request( + self, connection: ServerConnection, request: Request + ) -> Response | None: + """Route incoming request.""" + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler, connection.handler_kwargs = handler, kwargs + return None + + async def handler(self, connection: ServerConnection) -> None: + """Handle a connection.""" + return await connection.handler(connection, **connection.handler_kwargs) + + +def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, +) -> Awaitable[Server]: + """ + Create a WebSocket server dispatching connections to different handlers. + + This feature requires the third-party library `werkzeug`_:: + + $ pip install werkzeug + + .. _werkzeug: https://werkzeug.palletsprojects.com/ + + :func:`route` accepts the same arguments as + :func:`~websockets.sync.server.serve`, except as described below. + + The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns + to connection handlers. In addition to the connection, handlers receive + parameters captured in the URL as keyword arguments. + + Here's an example:: + + + from websockets.asyncio.router import route + from werkzeug.routing import Map, Rule + + async def channel_handler(websocket, channel_id): + ... + + url_map = Map([ + Rule("/channel/", endpoint=channel_handler), + ... + ]) + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with route(url_map, ...) as server: + await stop + + + Refer to the documentation of :mod:`werkzeug.routing` for details. + + If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, + when the server runs behind a reverse proxy that modifies the ``Host`` + header or terminates TLS, you need additional configuration: + + * Set ``server_name`` to the name of the server as seen by clients. When not + provided, websockets uses the value of the ``Host`` header. + + * Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling + TLS. Under the hood, this bind the URL map with a ``url_scheme`` of + ``wss://`` instead of ``ws://``. + + There is no need to specify ``websocket=True`` in each rule. It is added + automatically. + + Args: + url_map: Mapping of URL patterns to connection handlers. + server_name: Name of the server as seen by clients. If :obj:`None`, + websockets uses the value of the ``Host`` header. + ssl: Configuration for enabling TLS on the connection. Set it to + :obj:`True` if a reverse proxy terminates TLS connections. + create_router: Factory for the :class:`Router` dispatching requests to + handlers. Set it to a wrapper or a subclass to customize routing. + + """ + url_scheme = "ws" if ssl is None else "wss" + if ssl is not True and ssl is not None: + kwargs["ssl"] = ssl + + if create_router is None: + create_router = Router + + router = create_router(url_map, server_name, url_scheme) + + _process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = kwargs.pop("process_request", None) + if _process_request is None: + process_request: Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] = router.route_request + else: + + async def process_request( + connection: ServerConnection, request: Request + ) -> Response | None: + response = _process_request(connection, request) + if isinstance(response, Awaitable): + response = await response + if response is not None: + return response + return router.route_request(connection, request) + + return serve(router.handler, *args, process_request=process_request, **kwargs) + + +def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, +) -> Awaitable[Server]: + """ + Create a WebSocket Unix server dispatching connections to different handlers. + + :func:`unix_route` combines the behaviors of :func:`route` and + :func:`~websockets.asyncio.server.unix_serve`. + + Args: + url_map: Mapping of URL patterns to connection handlers. + path: File system path to the Unix socket. + + """ + return route(url_map, unix=True, path=path, **kwargs) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index 2e2b78782..ec7fc4383 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -9,7 +9,7 @@ import sys from collections.abc import Awaitable, Generator, Iterable, Sequence from types import TracebackType -from typing import Any, Callable, cast +from typing import Any, Callable, Mapping, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -87,6 +87,8 @@ def __init__( self.server = server self.request_rcvd: asyncio.Future[None] = self.loop.create_future() self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], Awaitable[None]] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() def respond(self, status: StatusLike, text: str) -> Response: """ diff --git a/src/websockets/sync/router.py b/src/websockets/sync/router.py new file mode 100644 index 000000000..33105bf32 --- /dev/null +++ b/src/websockets/sync/router.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import http +import ssl as ssl_module +import urllib.parse +from typing import Any, Callable, Literal + +from werkzeug.exceptions import NotFound +from werkzeug.routing import Map, RequestRedirect + +from ..http11 import Request, Response +from .server import Server, ServerConnection, serve + + +__all__ = ["route", "unix_route", "Router"] + + +class Router: + """WebSocket router supporting :func:`route`.""" + + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + else: + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request( + self, connection: ServerConnection, request: Request + ) -> Response | None: + """Route incoming request.""" + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler, connection.handler_kwargs = handler, kwargs + return None + + def handler(self, connection: ServerConnection) -> None: + """Handle a connection.""" + return connection.handler(connection, **connection.handler_kwargs) + + +def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, +) -> Server: + """ + Create a WebSocket server dispatching connections to different handlers. + + This feature requires the third-party library `werkzeug`_:: + + $ pip install werkzeug + + .. _werkzeug: https://werkzeug.palletsprojects.com/ + + :func:`route` accepts the same arguments as + :func:`~websockets.sync.server.serve`, except as described below. + + The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns + to connection handlers. In addition to the connection, handlers receive + parameters captured in the URL as keyword arguments. + + Here's an example:: + + + from websockets.sync.router import route + from werkzeug.routing import Map, Rule + + def channel_handler(websocket, channel_id): + ... + + url_map = Map([ + Rule("/channel/", endpoint=channel_handler), + ... + ]) + + with route(url_map, ...) as server: + server.serve_forever() + + Refer to the documentation of :mod:`werkzeug.routing` for details. + + If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, + when the server runs behind a reverse proxy that modifies the ``Host`` + header or terminates TLS, you need additional configuration: + + * Set ``server_name`` to the name of the server as seen by clients. When not + provided, websockets uses the value of the ``Host`` header. + + * Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling + TLS. Under the hood, this bind the URL map with a ``url_scheme`` of + ``wss://`` instead of ``ws://``. + + There is no need to specify ``websocket=True`` in each rule. It is added + automatically. + + Args: + url_map: Mapping of URL patterns to connection handlers. + server_name: Name of the server as seen by clients. If :obj:`None`, + websockets uses the value of the ``Host`` header. + ssl: Configuration for enabling TLS on the connection. Set it to + :obj:`True` if a reverse proxy terminates TLS connections. + create_router: Factory for the :class:`Router` dispatching requests to + handlers. Set it to a wrapper or a subclass to customize routing. + + """ + url_scheme = "ws" if ssl is None else "wss" + if ssl is not True and ssl is not None: + kwargs["ssl"] = ssl + + if create_router is None: + create_router = Router + + router = create_router(url_map, server_name, url_scheme) + + _process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = kwargs.pop("process_request", None) + if _process_request is None: + process_request: Callable[ + [ServerConnection, Request], + Response | None, + ] = router.route_request + else: + + def process_request( + connection: ServerConnection, request: Request + ) -> Response | None: + response = _process_request(connection, request) + if response is not None: + return response + return router.route_request(connection, request) + + return serve(router.handler, *args, process_request=process_request, **kwargs) + + +def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, +) -> Server: + """ + Create a WebSocket Unix server dispatching connections to different handlers. + + :func:`unix_route` combines the behaviors of :func:`route` and + :func:`~websockets.sync.server.unix_serve`. + + Args: + url_map: Mapping of URL patterns to connection handlers. + path: File system path to the Unix socket. + + """ + return route(url_map, unix=True, path=path, **kwargs) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 10e3b6816..efb40a7f4 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -13,7 +13,7 @@ import warnings from collections.abc import Iterable, Sequence from types import TracebackType -from typing import Any, Callable, cast +from typing import Any, Callable, Mapping, cast from ..exceptions import InvalidHeader from ..extensions.base import ServerExtensionFactory @@ -82,6 +82,8 @@ def __init__( max_queue=max_queue, ) self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], None] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() def respond(self, status: StatusLike, text: str) -> Response: """ diff --git a/tests/asyncio/server.py b/tests/asyncio/server.py index acf6500c6..b142bcd7e 100644 --- a/tests/asyncio/server.py +++ b/tests/asyncio/server.py @@ -1,5 +1,6 @@ import asyncio import socket +import urllib.parse def get_host_port(server): @@ -9,15 +10,16 @@ def get_host_port(server): raise AssertionError("expected at least one IPv4 socket") -def get_uri(server): - secure = server.server._ssl_context is not None # hack +def get_uri(server, secure=None): + if secure is None: + secure = server.server._ssl_context is not None # hack protocol = "wss" if secure else "ws" host, port = get_host_port(server) return f"{protocol}://{host}:{port}" async def handler(ws): - path = ws.request.path + path = urllib.parse.urlparse(ws.request.path).path if path == "/": # The default path is an eval shell. async for expr in ws: diff --git a/tests/asyncio/test_router.py b/tests/asyncio/test_router.py new file mode 100644 index 000000000..1426cc9f3 --- /dev/null +++ b/tests/asyncio/test_router.py @@ -0,0 +1,198 @@ +import http +import socket +import sys +import unittest +from unittest.mock import patch + +from websockets.asyncio.client import connect, unix_connect +from websockets.asyncio.router import * +from websockets.exceptions import InvalidStatus + +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path +from .server import EvalShellMixin, get_uri, handler +from .utils import alist + + +try: + from werkzeug.routing import Map, Rule +except ImportError: + pass + + +async def echo(websocket, count): + message = await websocket.recv() + for _ in range(count): + await websocket.send(message) + + +@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") +class RouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + # This is a small realistic example of werkzeug's basic URL routing + # features: path matching, parameter extraction, and default values. + + async def test_router_matches_paths_and_extracts_parameters(self): + """Router matches paths and extracts parameters.""" + url_map = Map( + [ + Rule("/echo", defaults={"count": 1}, endpoint=echo), + Rule("/echo/", endpoint=echo), + ] + ) + async with route(url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/echo") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello"]) + + async with connect(get_uri(server) + "/echo/3") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) + + @property # avoids an import-time dependency on werkzeug + def url_map(self): + return Map( + [ + Rule("/", endpoint=handler), + Rule("/r", redirect_to="/"), + ] + ) + + async def test_route_with_query_string(self): + """Router ignores query strings when matching paths.""" + async with route(self.url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/?a=b") as client: + await self.assertEval(client, "ws.request.path", "/?a=b") + + async def test_redirect(self): + """Router redirects connections according to redirect_to.""" + async with route(self.url_map, "localhost", 0) as server: + async with connect(get_uri(server) + "/r") as client: + await self.assertEval(client, "ws.request.path", "/") + + async def test_secure_redirect(self): + """Router redirects connections to a wss:// URI when TLS is enabled.""" + async with route(self.url_map, "localhost", 0, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT) as client: + await self.assertEval(client, "ws.request.path", "/") + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + async def test_force_secure_redirect(self): + """Router redirects ws:// connections to a wss:// URI when ssl=True.""" + async with route(self.url_map, "localhost", 0, ssl=True) as server: + redirect_uri = get_uri(server, secure=True) + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + redirect_uri + "/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + async def test_force_redirect_server_name(self): + """Router redirects connections to the host declared in server_name.""" + async with route(self.url_map, "localhost", 0, server_name="other") as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://other/", + ) + + async def test_not_found(self): + """Router rejects requests to unknown paths with an HTTP 404 error.""" + async with route(self.url_map, "localhost", 0) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/n"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 404", + ) + + async def test_process_request_function_returning_none(self): + """Router supports a process_request function returning None.""" + + def process_request(ws, request): + ws.process_request_ran = True + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + async with connect(get_uri(server) + "/") as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_coroutine_returning_none(self): + """Router supports a process_request coroutine returning None.""" + + async def process_request(ws, request): + ws.process_request_ran = True + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + async with connect(get_uri(server) + "/") as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_function_returning_response(self): + """Router supports a process_request function returning a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_coroutine_returning_response(self): + """Router supports a process_request coroutine returning a response.""" + + async def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async with route( + self.url_map, "localhost", 0, process_request=process_request + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_custom_router_factory(self): + """Router supports a custom router factory.""" + + class MyRouter(Router): + async def handler(self, connection): + connection.my_router_ran = True + return await super().handler(connection) + + async with route( + self.url_map, "localhost", 0, create_router=MyRouter + ) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.my_router_ran", "True") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + async def test_router_supports_unix_sockets(self): + """Router supports Unix sockets.""" + url_map = Map([Rule("/echo/", endpoint=echo)]) + with temp_unix_socket_path() as path: + async with unix_route(url_map, path): + async with unix_connect(path, "ws://localhost/echo/3") as client: + await client.send("hello") + messages = await alist(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) diff --git a/tests/sync/server.py b/tests/sync/server.py index fd7a03d82..cadaa267e 100644 --- a/tests/sync/server.py +++ b/tests/sync/server.py @@ -1,19 +1,22 @@ import contextlib import ssl import threading +import urllib.parse +from websockets.sync.router import * from websockets.sync.server import * -def get_uri(server): - secure = isinstance(server.socket, ssl.SSLSocket) # hack +def get_uri(server, secure=None): + if secure is None: + secure = isinstance(server.socket, ssl.SSLSocket) # hack protocol = "wss" if secure else "ws" host, port = server.socket.getsockname() return f"{protocol}://{host}:{port}" def handler(ws): - path = ws.request.path + path = urllib.parse.urlparse(ws.request.path).path if path == "/": # The default path is an eval shell. for expr in ws: @@ -34,8 +37,14 @@ def assertEval(self, client, expr, value): @contextlib.contextmanager -def run_server(handler=handler, host="localhost", port=0, **kwargs): - with serve(handler, host, port, **kwargs) as server: +def run_server_or_router( + serve_or_route, + handler_or_url_map, + host="localhost", + port=0, + **kwargs, +): + with serve_or_route(handler_or_url_map, host, port, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() @@ -63,9 +72,22 @@ def handler(sock, addr): handler_thread.join() +def run_server(handler=handler, **kwargs): + return run_server_or_router(serve, handler, **kwargs) + + +def run_router(url_map, **kwargs): + return run_server_or_router(route, url_map, **kwargs) + + @contextlib.contextmanager -def run_unix_server(path, handler=handler, **kwargs): - with unix_serve(handler, path, **kwargs) as server: +def run_unix_server_or_router( + path, + unix_serve_or_route, + handler_or_url_map, + **kwargs, +): + with unix_serve_or_route(handler_or_url_map, path, **kwargs) as server: thread = threading.Thread(target=server.serve_forever) thread.start() try: @@ -73,3 +95,11 @@ def run_unix_server(path, handler=handler, **kwargs): finally: server.shutdown() thread.join() + + +def run_unix_server(path, handler=handler, **kwargs): + return run_unix_server_or_router(path, unix_serve, handler, **kwargs) + + +def run_unix_router(path, url_map, **kwargs): + return run_unix_server_or_router(path, unix_route, url_map, **kwargs) diff --git a/tests/sync/test_router.py b/tests/sync/test_router.py new file mode 100644 index 000000000..07274e625 --- /dev/null +++ b/tests/sync/test_router.py @@ -0,0 +1,174 @@ +import http +import socket +import sys +import unittest +from unittest.mock import patch + +from websockets.exceptions import InvalidStatus +from websockets.sync.client import connect, unix_connect +from websockets.sync.router import * + +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path +from .server import EvalShellMixin, get_uri, handler, run_router, run_unix_router + + +try: + from werkzeug.routing import Map, Rule +except ImportError: + pass + + +def echo(websocket, count): + message = websocket.recv() + for _ in range(count): + websocket.send(message) + + +@unittest.skipUnless("werkzeug" in sys.modules, "werkzeug not installed") +class RouterTests(EvalShellMixin, unittest.TestCase): + # This is a small realistic example of werkzeug's basic URL routing + # features: path matching, parameter extraction, and default values. + + def test_router_matches_paths_and_extracts_parameters(self): + """Router matches paths and extracts parameters.""" + url_map = Map( + [ + Rule("/echo", defaults={"count": 1}, endpoint=echo), + Rule("/echo/", endpoint=echo), + ] + ) + with run_router(url_map) as server: + with connect(get_uri(server) + "/echo") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello"]) + + with connect(get_uri(server) + "/echo/3") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) + + @property # avoids an import-time dependency on werkzeug + def url_map(self): + return Map( + [ + Rule("/", endpoint=handler), + Rule("/r", redirect_to="/"), + ] + ) + + def test_route_with_query_string(self): + """Router ignores query strings when matching paths.""" + with run_router(self.url_map) as server: + with connect(get_uri(server) + "/?a=b") as client: + self.assertEval(client, "ws.request.path", "/?a=b") + + def test_redirect(self): + """Router redirects connections according to redirect_to.""" + with run_router(self.url_map, server_name="localhost") as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://localhost/", + ) + + def test_secure_redirect(self): + """Router redirects connections to a wss:// URI when TLS is enabled.""" + with run_router( + self.url_map, server_name="localhost", ssl=SERVER_CONTEXT + ) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r", ssl=CLIENT_CONTEXT): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "wss://localhost/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + def test_force_secure_redirect(self): + """Router redirects ws:// connections to a wss:// URI when ssl=True.""" + with run_router(self.url_map, ssl=True) as server: + redirect_uri = get_uri(server, secure=True) + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + redirect_uri + "/", + ) + + @patch("websockets.asyncio.client.connect.process_redirect", lambda _, exc: exc) + def test_force_redirect_server_name(self): + """Router redirects connections to the host declared in server_name.""" + with run_router(self.url_map, server_name="other") as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/r"): + self.fail("did not raise") + self.assertEqual( + raised.exception.response.headers["Location"], + "ws://other/", + ) + + def test_not_found(self): + """Router rejects requests to unknown paths with an HTTP 404 error.""" + with run_router(self.url_map) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/n"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 404", + ) + + def test_process_request_returning_none(self): + """Router supports a process_request returning None.""" + + def process_request(ws, request): + ws.process_request_ran = True + + with run_router(self.url_map, process_request=process_request) as server: + with connect(get_uri(server) + "/") as client: + self.assertEval(client, "ws.process_request_ran", "True") + + def test_process_request_returning_response(self): + """Router supports a process_request returning a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + with run_router(self.url_map, process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + with connect(get_uri(server) + "/"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + def test_custom_router_factory(self): + """Router supports a custom router factory.""" + + class MyRouter(Router): + def handler(self, connection): + connection.my_router_ran = True + return super().handler(connection) + + with run_router(self.url_map, create_router=MyRouter) as server: + with connect(get_uri(server)) as client: + self.assertEval(client, "ws.my_router_ran", "True") + + +@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") +class UnixRouterTests(EvalShellMixin, unittest.IsolatedAsyncioTestCase): + def test_router_supports_unix_sockets(self): + """Router supports Unix sockets.""" + url_map = Map([Rule("/echo/", endpoint=echo)]) + with temp_unix_socket_path() as path: + with run_unix_router(path, url_map): + with unix_connect(path, "ws://localhost/echo/3") as client: + client.send("hello") + messages = list(client) + self.assertEqual(messages, ["hello", "hello", "hello"]) diff --git a/tests/test_exports.py b/tests/test_exports.py index 88e27e69d..34a470661 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -2,6 +2,7 @@ import websockets import websockets.asyncio.client +import websockets.asyncio.router import websockets.asyncio.server import websockets.client import websockets.datastructures @@ -16,6 +17,7 @@ for name in ( [] + websockets.asyncio.client.__all__ + + websockets.asyncio.router.__all__ + websockets.asyncio.server.__all__ + websockets.client.__all__ + websockets.datastructures.__all__ diff --git a/tox.ini b/tox.ini index 918aeaaec..9450e9714 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ pass_env = deps = py311,py312,py313,coverage,maxi_cov: mitmproxy py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] + werkzeug [testenv:coverage] commands = @@ -47,3 +48,4 @@ commands = deps = mypy python-socks + werkzeug From 983e484cb77a1f0c490aaab6a092d26c38fd4f61 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 4 Feb 2025 22:38:42 +0100 Subject: [PATCH 735/762] Add example of routing. --- example/routing.py | 154 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 example/routing.py diff --git a/example/routing.py b/example/routing.py new file mode 100644 index 000000000..9f2df4980 --- /dev/null +++ b/example/routing.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import time +import zoneinfo + +from websockets.asyncio.router import route +from websockets.exceptions import ConnectionClosed +from werkzeug.routing import BaseConverter, Map, Rule, ValidationError + + +async def clock(websocket, tzinfo): + """Send the current time in the given timezone every second.""" + loop = asyncio.get_running_loop() + loop_offset = (loop.time() - time.time()) % 1 + try: + while True: + # Sleep until the next second according to the wall clock. + await asyncio.sleep(1 - (loop.time() - loop_offset) % 1) + now = datetime.datetime.now(tzinfo).replace(microsecond=0) + await websocket.send(now.isoformat()) + except ConnectionClosed: + return + + +async def alarm(websocket, alarm_at, tzinfo): + """Send the alarm time in the given timezone when it is reached.""" + alarm_at = alarm_at.replace(tzinfo=tzinfo) + now = datetime.datetime.now(tz=datetime.timezone.utc) + + try: + async with asyncio.timeout((alarm_at - now).total_seconds()): + await websocket.wait_closed() + except asyncio.TimeoutError: + try: + await websocket.send(alarm_at.isoformat()) + except ConnectionClosed: + return + + +async def timer(websocket, alarm_after): + """Send the remaining time until the alarm time every second.""" + alarm_at = datetime.datetime.now(tz=datetime.timezone.utc) + alarm_after + loop = asyncio.get_running_loop() + loop_offset = (loop.time() - time.time() + alarm_at.timestamp()) % 1 + + try: + while alarm_after.total_seconds() > 0: + # Sleep until the next second as a delta to the alarm time. + await asyncio.sleep(1 - (loop.time() - loop_offset) % 1) + alarm_after = alarm_at - datetime.datetime.now(tz=datetime.timezone.utc) + # Round up to the next second. + alarm_after += datetime.timedelta( + seconds=1, + microseconds=-alarm_after.microseconds, + ) + await websocket.send(format_timedelta(alarm_after)) + except ConnectionClosed: + return + + +class ZoneInfoConverter(BaseConverter): + regex = r"[A-Za-z0-9_/+-]+" + + def to_python(self, value): + try: + return zoneinfo.ZoneInfo(value) + except zoneinfo.ZoneInfoNotFoundError: + raise ValidationError + + def to_url(self, value): + return value.key + + +class DateTimeConverter(BaseConverter): + regex = r"[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(?:\.[0-9]{3})?" + + def to_python(self, value): + try: + return datetime.datetime.fromisoformat(value) + except ValueError: + raise ValidationError + + def to_url(self, value): + return value.isoformat() + + +class TimeDeltaConverter(BaseConverter): + regex = r"[0-9]{2}:[0-9]{2}:[0-9]{2}(?:\.[0-9]{3}(?:[0-9]{3})?)?" + + def to_python(self, value): + return datetime.timedelta( + hours=int(value[0:2]), + minutes=int(value[3:5]), + seconds=int(value[6:8]), + milliseconds=int(value[9:12]) if len(value) == 12 else 0, + microseconds=int(value[9:15]) if len(value) == 15 else 0, + ) + + def to_url(self, value): + return format_timedelta(value) + + +def format_timedelta(delta): + assert 0 <= delta.seconds < 86400 + hours = delta.seconds // 3600 + minutes = (delta.seconds % 3600) // 60 + seconds = delta.seconds % 60 + if delta.microseconds: + return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{delta.microseconds:06d}" + else: + return f"{hours:02d}:{minutes:02d}:{seconds:02d}" + + +url_map = Map( + [ + Rule( + "/", + redirect_to="/clock", + ), + Rule( + "/clock", + defaults={"tzinfo": datetime.timezone.utc}, + endpoint=clock, + ), + Rule( + "/clock/", + endpoint=clock, + ), + Rule( + "/alarm//", + endpoint=alarm, + ), + Rule( + "/timer/", + endpoint=timer, + ), + ], + converters={ + "tzinfo": ZoneInfoConverter, + "datetime": DateTimeConverter, + "timedelta": TimeDeltaConverter, + }, +) + + +async def main(): + async with route(url_map, "localhost", 8888): + await asyncio.get_running_loop().create_future() # run forever + + +if __name__ == "__main__": + asyncio.run(main()) From a28bb7fff5d4141a805697a4645fe84c4101c718 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 4 Feb 2025 23:12:37 +0100 Subject: [PATCH 736/762] Update references to third-party servers. --- README.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 94cd79ab9..39ae1e8ba 100644 --- a/README.rst +++ b/README.rst @@ -128,11 +128,12 @@ Why shouldn't I use ``websockets``? and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP is minimal — just enough for an HTTP health check. - If you want to do both in the same server, look at HTTP frameworks that - build on top of ``websockets`` to support WebSocket connections, like - Sanic_. + If you want to do both in the same server, look at HTTP + WebSocket servers + that build on top of ``websockets`` to support WebSocket connections, like + uvicorn_ or Sanic_. -.. _Sanic: https://sanicframework.org/en/ +.. _uvicorn: https://www.uvicorn.org/ +.. _Sanic: https://sanic.dev/en/ What else? ---------- From dde3716725332f76d7f62fdedfde172017bd02b8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 4 Feb 2025 23:12:54 +0100 Subject: [PATCH 737/762] There are faster libraries today. --- docs/faq/misc.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/faq/misc.rst b/docs/faq/misc.rst index 4936aa6f3..3b5106006 100644 --- a/docs/faq/misc.rst +++ b/docs/faq/misc.rst @@ -24,7 +24,9 @@ you must disable: * Keepalive: set ``ping_interval=None`` * UTF-8 decoding: send ``bytes`` rather than ``str`` -If websockets is still slower than another Python library, please file a bug. +Then, please consider whether websockets is the bottleneck of the performance +of your application. Usually, in real-world applications, CPU time spent in +websockets is negligible compared to time spent in the application logic. Are there ``onopen``, ``onmessage``, ``onerror``, and ``onclose`` callbacks? ............................................................................ From 0bdfbd1dbe72ce3770428325eb4cfa6f3952773c Mon Sep 17 00:00:00 2001 From: Forrest Li Date: Tue, 4 Feb 2025 18:38:27 -0500 Subject: [PATCH 738/762] Fix incorrect asyncio import in README --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 39ae1e8ba..dc2530c23 100644 --- a/README.rst +++ b/README.rst @@ -47,7 +47,7 @@ Here's an echo server with the ``asyncio`` API: #!/usr/bin/env python import asyncio - from websockets.server import serve + from websockets.asyncio.server import serve async def echo(websocket): async for message in websocket: From 5b516463c9a2ab9cf76ea087f0687cc0c8d54f0c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 7 Feb 2025 10:17:51 +0100 Subject: [PATCH 739/762] Revert "Disable PyPy in CI." This reverts commit 1e62b3c384a9fc345c661ebdd66a2a1a0fbbff52. Fix #1581. --- .github/workflows/tests.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ca73cd499..5ab9c4c72 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -60,13 +60,12 @@ jobs: - "3.11" - "3.12" - "3.13" -# Disable PyPy per https://github.com/python-websockets/websockets/issues/1581 -# - "pypy-3.10" + - "pypy-3.10" is_main: - ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} -# exclude: -# - python: "pypy-3.10" -# is_main: false + exclude: + - python: "pypy-3.10" + is_main: false steps: - name: Check out repository uses: actions/checkout@v4 From 90db2de19feb1bd6147f2335f9facdec9f87e690 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Feb 2025 10:53:51 +0100 Subject: [PATCH 740/762] Document HTTP Digest Auth as a known limitation. Ref #784. --- docs/reference/features.rst | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 0da966ccc..321e2e832 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -125,8 +125,6 @@ Server +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ❌ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | - +------------------------------------+--------+--------+--------+--------+ | Dispatch connections to handlers | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ @@ -165,16 +163,11 @@ Client +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | - | (`#784`_) | | | | | - +------------------------------------+--------+--------+--------+--------+ | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ -.. _#784: https://github.com/python-websockets/websockets/issues/784 - Known limitations ----------------- @@ -188,6 +181,10 @@ Request if it is missing or invalid (`#1246`). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 +The client doesn't support HTTP Digest Authentication (`#784`_). + +.. _#784: https://github.com/python-websockets/websockets/issues/784 + The client API doesn't attempt to guarantee that there is no more than one connection to a given IP address in a CONNECTING state. This behavior is mandated by :rfc:`6455`, section 4.1. However, :func:`~asyncio.client.connect()` From 2dd33c86dfac2861d2adddb6e32eac9d1787a221 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 8 Feb 2025 17:49:02 +0100 Subject: [PATCH 741/762] Add guide for deployment to Koyeb. Fix #1369. --- docs/howto/fly.rst | 2 +- docs/howto/heroku.rst | 4 +- docs/howto/index.rst | 1 + docs/howto/koyeb.rst | 165 ++++++++++++++++++++++ docs/spelling_wordlist.txt | 1 + example/deployment/koyeb/Procfile | 1 + example/deployment/koyeb/app.py | 37 +++++ example/deployment/koyeb/requirements.txt | 1 + 8 files changed, 208 insertions(+), 4 deletions(-) create mode 100644 docs/howto/koyeb.rst create mode 100644 example/deployment/koyeb/Procfile create mode 100644 example/deployment/koyeb/app.py create mode 100644 example/deployment/koyeb/requirements.txt diff --git a/docs/howto/fly.rst b/docs/howto/fly.rst index ed001a2ae..4262e8aeb 100644 --- a/docs/howto/fly.rst +++ b/docs/howto/fly.rst @@ -1,5 +1,5 @@ Deploy to Fly -================ +============= This guide describes how to deploy a websockets server to Fly_. diff --git a/docs/howto/heroku.rst b/docs/howto/heroku.rst index b335e14c5..8d16eccfb 100644 --- a/docs/howto/heroku.rst +++ b/docs/howto/heroku.rst @@ -58,12 +58,10 @@ on websockets: .. literalinclude:: ../../example/deployment/heroku/requirements.txt :language: text -Create a ``Procfile``. +Create a ``Procfile`` to tell Heroku how to run the app. .. literalinclude:: ../../example/deployment/heroku/Procfile -This tells Heroku how to run the app. - Confirm that you created the correct files and commit them to git: .. code-block:: console diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 863c1c63c..619b11fa8 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -48,6 +48,7 @@ Once your application is ready, learn how to deploy it on various platforms. :titlesonly: render + koyeb fly heroku kubernetes diff --git a/docs/howto/koyeb.rst b/docs/howto/koyeb.rst new file mode 100644 index 000000000..1ac17daac --- /dev/null +++ b/docs/howto/koyeb.rst @@ -0,0 +1,165 @@ +Deploy to Koyeb +================ + +This guide describes how to deploy a websockets server to Koyeb_. + +.. _Koyeb: https://www.koyeb.com + +.. admonition:: The free tier of Koyeb is sufficient for trying this guide. + :class: tip + + The `free tier`__ include one web service, which this guide uses. + + __ https://www.koyeb.com/pricing + +We’re going to deploy a very simple app. The process would be identical to a +more realistic app. + +Create repository +----------------- + +Koyeb supports multiple deployment methods. Its quick start guides recommend +git-driven deployment as the first option. Let's initialize a git repository: + +.. code-block:: console + + $ mkdir websockets-echo + $ cd websockets-echo + $ git init -b main + Initialized empty Git repository in websockets-echo/.git/ + $ git commit --allow-empty -m "Initial commit." + [main (root-commit) 740f699] Initial commit. + +Render requires the git repository to be hosted at GitHub. + +Sign up or log in to GitHub. Create a new repository named ``websockets-echo``. +Don't enable any of the initialization options offered by GitHub. Then, follow +instructions for pushing an existing repository from the command line. + +After pushing, refresh your repository's homepage on GitHub. You should see an +empty repository with an empty initial commit. + +Create application +------------------ + +Here’s the implementation of the app, an echo server. Save it in a file +called ``app.py``: + +.. literalinclude:: ../../example/deployment/koyeb/app.py + :language: python + +This app implements typical requirements for running on a Platform as a Service: + +* it listens on the port provided in the ``$PORT`` environment variable; +* it provides a health check at ``/healthz``; +* it closes connections and exits cleanly when it receives a ``SIGTERM`` signal; + while not documented, this is how Koyeb terminates apps. + +Create a ``requirements.txt`` file containing this line to declare a dependency +on websockets: + +.. literalinclude:: ../../example/deployment/koyeb/requirements.txt + :language: text + +Create a ``Procfile`` to tell Koyeb how to run the app. + +.. literalinclude:: ../../example/deployment/koyeb/Procfile + +Confirm that you created the correct files and commit them to git: + +.. code-block:: console + + $ ls + Procfile app.py requirements.txt + $ git add . + $ git commit -m "Initial implementation." + [main f634b8b] Initial implementation. +  3 files changed, 39 insertions(+) +  create mode 100644 Procfile +  create mode 100644 app.py +  create mode 100644 requirements.txt + +The app is ready. Let's deploy it! + +Deploy application +------------------ + +Sign up or log in to Koyeb. + +In the Koyeb control panel, create a web service with GitHub as the deployment +method. Install and authorize Koyeb's GitHub app if you haven't done that yet. + +Follow the steps to create a new service: + +1. Select the ``websockets-echo`` repository in the list of your repositories. +2. Confirm that the **Free** instance type is selected. Click **Next**. +3. Configure health checks: change the protocol from TCP to HTTP and set the + path to ``/healthz``. Review other settings; defaults should be correct. + Click **Deploy**. + +Koyeb builds the app, deploys it, verifies that the health checks passes, and +makes the deployment active. + +Validate deployment +------------------- + +Let's confirm that your application is running as expected. + +Since it's a WebSocket server, you need a WebSocket client, such as the +interactive client that comes with websockets. + +If you're currently building a websockets server, perhaps you're already in a +virtualenv where websockets is installed. If not, you can install it in a new +virtualenv as follows: + +.. code-block:: console + + $ python -m venv websockets-client + $ . websockets-client/bin/activate + $ pip install websockets + +Look for the URL of your app in the Koyeb control panel. It looks like +``https://--.koyeb.app/``. Connect the +interactive client — you must replace ``https`` with ``wss`` in the URL: + +.. code-block:: console + + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. + > + +Great! Your app is running! + +Once you're connected, you can send any message and the server will echo it, +or press Ctrl-D to terminate the connection: + +.. code-block:: console + + > Hello! + < Hello! + Connection closed: 1000 (OK). + + +You can also confirm that your application shuts down gracefully. + +Connect an interactive client again: + +.. code-block:: console + + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. + > + +In the Koyeb control panel, go to the **Settings** tab, click **Pause**, and +confirm. + +Eventually, the connection gets closed with code 1001 (going away). + +.. code-block:: console + + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. + Connection closed: 1001 (going away). + +If graceful shutdown wasn't working, the server wouldn't perform a closing +handshake and the connection would be closed with code 1006 (abnormal closure). diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 11b13250a..c841bda2d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -38,6 +38,7 @@ iterable js keepalive KiB +Koyeb kubernetes lifecycle linkerd diff --git a/example/deployment/koyeb/Procfile b/example/deployment/koyeb/Procfile new file mode 100644 index 000000000..2e35818f6 --- /dev/null +++ b/example/deployment/koyeb/Procfile @@ -0,0 +1 @@ +web: python app.py diff --git a/example/deployment/koyeb/app.py b/example/deployment/koyeb/app.py new file mode 100644 index 000000000..978467a30 --- /dev/null +++ b/example/deployment/koyeb/app.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +import asyncio +import http +import os +import signal + +from websockets.asyncio.server import serve + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + + +async def main(): + # Set the stop condition when receiving SIGINT. + loop = asyncio.get_running_loop() + stop = loop.create_future() + loop.add_signal_handler(signal.SIGINT, stop.set_result, None) + + async with serve( + echo, + host="", + port=int(os.environ["PORT"]), + process_request=health_check, + ): + await stop + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/deployment/koyeb/requirements.txt b/example/deployment/koyeb/requirements.txt new file mode 100644 index 000000000..14774b465 --- /dev/null +++ b/example/deployment/koyeb/requirements.txt @@ -0,0 +1 @@ +websockets From 277f0f80e00fca87a9dbda8f5766d8dbf1347eda Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 16:24:39 +0100 Subject: [PATCH 742/762] Switch tutorial to Koyeb. Fix #1299. --- docs/howto/koyeb.rst | 14 +-- docs/intro/tutorial3.rst | 167 +++++++++++++++++---------------- docs/spelling_wordlist.txt | 3 +- example/tutorial/step3/app.py | 8 +- example/tutorial/step3/main.js | 2 +- 5 files changed, 101 insertions(+), 93 deletions(-) diff --git a/docs/howto/koyeb.rst b/docs/howto/koyeb.rst index 1ac17daac..0ad126dd8 100644 --- a/docs/howto/koyeb.rst +++ b/docs/howto/koyeb.rst @@ -119,13 +119,13 @@ virtualenv as follows: $ pip install websockets Look for the URL of your app in the Koyeb control panel. It looks like -``https://--.koyeb.app/``. Connect the +``https://--.koyeb.app/``. Connect the interactive client — you must replace ``https`` with ``wss`` in the URL: .. code-block:: console - $ python -m websockets wss://--.koyeb.app/ - Connected to wss://--.koyeb.app/. + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. > Great! Your app is running! @@ -146,8 +146,8 @@ Connect an interactive client again: .. code-block:: console - $ python -m websockets wss://--.koyeb.app/ - Connected to wss://--.koyeb.app/. + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. > In the Koyeb control panel, go to the **Settings** tab, click **Pause**, and @@ -157,8 +157,8 @@ Eventually, the connection gets closed with code 1001 (going away). .. code-block:: console - $ python -m websockets wss://--.koyeb.app/ - Connected to wss://--.koyeb.app/. + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. Connection closed: 1001 (going away). If graceful shutdown wasn't working, the server wouldn't perform a closing diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 21d51371b..9356fdbe1 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -28,14 +28,20 @@ and a WebSocket server on ``ws://localhost:8001/`` with: Now you want to deploy these servers on the Internet. There's a vast range of hosting providers to choose from. For the sake of simplicity, we'll rely on: -* GitHub Pages for the HTTP server; -* Heroku for the WebSocket server. +* `GitHub Pages`_ for the HTTP server; +* Koyeb_ for the WebSocket server. + +.. _GitHub Pages: https://pages.github.com/ +.. _Koyeb: https://www.koyeb.com/ + +Koyeb is a modern Platform as a Service provider whose free tier allows you to +run a web application, including a WebSocket server. Commit project to git --------------------- Perhaps you committed your work to git while you were progressing through the -tutorial. If you didn't, now is a good time, because GitHub and Heroku offer +tutorial. If you didn't, now is a good time, because GitHub and Koyeb offer git-based deployment workflows. Initialize a git repository: @@ -45,7 +51,7 @@ Initialize a git repository: $ git init -b main Initialized empty Git repository in websockets-tutorial/.git/ $ git commit --allow-empty -m "Initial commit." - [main (root-commit) ...] Initial commit. + [main (root-commit) 8195c1d] Initial commit. Add all files and commit: @@ -53,7 +59,7 @@ Add all files and commit: $ git add . $ git commit -m "Initial implementation of Connect Four game." - [main ...] Initial implementation of Connect Four game. + [main 7f0b2c4] Initial implementation of Connect Four game. 6 files changed, 500 insertions(+) create mode 100644 app.py create mode 100644 connect4.css @@ -62,32 +68,58 @@ Add all files and commit: create mode 100644 index.html create mode 100644 main.js -Prepare the WebSocket server ----------------------------- +Sign up or log in to GitHub. -Before you deploy the server, you must adapt it to meet requirements of -Heroku's runtime. This involves two small changes: +Create a new repository. Set the repository name to ``websockets-tutorial``, +the visibility to Public, and click **Create repository**. -1. Heroku expects the server to `listen on a specific port`_, provided in the - ``$PORT`` environment variable. +Push your code to this repository. You must replace ``python-websockets`` by +your GitHub username in the following command: -2. Heroku sends a ``SIGTERM`` signal when `shutting down a dyno`_, which - should trigger a clean exit. +.. code-block:: console -.. _listen on a specific port: https://devcenter.heroku.com/articles/preparing-a-codebase-for-heroku-deployment#4-listen-on-the-correct-port + $ git remote add origin git@github.com:python-websockets/websockets-tutorial.git + $ git branch -M main + $ git push -u origin main + ... + To github.com:python-websockets/websockets-tutorial.git + * [new branch] main -> main + Branch 'main' set up to track remote branch 'main' from 'origin'. + +Adapt the WebSocket server +-------------------------- -.. _shutting down a dyno: https://devcenter.heroku.com/articles/dynos#shutdown +Before you deploy the server, you must adapt it for Koyeb's environment. This +involves three small changes: + +1. Koyeb provides the port on which the server should listen in the ``$PORT`` + environment variable. + +2. Koyeb requires a health check to verify that the server is running. We'll add + a HTTP health check. + +3. Koyeb sends a ``SIGTERM`` signal when terminating the server. We'll catch it + and trigger a clean exit. Adapt the ``main()`` coroutine accordingly: .. code-block:: python + import http import os import signal +.. literalinclude:: ../../example/tutorial/step3/app.py + :pyobject: health_check + .. literalinclude:: ../../example/tutorial/step3/app.py :pyobject: main +The ``process_request`` parameter of :func:`~asyncio.server.serve` is a callback +that runs for each request. When it returns an HTTP response, websockets sends +that response instead of opening a WebSocket connection. Here, requests to +``/healthz`` return an HTTP 200 status code. + To catch the ``SIGTERM`` signal, ``main()`` creates a :class:`~asyncio.Future` called ``stop`` and registers a signal handler that sets the result of this future. The value of the future doesn't matter; it's only for waiting for @@ -97,8 +129,6 @@ Then, by using :func:`~asyncio.server.serve` as a context manager and exiting the context when ``stop`` has a result, ``main()`` ensures that the server closes connections cleanly and exits on ``SIGTERM``. -The app is now fully compatible with Heroku. - Deploy the WebSocket server --------------------------- @@ -108,12 +138,12 @@ when building the image: .. literalinclude:: ../../example/tutorial/step3/requirements.txt :language: text -.. admonition:: Heroku treats ``requirements.txt`` as a signal to `detect a Python app`_. +.. admonition:: Koyeb treats ``requirements.txt`` as a signal to `detect a Python app`__. :class: tip That's why you don't need to declare that you need a Python runtime. -.. _detect a Python app: https://devcenter.heroku.com/articles/python-support#recognizing-a-python-app + __ https://www.koyeb.com/docs/build-and-deploy/build-from-git/python#detection Create a ``Procfile`` file with this content to configure the command for running the server: @@ -121,66 +151,52 @@ running the server: .. literalinclude:: ../../example/tutorial/step3/Procfile :language: text -Commit your changes: +Commit and push your changes: .. code-block:: console $ git add . - $ git commit -m "Deploy to Heroku." - [main ...] Deploy to Heroku. - 3 files changed, 12 insertions(+), 2 deletions(-) + $ git commit -m "Deploy to Koyeb." + [main ac96d65] Deploy to Koyeb. + 3 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 Procfile create mode 100644 requirements.txt + $ git push + ... + To github.com:python-websockets/websockets-tutorial.git + + 6bd6032...ac96d65 main -> main -Follow the `set-up instructions`_ to install the Heroku CLI and to log in, if -you haven't done that yet. - -.. _set-up instructions: https://devcenter.heroku.com/articles/getting-started-with-python#set-up - -Create a Heroku app. You must choose a unique name and replace -``websockets-tutorial`` by this name in the following command: +Sign up or log in to Koyeb. -.. code-block:: console - - $ heroku create websockets-tutorial - Creating ⬢ websockets-tutorial... done - https://websockets-tutorial.herokuapp.com/ | https://git.heroku.com/websockets-tutorial.git - -If you reuse a name that someone else already uses, you will receive this -error; if this happens, try another name: +In the Koyeb control panel, create a web service with GitHub as the deployment +method. `Install and authorize Koyeb's GitHub app`__ if you haven't done that yet. -.. code-block:: console - - $ heroku create websockets-tutorial - Creating ⬢ websockets-tutorial... ! - ▸ Name websockets-tutorial is already taken +__ https://www.koyeb.com/docs/build-and-deploy/deploy-with-git#connect-your-github-account-to-koyeb -Deploy by pushing the code to Heroku: +Follow the steps to create a new service: -.. code-block:: console +1. Select the ``websockets-tutorial`` repository in the list of your repositories. +2. Confirm that the **Free** instance type is selected. Click **Next**. +3. Configure health checks: change the protocol from TCP to HTTP and set the + path to ``/healthz``. Review other settings; defaults should be correct. + Click **Deploy**. - $ git push heroku - - ... lots of output... - - remote: Released v1 - remote: https://websockets-tutorial.herokuapp.com/ deployed to Heroku - remote: - remote: Verifying deploy... done. - To https://git.heroku.com/websockets-tutorial.git - * [new branch] main -> main +Koyeb builds the app, deploys it, verifies that the health checks passes, and +makes the deployment active. You can test the WebSocket server with the interactive client exactly like you -did in the first part of the tutorial. Replace ``websockets-tutorial`` by the -name of your app in the following command: +did in the first part of the tutorial. The Koyeb control panel provides the URL +of your app in the format: ``https://--.koyeb.app/``. Replace +``https`` with ``wss`` in the URL and connect the interactive client: .. code-block:: console - $ python -m websockets wss://websockets-tutorial.herokuapp.com/ - Connected to wss://websockets-tutorial.herokuapp.com/. + $ python -m websockets wss://--.koyeb.app/ + Connected to wss://--.koyeb.app/. > {"type": "init"} < {"type": "init", "join": "54ICxFae_Ip7TJE2", "watch": "634w44TblL5Dbd9a"} - Connection closed: 1000 (OK). + +Press Ctrl-D to terminate the connection. It works! @@ -199,7 +215,7 @@ You can take this strategy one step further by checking the address of the HTTP server and determining the address of the WebSocket server accordingly. Add this function to ``main.js``; replace ``python-websockets`` by your GitHub -username and ``websockets-tutorial`` by the name of your app on Heroku: +username and ``websockets-tutorial`` by the name of your app on Koyeb: .. literalinclude:: ../../example/tutorial/step3/main.js :language: js @@ -218,42 +234,27 @@ Commit your changes: $ git add . $ git commit -m "Configure WebSocket server address." - [main ...] Configure WebSocket server address. + [main 0903526] Configure WebSocket server address. 1 file changed, 11 insertions(+), 1 deletion(-) + $ git push + ... + To github.com:python-websockets/websockets-tutorial.git + + ac96d65...0903526 main -> main Deploy the web application -------------------------- -Go to GitHub and create a new repository called ``websockets-tutorial``. - -Push your code to this repository. You must replace ``python-websockets`` by -your GitHub username in the following command: - -.. code-block:: console - - $ git remote add origin git@github.com:python-websockets/websockets-tutorial.git - $ git push -u origin main - Enumerating objects: 11, done. - Counting objects: 100% (11/11), done. - Delta compression using up to 8 threads - Compressing objects: 100% (10/10), done. - Writing objects: 100% (11/11), 5.90 KiB | 2.95 MiB/s, done. - Total 11 (delta 0), reused 0 (delta 0), pack-reused 0 - To github.com:/websockets-tutorial.git - * [new branch] main -> main - Branch 'main' set up to track remote branch 'main' from 'origin'. - Go back to GitHub, open the Settings tab of the repository and select Pages in the menu. Select the main branch as source and click Save. GitHub tells you that your site is published. -Follow the link and start a game! +Open https://.github.io/websockets-tutorial/ and start a game! Summary ------- In this third part of the tutorial, you learned how to deploy a WebSocket -application with Heroku. +application with Koyeb. You can start a Connect Four game, send the JOIN link to a friend, and play over the Internet! diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index c841bda2d..dd32a78c3 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -27,6 +27,7 @@ Dockerfile dyno formatter fractalideas +github gunicorn healthz html @@ -38,7 +39,7 @@ iterable js keepalive KiB -Koyeb +koyeb kubernetes lifecycle linkerd diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 261057f9a..335fd48d3 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio +import http import json import os import secrets @@ -183,6 +184,11 @@ async def handler(websocket): await start(websocket) +def health_check(connection, request): + if request.path == "/healthz": + return connection.respond(http.HTTPStatus.OK, "OK\n") + + async def main(): # Set the stop condition when receiving SIGTERM. loop = asyncio.get_running_loop() @@ -190,7 +196,7 @@ async def main(): loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) port = int(os.environ.get("PORT", "8001")) - async with serve(handler, "", port): + async with serve(handler, "", port, process_request=health_check): await stop diff --git a/example/tutorial/step3/main.js b/example/tutorial/step3/main.js index 3000fa2f7..3a7a0db49 100644 --- a/example/tutorial/step3/main.js +++ b/example/tutorial/step3/main.js @@ -2,7 +2,7 @@ import { createBoard, playMove } from "./connect4.js"; function getWebSocketServer() { if (window.location.host === "python-websockets.github.io") { - return "wss://websockets-tutorial.herokuapp.com/"; + return "wss://websockets-tutorial.koyeb.app/"; } else if (window.location.host === "localhost:8000") { return "ws://localhost:8001/"; } else { From 755eac176700548d430ef580c4ad1128a096674a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 17:28:02 +0100 Subject: [PATCH 743/762] Simplify pattern for shutting down servers. This avoids having to create manually a future. Also, it's robust to receiving multiple times the signal. Fix #1593. --- docs/faq/client.rst | 2 +- docs/faq/server.rst | 2 +- docs/howto/haproxy.rst | 2 +- docs/howto/nginx.rst | 2 +- docs/intro/tutorial3.rst | 20 ++++++++------------ docs/topics/deployment.rst | 3 +-- example/deployment/fly/app.py | 16 ++++------------ example/deployment/haproxy/app.py | 16 +++++----------- example/deployment/heroku/app.py | 16 +++++----------- example/deployment/koyeb/app.py | 17 +++++------------ example/deployment/kubernetes/app.py | 16 ++++------------ example/deployment/nginx/app.py | 15 +++++---------- example/deployment/render/app.py | 16 ++++------------ example/deployment/supervisor/app.py | 16 ++++------------ example/faq/shutdown_client.py | 3 +-- example/faq/shutdown_server.py | 16 +++++++--------- example/tutorial/step3/app.py | 11 ++++------- experiments/compression/server.py | 19 ++++++++----------- 18 files changed, 69 insertions(+), 139 deletions(-) diff --git a/docs/faq/client.rst b/docs/faq/client.rst index c39e588ca..cf27fcd45 100644 --- a/docs/faq/client.rst +++ b/docs/faq/client.rst @@ -103,7 +103,7 @@ You can close the connection. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_client.py - :emphasize-lines: 11-13 + :emphasize-lines: 10-12 How do I disable TLS/SSL certificate verification? -------------------------------------------------- diff --git a/docs/faq/server.rst b/docs/faq/server.rst index d00dcafba..bb04c5e1c 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -310,7 +310,7 @@ Exit the :func:`~serve` context manager. Here's an example that terminates cleanly when it receives SIGTERM on Unix: .. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 13-16,19 + :emphasize-lines: 14-16 How do I stop a server while keeping existing connections open? --------------------------------------------------------------- diff --git a/docs/howto/haproxy.rst b/docs/howto/haproxy.rst index fdaab0401..5ffe21ea3 100644 --- a/docs/howto/haproxy.rst +++ b/docs/howto/haproxy.rst @@ -15,7 +15,7 @@ Run server processes Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/haproxy/app.py - :emphasize-lines: 24 + :language: python Each server process listens on a different port by extracting an incremental index from an environment variable set by Supervisor. diff --git a/docs/howto/nginx.rst b/docs/howto/nginx.rst index 872353cad..5b37c9b36 100644 --- a/docs/howto/nginx.rst +++ b/docs/howto/nginx.rst @@ -15,7 +15,7 @@ Run server processes Save this app to ``app.py``: .. literalinclude:: ../../example/deployment/nginx/app.py - :emphasize-lines: 21,23 + :language: python We'd like nginx to connect to websockets servers via Unix sockets in order to avoid the overhead of TCP for communicating between processes running in the diff --git a/docs/intro/tutorial3.rst b/docs/intro/tutorial3.rst index 9356fdbe1..bdd1bd50b 100644 --- a/docs/intro/tutorial3.rst +++ b/docs/intro/tutorial3.rst @@ -120,14 +120,10 @@ that runs for each request. When it returns an HTTP response, websockets sends that response instead of opening a WebSocket connection. Here, requests to ``/healthz`` return an HTTP 200 status code. -To catch the ``SIGTERM`` signal, ``main()`` creates a :class:`~asyncio.Future` -called ``stop`` and registers a signal handler that sets the result of this -future. The value of the future doesn't matter; it's only for waiting for -``SIGTERM``. - -Then, by using :func:`~asyncio.server.serve` as a context manager and exiting -the context when ``stop`` has a result, ``main()`` ensures that the server -closes connections cleanly and exits on ``SIGTERM``. +``main()`` registers a signal handler that closes the server when receiving the +``SIGTERM`` signal. Then, it waits for the server to be closed. Additionally, +using :func:`~asyncio.server.serve` as a context manager ensures that the server +will always be closed cleanly, even if the program crashes. Deploy the WebSocket server --------------------------- @@ -157,14 +153,14 @@ Commit and push your changes: $ git add . $ git commit -m "Deploy to Koyeb." - [main ac96d65] Deploy to Koyeb. - 3 files changed, 18 insertions(+), 2 deletions(-) + [main 4a4b6e9] Deploy to Koyeb. + 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 Procfile create mode 100644 requirements.txt $ git push ... To github.com:python-websockets/websockets-tutorial.git - + 6bd6032...ac96d65 main -> main + + 6bd6032...4a4b6e9 main -> main Sign up or log in to Koyeb. @@ -239,7 +235,7 @@ Commit your changes: $ git push ... To github.com:python-websockets/websockets-tutorial.git - + ac96d65...0903526 main -> main + + 4a4b6e9...968eaaa main -> main Deploy the web application -------------------------- diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst index 48ef72b56..00d1f9285 100644 --- a/docs/topics/deployment.rst +++ b/docs/topics/deployment.rst @@ -101,8 +101,7 @@ Here's an example: :emphasize-lines: 13-16,19 When exiting the context manager, :func:`~asyncio.server.serve` closes all -connections -with code 1001 (going away). As a consequence: +connections with code 1001 (going away). As a consequence: * If the connection handler is awaiting :meth:`~asyncio.server.ServerConnection.recv`, it receives a diff --git a/example/deployment/fly/app.py b/example/deployment/fly/app.py index c8e6af4f9..a841831cf 100644 --- a/example/deployment/fly/app.py +++ b/example/deployment/fly/app.py @@ -18,18 +18,10 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="", - port=8080, - process_request=health_check, - ): - await stop + async with serve(echo, "", 8080, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/haproxy/app.py b/example/deployment/haproxy/app.py index ef6d9c42d..6596c9f32 100644 --- a/example/deployment/haproxy/app.py +++ b/example/deployment/haproxy/app.py @@ -13,17 +13,11 @@ async def echo(websocket): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="localhost", - port=8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]), - ): - await stop + port = 8000 + int(os.environ["SUPERVISOR_PROCESS_NAME"][-2:]) + async with serve(echo, "localhost", port) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/heroku/app.py b/example/deployment/heroku/app.py index 17ad09d26..524fb35f8 100644 --- a/example/deployment/heroku/app.py +++ b/example/deployment/heroku/app.py @@ -13,17 +13,11 @@ async def echo(websocket): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="", - port=int(os.environ["PORT"]), - ): - await stop + port = int(os.environ["PORT"]) + async with serve(echo, "localhost", port) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/koyeb/app.py b/example/deployment/koyeb/app.py index 978467a30..4bfbee793 100644 --- a/example/deployment/koyeb/app.py +++ b/example/deployment/koyeb/app.py @@ -19,18 +19,11 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGINT. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGINT, stop.set_result, None) - - async with serve( - echo, - host="", - port=int(os.environ["PORT"]), - process_request=health_check, - ): - await stop + port = int(os.environ["PORT"]) + async with serve(echo, "", port, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGINT, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/kubernetes/app.py b/example/deployment/kubernetes/app.py index 387f0ade1..95125773d 100755 --- a/example/deployment/kubernetes/app.py +++ b/example/deployment/kubernetes/app.py @@ -31,18 +31,10 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - slow_echo, - host="", - port=80, - process_request=health_check, - ): - await stop + async with serve(slow_echo, "", 80, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/nginx/app.py b/example/deployment/nginx/app.py index 134070f61..4b3ad9b13 100644 --- a/example/deployment/nginx/app.py +++ b/example/deployment/nginx/app.py @@ -13,16 +13,11 @@ async def echo(websocket): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with unix_serve( - echo, - path=f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock", - ): - await stop + path = f"{os.environ['SUPERVISOR_PROCESS_NAME']}.sock" + async with unix_serve(echo, path) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/render/app.py b/example/deployment/render/app.py index c8e6af4f9..a841831cf 100644 --- a/example/deployment/render/app.py +++ b/example/deployment/render/app.py @@ -18,18 +18,10 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="", - port=8080, - process_request=health_check, - ): - await stop + async with serve(echo, "", 8080, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/deployment/supervisor/app.py b/example/deployment/supervisor/app.py index 5e69f16a6..1ca70bdc0 100644 --- a/example/deployment/supervisor/app.py +++ b/example/deployment/supervisor/app.py @@ -12,18 +12,10 @@ async def echo(websocket): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve( - echo, - host="", - port=8080, - reuse_port=True, - ): - await stop + async with serve(echo, "", 8080, reuse_port=True) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/example/faq/shutdown_client.py b/example/faq/shutdown_client.py index 5c8bd8cbe..3280c6f9b 100755 --- a/example/faq/shutdown_client.py +++ b/example/faq/shutdown_client.py @@ -6,8 +6,7 @@ from websockets.asyncio.client import connect async def client(): - uri = "ws://localhost:8765" - async with connect(uri) as websocket: + async with connect("ws://localhost:8765") as websocket: # Close the connection when receiving SIGTERM. loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGTERM, loop.create_task, websocket.close()) diff --git a/example/faq/shutdown_server.py b/example/faq/shutdown_server.py index 3f7bc5732..ea00e2520 100755 --- a/example/faq/shutdown_server.py +++ b/example/faq/shutdown_server.py @@ -5,17 +5,15 @@ from websockets.asyncio.server import serve -async def echo(websocket): +async def handler(websocket): async for message in websocket: - await websocket.send(message) + ... async def server(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - async with serve(echo, "localhost", 8765): - await stop + async with serve(handler, "localhost", 8765) as server: + # Close the server when receiving SIGTERM. + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() asyncio.run(server()) diff --git a/example/tutorial/step3/app.py b/example/tutorial/step3/app.py index 335fd48d3..8a285e92e 100644 --- a/example/tutorial/step3/app.py +++ b/example/tutorial/step3/app.py @@ -190,14 +190,11 @@ def health_check(connection, request): async def main(): - # Set the stop condition when receiving SIGTERM. - loop = asyncio.get_running_loop() - stop = loop.create_future() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - port = int(os.environ.get("PORT", "8001")) - async with serve(handler, "", port, process_request=health_check): - await stop + async with serve(handler, "", port, process_request=health_check) as server: + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + await server.wait_closed() if __name__ == "__main__": diff --git a/experiments/compression/server.py b/experiments/compression/server.py index 1c28f7355..dd399a29f 100644 --- a/experiments/compression/server.py +++ b/experiments/compression/server.py @@ -35,15 +35,6 @@ async def handler(ws): async def server(): - loop = asyncio.get_running_loop() - stop = loop.create_future() - - # Set the stop condition when receiving SIGTERM. - print("Stop the server with:") - print(f"kill -TERM {os.getpid()}") - print() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - async with serve( handler, "localhost", @@ -55,9 +46,15 @@ async def server(): compress_settings={"memLevel": ML}, ) ], - ): + ) as server: + print("Stop the server with:") + print(f"kill -TERM {os.getpid()}") + print() + loop = asyncio.get_running_loop() + loop.add_signal_handler(signal.SIGTERM, server.close) + tracemalloc.start() - await stop + await server.wait_closed() asyncio.run(server()) From 9f01cefde6f0d55de8bd2738a9601ddf661a1db1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 18:44:28 +0100 Subject: [PATCH 744/762] Simplify pattern for running servers forever. --- README.rst | 4 ++-- docs/intro/tutorial1.rst | 4 ++-- example/django/authentication.py | 4 ++-- example/faq/health_check_server.py | 4 ++-- example/quickstart/counter.py | 4 ++-- example/quickstart/server.py | 4 ++-- example/quickstart/server_secure.py | 4 ++-- example/quickstart/show_time.py | 4 ++-- example/routing.py | 4 ++-- example/tutorial/step1/app.py | 4 ++-- example/tutorial/step2/app.py | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/README.rst b/README.rst index dc2530c23..cc47b2910 100644 --- a/README.rst +++ b/README.rst @@ -54,8 +54,8 @@ Here's an echo server with the ``asyncio`` API: await websocket.send(message) async def main(): - async with serve(echo, "localhost", 8765): - await asyncio.get_running_loop().create_future() # run forever + async with serve(echo, "localhost", 8765) as server: + await server.serve_forever() asyncio.run(main()) diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index 87074caee..cc88e6a4a 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -194,8 +194,8 @@ Create an ``app.py`` file next to ``connect4.py`` with this content: async def main(): - async with serve(handler, "", 8001): - await asyncio.get_running_loop().create_future() # run forever + async with serve(handler, "", 8001) as server: + await server.serve_forever() if __name__ == "__main__": diff --git a/example/django/authentication.py b/example/django/authentication.py index c4f12a3f8..e61d70432 100644 --- a/example/django/authentication.py +++ b/example/django/authentication.py @@ -22,8 +22,8 @@ async def handler(websocket): async def main(): - async with serve(handler, "localhost", 8888): - await asyncio.get_running_loop().create_future() # run forever + async with serve(handler, "localhost", 8888) as server: + await server.serve_forever() if __name__ == "__main__": diff --git a/example/faq/health_check_server.py b/example/faq/health_check_server.py index 30623a4bb..3fdffb501 100755 --- a/example/faq/health_check_server.py +++ b/example/faq/health_check_server.py @@ -13,7 +13,7 @@ async def echo(websocket): await websocket.send(message) async def main(): - async with serve(echo, "localhost", 8765, process_request=health_check): - await asyncio.get_running_loop().create_future() # run forever + async with serve(echo, "localhost", 8765, process_request=health_check) as server: + await server.serve_forever() asyncio.run(main()) diff --git a/example/quickstart/counter.py b/example/quickstart/counter.py index 91eedc56a..8f0ff81be 100755 --- a/example/quickstart/counter.py +++ b/example/quickstart/counter.py @@ -42,8 +42,8 @@ async def counter(websocket): broadcast(USERS, users_event()) async def main(): - async with serve(counter, "localhost", 6789): - await asyncio.get_running_loop().create_future() # run forever + async with serve(counter, "localhost", 6789) as server: + await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server.py b/example/quickstart/server.py index bde5e6126..a01f91703 100755 --- a/example/quickstart/server.py +++ b/example/quickstart/server.py @@ -14,8 +14,8 @@ async def hello(websocket): print(f">>> {greeting}") async def main(): - async with serve(hello, "localhost", 8765): - await asyncio.get_running_loop().create_future() # run forever + async with serve(hello, "localhost", 8765) as server: + await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/server_secure.py b/example/quickstart/server_secure.py index 8b456ed6e..92c6629b5 100755 --- a/example/quickstart/server_secure.py +++ b/example/quickstart/server_secure.py @@ -20,8 +20,8 @@ async def hello(websocket): ssl_context.load_cert_chain(localhost_pem) async def main(): - async with serve(hello, "localhost", 8765, ssl=ssl_context): - await asyncio.get_running_loop().create_future() # run forever + async with serve(hello, "localhost", 8765, ssl=ssl_context) as server: + await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) diff --git a/example/quickstart/show_time.py b/example/quickstart/show_time.py index 8aeb811db..ecb908f30 100755 --- a/example/quickstart/show_time.py +++ b/example/quickstart/show_time.py @@ -13,8 +13,8 @@ async def show_time(websocket): await asyncio.sleep(random.random() * 2 + 1) async def main(): - async with serve(show_time, "localhost", 5678): - await asyncio.get_running_loop().create_future() # run forever + async with serve(show_time, "localhost", 5678) as server: + await server.serve_forever() if __name__ == "__main__": asyncio.run(main()) diff --git a/example/routing.py b/example/routing.py index 9f2df4980..7fc4ad4b3 100644 --- a/example/routing.py +++ b/example/routing.py @@ -146,8 +146,8 @@ def format_timedelta(delta): async def main(): - async with route(url_map, "localhost", 8888): - await asyncio.get_running_loop().create_future() # run forever + async with route(url_map, "localhost", 8888) as server: + await server.serve_forever() if __name__ == "__main__": diff --git a/example/tutorial/step1/app.py b/example/tutorial/step1/app.py index 595a10dc7..bc8f02484 100644 --- a/example/tutorial/step1/app.py +++ b/example/tutorial/step1/app.py @@ -57,8 +57,8 @@ async def handler(websocket): async def main(): - async with serve(handler, "", 8001): - await asyncio.get_running_loop().create_future() # run forever + async with serve(handler, "", 8001) as server: + await server.serve_forever() if __name__ == "__main__": diff --git a/example/tutorial/step2/app.py b/example/tutorial/step2/app.py index ef3dd9483..fe50fb3af 100644 --- a/example/tutorial/step2/app.py +++ b/example/tutorial/step2/app.py @@ -182,8 +182,8 @@ async def handler(websocket): async def main(): - async with serve(handler, "", 8001): - await asyncio.get_running_loop().create_future() # run forever + async with serve(handler, "", 8001) as server: + await server.serve_forever() if __name__ == "__main__": From 930defe196547aba051dbac8d409d6008a86f454 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 18:57:13 +0100 Subject: [PATCH 745/762] Fix signal name. --- example/deployment/koyeb/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/deployment/koyeb/app.py b/example/deployment/koyeb/app.py index 4bfbee793..62ba9d843 100644 --- a/example/deployment/koyeb/app.py +++ b/example/deployment/koyeb/app.py @@ -22,7 +22,7 @@ async def main(): port = int(os.environ["PORT"]) async with serve(echo, "", port, process_request=health_check) as server: loop = asyncio.get_running_loop() - loop.add_signal_handler(signal.SIGINT, server.close) + loop.add_signal_handler(signal.SIGTERM, server.close) await server.wait_closed() From 281a5b39b5529e40946f5e54d6c205acb7cfefda Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 20:14:17 +0100 Subject: [PATCH 746/762] Remove cheat sheet. The information wasn't presented in a useful format and some of it was out of date. Refs #1209. --- docs/faq/common.rst | 32 ++++++++++++++- docs/faq/index.rst | 10 +++-- docs/howto/cheatsheet.rst | 86 --------------------------------------- docs/howto/index.rst | 1 - docs/topics/logging.rst | 2 +- 5 files changed, 38 insertions(+), 93 deletions(-) delete mode 100644 docs/howto/cheatsheet.rst diff --git a/docs/faq/common.rst b/docs/faq/common.rst index ba7a95932..b31d4504f 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -3,6 +3,33 @@ Both sides .. currentmodule:: websockets.asyncio.connection +.. _enable-debug-logs: + +How do I enable debug logs? +--------------------------- + +You can enable debug logs to see exactly what websockets is doing. + +If logging isn't configured in your application:: + + import logging + + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.DEBUG, + ) + +If logging is already configured:: + + import logging + + logger = logging.getLogger("websockets") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + +Refer to the :doc:`logging documentation <../topics/logging>` for more details +on logging in websockets. + What does ``ConnectionClosedError: no close frame received or sent`` mean? -------------------------------------------------------------------------- @@ -39,8 +66,9 @@ There are several reasons why long-lived connections may be lost: connections may terminate connections after a short amount of time, usually 30 seconds, despite websockets' keepalive mechanism. -If you're facing a reproducible issue, :ref:`enable debug logs ` to -see when and how connections are closed. +If you're facing a reproducible issue, :ref:`enable debug logs +` to see when and how connections are closed. connections are +closed. What does ``ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received`` mean? --------------------------------------------------------------------------------------------------------------------- diff --git a/docs/faq/index.rst b/docs/faq/index.rst index 9d5b0d538..7488a5397 100644 --- a/docs/faq/index.rst +++ b/docs/faq/index.rst @@ -7,10 +7,14 @@ Frequently asked questions about :mod:`asyncio`. :class: seealso - Python's documentation about `developing with asyncio`_ is a good - complement. + If you're new to ``asyncio``, you will certainly encounter issues that are + related to asynchronous programming in general rather than to websockets in + particular. - .. _developing with asyncio: https://docs.python.org/3/library/asyncio-dev.html + Fortunately, Python's official documentation provides advice to `develop + with asyncio`_. Check it out: it's invaluable! + + .. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html .. toctree:: diff --git a/docs/howto/cheatsheet.rst b/docs/howto/cheatsheet.rst deleted file mode 100644 index 8df2f234b..000000000 --- a/docs/howto/cheatsheet.rst +++ /dev/null @@ -1,86 +0,0 @@ -Cheat sheet -=========== - -.. currentmodule:: websockets - -Server ------- - -* Write a coroutine that handles a single connection. It receives a WebSocket - protocol instance and the URI path in argument. - - * Call :meth:`~asyncio.connection.Connection.recv` and - :meth:`~asyncio.connection.Connection.send` to receive and send messages at - any time. - - * When :meth:`~asyncio.connection.Connection.recv` or - :meth:`~asyncio.connection.Connection.send` raises - :exc:`~exceptions.ConnectionClosed`, clean up and exit. If you started other - :class:`asyncio.Task`, terminate them before exiting. - - * If you aren't awaiting :meth:`~asyncio.connection.Connection.recv`, consider - awaiting :meth:`~asyncio.connection.Connection.wait_closed` to detect - quickly when the connection is closed. - - * You may :meth:`~asyncio.connection.Connection.ping` or - :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed - in general. - -* Create a server with :func:`~asyncio.server.serve` which is similar to asyncio's - :meth:`~asyncio.loop.create_server`. You can also use it as an asynchronous - context manager. - - * The server takes care of establishing connections, then lets the handler - execute the application logic, and finally closes the connection after the - handler exits normally or with an exception. - - * For advanced customization, you may subclass - :class:`~asyncio.server.ServerConnection` and pass either this subclass or a - factory function as the ``create_connection`` argument. - -Client ------- - -* Create a client with :func:`~asyncio.client.connect` which is similar to - asyncio's :meth:`~asyncio.loop.create_connection`. You can also use it as an - asynchronous context manager. - - * For advanced customization, you may subclass - :class:`~asyncio.client.ClientConnection` and pass either this subclass or - a factory function as the ``create_connection`` argument. - -* Call :meth:`~asyncio.connection.Connection.recv` and - :meth:`~asyncio.connection.Connection.send` to receive and send messages at - any time. - -* You may :meth:`~asyncio.connection.Connection.ping` or - :meth:`~asyncio.connection.Connection.pong` if you wish but it isn't needed in - general. - -* If you aren't using :func:`~asyncio.client.connect` as a context manager, call - :meth:`~asyncio.connection.Connection.close` to terminate the connection. - -.. _debugging: - -Debugging ---------- - -If you don't understand what websockets is doing, enable logging:: - - import logging - logger = logging.getLogger('websockets') - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler()) - -The logs contain: - -* Exceptions in the connection handler at the ``ERROR`` level -* Exceptions in the opening or closing handshake at the ``INFO`` level -* All frames at the ``DEBUG`` level — this can be very verbose - -If you're new to ``asyncio``, you will certainly encounter issues that are -related to asynchronous programming in general rather than to websockets in -particular. Fortunately Python's official documentation provides advice to -`develop with asyncio`_. Check it out: it's invaluable! - -.. _develop with asyncio: https://docs.python.org/3/library/asyncio-dev.html diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 619b11fa8..0b573b229 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -21,7 +21,6 @@ If you're stuck, perhaps you'll find the answer here. .. toctree:: :titlesonly: - cheatsheet patterns autoreload diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index fff33a024..03c5bef5f 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -72,7 +72,7 @@ Here's a basic configuration for a server in production:: Here's how to enable debug logs for development:: logging.basicConfig( - format="%(message)s", + format="%(asctime)s %(message)s", level=logging.DEBUG, ) From 476aaac5bd146f82c8b30658233897acaedbe82e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 20:29:42 +0100 Subject: [PATCH 747/762] Deindent module contents in API docs. This the dominant style in the project. --- docs/reference/datastructures.rst | 76 +++++++++++++++---------------- docs/reference/extensions.rst | 41 ++++++++--------- docs/reference/types.rst | 14 +++--- 3 files changed, 65 insertions(+), 66 deletions(-) diff --git a/docs/reference/datastructures.rst b/docs/reference/datastructures.rst index ec02d4210..04a7466fa 100644 --- a/docs/reference/datastructures.rst +++ b/docs/reference/datastructures.rst @@ -6,61 +6,61 @@ WebSocket events .. automodule:: websockets.frames - .. autoclass:: Frame - - .. autoclass:: Opcode - - .. autoattribute:: CONT - .. autoattribute:: TEXT - .. autoattribute:: BINARY - .. autoattribute:: CLOSE - .. autoattribute:: PING - .. autoattribute:: PONG - - .. autoclass:: Close - - .. autoclass:: CloseCode - - .. autoattribute:: NORMAL_CLOSURE - .. autoattribute:: GOING_AWAY - .. autoattribute:: PROTOCOL_ERROR - .. autoattribute:: UNSUPPORTED_DATA - .. autoattribute:: NO_STATUS_RCVD - .. autoattribute:: ABNORMAL_CLOSURE - .. autoattribute:: INVALID_DATA - .. autoattribute:: POLICY_VIOLATION - .. autoattribute:: MESSAGE_TOO_BIG - .. autoattribute:: MANDATORY_EXTENSION - .. autoattribute:: INTERNAL_ERROR - .. autoattribute:: SERVICE_RESTART - .. autoattribute:: TRY_AGAIN_LATER - .. autoattribute:: BAD_GATEWAY - .. autoattribute:: TLS_HANDSHAKE +.. autoclass:: Frame + +.. autoclass:: Opcode + + .. autoattribute:: CONT + .. autoattribute:: TEXT + .. autoattribute:: BINARY + .. autoattribute:: CLOSE + .. autoattribute:: PING + .. autoattribute:: PONG + +.. autoclass:: Close + +.. autoclass:: CloseCode + + .. autoattribute:: NORMAL_CLOSURE + .. autoattribute:: GOING_AWAY + .. autoattribute:: PROTOCOL_ERROR + .. autoattribute:: UNSUPPORTED_DATA + .. autoattribute:: NO_STATUS_RCVD + .. autoattribute:: ABNORMAL_CLOSURE + .. autoattribute:: INVALID_DATA + .. autoattribute:: POLICY_VIOLATION + .. autoattribute:: MESSAGE_TOO_BIG + .. autoattribute:: MANDATORY_EXTENSION + .. autoattribute:: INTERNAL_ERROR + .. autoattribute:: SERVICE_RESTART + .. autoattribute:: TRY_AGAIN_LATER + .. autoattribute:: BAD_GATEWAY + .. autoattribute:: TLS_HANDSHAKE HTTP events ----------- .. automodule:: websockets.http11 - .. autoclass:: Request +.. autoclass:: Request - .. autoclass:: Response +.. autoclass:: Response .. automodule:: websockets.datastructures - .. autoclass:: Headers +.. autoclass:: Headers - .. automethod:: get_all + .. automethod:: get_all - .. automethod:: raw_items + .. automethod:: raw_items - .. autoexception:: MultipleValuesError +.. autoexception:: MultipleValuesError URIs ---- .. automodule:: websockets.uri - .. autofunction:: parse_uri +.. autofunction:: parse_uri - .. autoclass:: WebSocketURI +.. autoclass:: WebSocketURI diff --git a/docs/reference/extensions.rst b/docs/reference/extensions.rst index f3da464a5..880ef4a2a 100644 --- a/docs/reference/extensions.rst +++ b/docs/reference/extensions.rst @@ -16,45 +16,44 @@ Per-Message Deflate .. automodule:: websockets.extensions.permessage_deflate - :mod:`websockets.extensions.permessage_deflate` implements WebSocket - Per-Message Deflate. +:mod:`websockets.extensions.permessage_deflate` implements WebSocket Per-Message +Deflate. - This extension is specified in :rfc:`7692`. +This extension is specified in :rfc:`7692`. - Refer to the :doc:`topic guide on compression <../topics/compression>` to - learn more about tuning compression settings. +Refer to the :doc:`topic guide on compression <../topics/compression>` to learn +more about tuning compression settings. - .. autoclass:: ClientPerMessageDeflateFactory +.. autoclass:: ServerPerMessageDeflateFactory - .. autoclass:: ServerPerMessageDeflateFactory +.. autoclass:: ClientPerMessageDeflateFactory Base classes ------------ .. automodule:: websockets.extensions - :mod:`websockets.extensions` defines base classes for implementing - extensions. +:mod:`websockets.extensions` defines base classes for implementing extensions. - Refer to the :doc:`how-to guide on extensions <../howto/extensions>` to - learn more about writing an extension. +Refer to the :doc:`how-to guide on extensions <../howto/extensions>` to learn +more about writing an extension. - .. autoclass:: Extension +.. autoclass:: Extension - .. autoattribute:: name + .. autoattribute:: name - .. automethod:: decode + .. automethod:: decode - .. automethod:: encode + .. automethod:: encode - .. autoclass:: ClientExtensionFactory +.. autoclass:: ServerExtensionFactory - .. autoattribute:: name + .. automethod:: process_request_params - .. automethod:: get_request_params +.. autoclass:: ClientExtensionFactory - .. automethod:: process_response_params + .. autoattribute:: name - .. autoclass:: ServerExtensionFactory + .. automethod:: get_request_params - .. automethod:: process_request_params + .. automethod:: process_response_params diff --git a/docs/reference/types.rst b/docs/reference/types.rst index 9d3aa8310..d249b9294 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -3,19 +3,19 @@ Types .. automodule:: websockets.typing - .. autodata:: Data +.. autodata:: Data - .. autodata:: LoggerLike +.. autodata:: LoggerLike - .. autodata:: StatusLike +.. autodata:: StatusLike - .. autodata:: Origin +.. autodata:: Origin - .. autodata:: Subprotocol +.. autodata:: Subprotocol - .. autodata:: ExtensionName +.. autodata:: ExtensionName - .. autodata:: ExtensionParameter +.. autodata:: ExtensionParameter .. autodata:: websockets.protocol.Event From 4a2cfd5eaccb8e48971d21f65c4d04cf766a273f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 20:30:33 +0100 Subject: [PATCH 748/762] Assorted fixes to the docs. --- docs/intro/index.rst | 1 + docs/reference/asyncio/client.rst | 4 ++-- docs/reference/asyncio/common.rst | 4 ++-- docs/reference/asyncio/server.rst | 4 ++-- docs/reference/features.rst | 3 ++- docs/reference/index.rst | 11 +++++------ docs/reference/legacy/client.rst | 4 ++-- docs/reference/legacy/common.rst | 4 ++-- docs/reference/legacy/server.rst | 4 ++-- docs/spelling_wordlist.txt | 1 + docs/topics/design.rst | 4 ++-- 11 files changed, 23 insertions(+), 21 deletions(-) diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 642e50094..76994b6a2 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -35,6 +35,7 @@ Tutorial Learn how to build an real-time web application with websockets. .. toctree:: + :maxdepth: 1 tutorial1 tutorial2 diff --git a/docs/reference/asyncio/client.rst b/docs/reference/asyncio/client.rst index ea7b21506..72c7dce37 100644 --- a/docs/reference/asyncio/client.rst +++ b/docs/reference/asyncio/client.rst @@ -1,5 +1,5 @@ -Client (new :mod:`asyncio`) -=========================== +Client (:mod:`asyncio`) +======================= .. automodule:: websockets.asyncio.client diff --git a/docs/reference/asyncio/common.rst b/docs/reference/asyncio/common.rst index 325f20450..d772adc25 100644 --- a/docs/reference/asyncio/common.rst +++ b/docs/reference/asyncio/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (new :mod:`asyncio`) -=============================== +Both sides (:mod:`asyncio`) +=========================== .. automodule:: websockets.asyncio.connection diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index 8d8b700f3..a245929ef 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -1,5 +1,5 @@ -Server (new :mod:`asyncio`) -=========================== +Server (:mod:`asyncio`) +======================= .. automodule:: websockets.asyncio.server diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 321e2e832..e5f6e0de0 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -68,6 +68,7 @@ Both sides | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | @@ -177,7 +178,7 @@ There is no way to control compression of outgoing frames on a per-frame basis .. _#538: https://github.com/python-websockets/websockets/issues/538 The server doesn't check the Host header and doesn't respond with HTTP 400 Bad -Request if it is missing or invalid (`#1246`). +Request if it is missing or invalid (`#1246`_). .. _#1246: https://github.com/python-websockets/websockets/issues/1246 diff --git a/docs/reference/index.rst b/docs/reference/index.rst index c78a3c095..cc9542c24 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -51,12 +51,11 @@ application servers. sansio/server sansio/client -:mod:`asyncio` (legacy) ------------------------ - -This is the historical implementation. +Legacy +------ -It is deprecated and will be removed. +This is the historical implementation. It is deprecated. It will be removed by +2030. .. toctree:: :titlesonly: @@ -67,7 +66,7 @@ It is deprecated and will be removed. Extensions ---------- -The Per-Message Deflate extension is built in. You may also define custom +The Per-Message Deflate extension is built-in. You may also define custom extensions. .. toctree:: diff --git a/docs/reference/legacy/client.rst b/docs/reference/legacy/client.rst index a798409f0..ede887f32 100644 --- a/docs/reference/legacy/client.rst +++ b/docs/reference/legacy/client.rst @@ -1,5 +1,5 @@ -Client (legacy :mod:`asyncio`) -============================== +Client (legacy) +=============== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution diff --git a/docs/reference/legacy/common.rst b/docs/reference/legacy/common.rst index 45c56fccd..821576020 100644 --- a/docs/reference/legacy/common.rst +++ b/docs/reference/legacy/common.rst @@ -1,7 +1,7 @@ :orphan: -Both sides (legacy :mod:`asyncio`) -================================== +Both sides (legacy) +=================== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index 3c1d19fc6..8636034e2 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -1,5 +1,5 @@ -Server (legacy :mod:`asyncio`) -============================== +Server (legacy) +=============== .. admonition:: The legacy :mod:`asyncio` implementation is deprecated. :class: caution diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index dd32a78c3..1980aafa3 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -33,6 +33,7 @@ healthz html hypercorn iframe +io IPv istio iterable diff --git a/docs/topics/design.rst b/docs/topics/design.rst index bc14bd332..c1f55a9dc 100644 --- a/docs/topics/design.rst +++ b/docs/topics/design.rst @@ -1,7 +1,7 @@ :orphan: -Design (legacy :mod:`asyncio`) -============================== +Design (legacy) +=============== .. currentmodule:: websockets.legacy From 2a612686a0c7f0f39d09e2478d23ca91d9cfe5d6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 22:03:58 +0100 Subject: [PATCH 749/762] Move deployment guides to their own section. --- docs/{howto => deploy}/fly.rst | 0 docs/{howto => deploy}/haproxy.rst | 0 docs/{howto => deploy}/heroku.rst | 0 docs/deploy/index.rst | 26 ++++++++++++++++++++++++++ docs/{howto => deploy}/koyeb.rst | 0 docs/{howto => deploy}/kubernetes.rst | 0 docs/{howto => deploy}/nginx.rst | 0 docs/{howto => deploy}/render.rst | 0 docs/{howto => deploy}/supervisor.rst | 0 docs/howto/index.rst | 14 -------------- docs/index.rst | 1 + 11 files changed, 27 insertions(+), 14 deletions(-) rename docs/{howto => deploy}/fly.rst (100%) rename docs/{howto => deploy}/haproxy.rst (100%) rename docs/{howto => deploy}/heroku.rst (100%) create mode 100644 docs/deploy/index.rst rename docs/{howto => deploy}/koyeb.rst (100%) rename docs/{howto => deploy}/kubernetes.rst (100%) rename docs/{howto => deploy}/nginx.rst (100%) rename docs/{howto => deploy}/render.rst (100%) rename docs/{howto => deploy}/supervisor.rst (100%) diff --git a/docs/howto/fly.rst b/docs/deploy/fly.rst similarity index 100% rename from docs/howto/fly.rst rename to docs/deploy/fly.rst diff --git a/docs/howto/haproxy.rst b/docs/deploy/haproxy.rst similarity index 100% rename from docs/howto/haproxy.rst rename to docs/deploy/haproxy.rst diff --git a/docs/howto/heroku.rst b/docs/deploy/heroku.rst similarity index 100% rename from docs/howto/heroku.rst rename to docs/deploy/heroku.rst diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst new file mode 100644 index 000000000..a965e649b --- /dev/null +++ b/docs/deploy/index.rst @@ -0,0 +1,26 @@ +Deployment guides +================= + +Discover how to deploy your application on various platforms. + +Platforms-as-a-Service +---------------------- + +.. toctree:: + :titlesonly: + + render + koyeb + fly + heroku + +Self-hosted +----------- + +.. toctree:: + :titlesonly: + + kubernetes + supervisor + nginx + haproxy diff --git a/docs/howto/koyeb.rst b/docs/deploy/koyeb.rst similarity index 100% rename from docs/howto/koyeb.rst rename to docs/deploy/koyeb.rst diff --git a/docs/howto/kubernetes.rst b/docs/deploy/kubernetes.rst similarity index 100% rename from docs/howto/kubernetes.rst rename to docs/deploy/kubernetes.rst diff --git a/docs/howto/nginx.rst b/docs/deploy/nginx.rst similarity index 100% rename from docs/howto/nginx.rst rename to docs/deploy/nginx.rst diff --git a/docs/howto/render.rst b/docs/deploy/render.rst similarity index 100% rename from docs/howto/render.rst rename to docs/deploy/render.rst diff --git a/docs/howto/supervisor.rst b/docs/deploy/supervisor.rst similarity index 100% rename from docs/howto/supervisor.rst rename to docs/deploy/supervisor.rst diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 0b573b229..deae44942 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -41,20 +41,6 @@ features, which websockets supports fully. .. _deployment-howto: -Once your application is ready, learn how to deploy it on various platforms. - -.. toctree:: - :titlesonly: - - render - koyeb - fly - heroku - kubernetes - supervisor - nginx - haproxy - If you're integrating the Sans-I/O layer of websockets into a library, rather than building an application with websockets, follow this guide. diff --git a/docs/index.rst b/docs/index.rst index de14fa2d0..bc3cc9df4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -100,6 +100,7 @@ Do you like it? :doc:`Let's dive in! ` intro/index howto/index + deploy/index faq/index reference/index topics/index From 8c9f6fc27bb54610f1c389d7769cddc1504ca57e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 9 Feb 2025 22:46:17 +0100 Subject: [PATCH 750/762] Merge deployment topic guide into deployment section. --- .../architecture.svg} | 0 docs/deploy/index.rst | 202 +++++++++++++++++- docs/howto/index.rst | 2 - docs/spelling_wordlist.txt | 1 + docs/topics/deployment.rst | 181 ---------------- docs/topics/index.rst | 1 - 6 files changed, 197 insertions(+), 190 deletions(-) rename docs/{topics/deployment.svg => deploy/architecture.svg} (100%) delete mode 100644 docs/topics/deployment.rst diff --git a/docs/topics/deployment.svg b/docs/deploy/architecture.svg similarity index 100% rename from docs/topics/deployment.svg rename to docs/deploy/architecture.svg diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index a965e649b..a6432d835 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -1,11 +1,39 @@ -Deployment guides -================= +Deployment +========== -Discover how to deploy your application on various platforms. +.. currentmodule:: websockets -Platforms-as-a-Service +Architecture decisions ---------------------- +When you deploy your websockets server to production, at a high level, your +architecture will almost certainly look like the following diagram: + +.. image:: architecture.svg + +The basic unit for scaling a websockets server is "one server process". Each +blue box in the diagram represents one server process. + +There's more variation in routing. While the routing layer is shown as one big +box, it is likely to involve several subsystems. + +As a consequence, when you design a deployment, you must answer two questions: + +1. How will I run the appropriate number of server processes? +2. How will I route incoming connections to these processes? + +These questions are interrelated. There's a wide range of valid answers, +depending on your goals and your constraints. + +Platforms-as-a-Service +...................... + +Platforms-as-a-Service are the easiest option. They provide end-to-end, +integrated solutions and they require little configuration. + +Here's how to deploy on some popular PaaS providers. Since all PaaS use +similar patterns, the concepts translate to other providers. + .. toctree:: :titlesonly: @@ -14,8 +42,13 @@ Platforms-as-a-Service fly heroku -Self-hosted ------------ +Self-hosted infrastructure +.......................... + +If you need more control over your infrastructure, you can deploy on your own +infrastructure. This requires more configuration. + +Here's how to configure some components mentioned in this guide. .. toctree:: :titlesonly: @@ -24,3 +57,160 @@ Self-hosted supervisor nginx haproxy + +Running server processes +------------------------ + +How many processes do I need? +............................. + +Typically, one server process will manage a few hundreds or thousands +connections, depending on the frequency of messages and the amount of work +they require. + +CPU and memory usage increase with the number of connections to the server. + +Often CPU is the limiting factor. If a server process goes to 100% CPU, then +you reached the limit. How much headroom you want to keep is up to you. + +Once you know how many connections a server process can manage and how many +connections you need to handle, you can calculate how many processes to run. + +You can also automate this calculation by configuring an autoscaler to keep +CPU usage or connection count within acceptable limits. + +.. admonition:: Don't scale with threads. Scale only with processes. + :class: tip + + Threads don't make sense for a server built with :mod:`asyncio`. + +How do I run processes? +....................... + +Most solutions for running multiple instances of a server process fall into +one of these three buckets: + +1. Running N processes on a platform: + + * a Kubernetes Deployment + + * its equivalent on a Platform as a Service provider + +2. Running N servers: + + * an AWS Auto Scaling group, a GCP Managed instance group, etc. + + * a fixed set of long-lived servers + +3. Running N processes on a server: + + * preferably via a process manager or supervisor + +Option 1 is easiest if you have access to such a platform. Option 2 usually +combines with option 3. + +How do I start a process? +......................... + +Run a Python program that invokes :func:`~asyncio.server.serve` or +:func:`~asyncio.router.route`. That's it! + +Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're +alternatives to websockets, not complements. + +Don't run a WSGI server such as Gunicorn, Waitress, or mod_wsgi. They aren't +designed to run WebSocket applications. + +Applications servers handle network connections and expose a Python API. You +don't need one because websockets handles network connections directly. + +How do I stop a process? +........................ + +Process managers send the SIGTERM signal to terminate processes. Catch this +signal and exit the server to ensure a graceful shutdown. + +Here's an example: + +.. literalinclude:: ../../example/faq/shutdown_server.py + :emphasize-lines: 14-16 + +When exiting the context manager, :func:`~asyncio.server.serve` closes all +connections with code 1001 (going away). As a consequence: + +* If the connection handler is awaiting + :meth:`~asyncio.server.ServerConnection.recv`, it receives a + :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception + and clean up before exiting. + +* Otherwise, it should be waiting on + :meth:`~asyncio.server.ServerConnection.wait_closed`, so it can receive the + :exc:`~exceptions.ConnectionClosedOK` exception and exit. + +This example is easily adapted to handle other signals. + +If you override the default signal handler for SIGINT, which raises +:exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a +program with Ctrl-C anymore when it's stuck in a loop. + +Routing connections +------------------- + +What does routing involve? +.......................... + +Since the routing layer is directly exposed to the Internet, it should provide +appropriate protection against threats ranging from Internet background noise +to targeted attacks. + +You should always secure WebSocket connections with TLS. Since the routing +layer carries the public domain name, it should terminate TLS connections. + +Finally, it must route connections to the server processes, balancing new +connections across them. + +How do I route connections? +........................... + +Here are typical solutions for load balancing, matched to ways of running +processes: + +1. If you're running on a platform, it comes with a routing layer: + + * a Kubernetes Ingress and Service + + * a service mesh: Istio, Consul, Linkerd, etc. + + * the routing mesh of a Platform as a Service + +2. If you're running N servers, you may load balance with: + + * a cloud load balancer: AWS Elastic Load Balancing, GCP Cloud Load + Balancing, etc. + + * A software load balancer: HAProxy, NGINX, etc. + +3. If you're running N processes on a server, you may load balance with: + + * A software load balancer: HAProxy, NGINX, etc. + + * The operating system — all processes listen on the same port + +You may trust the load balancer to handle encryption and to provide security. +You may add another layer in front of the load balancer for these purposes. + +There are many possibilities. Don't add layers that you don't need, though. + +How do I implement a health check? +.................................. + +Load balancers need a way to check whether server processes are up and running +to avoid routing connections to a non-functional backend. + +websockets provide minimal support for responding to HTTP requests with the +``process_request`` hook. + +Here's an example: + +.. literalinclude:: ../../example/faq/health_check_server.py + :emphasize-lines: 7-9,16 diff --git a/docs/howto/index.rst b/docs/howto/index.rst index deae44942..c7859f3cd 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -39,8 +39,6 @@ features, which websockets supports fully. extensions -.. _deployment-howto: - If you're integrating the Sans-I/O layer of websockets into a library, rather than building an application with websockets, follow this guide. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1980aafa3..4a7dcd5ab 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -51,6 +51,7 @@ middleware mutex mypy nginx +PaaS Paketo permessage pid diff --git a/docs/topics/deployment.rst b/docs/topics/deployment.rst deleted file mode 100644 index 00d1f9285..000000000 --- a/docs/topics/deployment.rst +++ /dev/null @@ -1,181 +0,0 @@ -Deployment -========== - -.. currentmodule:: websockets - -When you deploy your websockets server to production, at a high level, your -architecture will almost certainly look like the following diagram: - -.. image:: deployment.svg - -The basic unit for scaling a websockets server is "one server process". Each -blue box in the diagram represents one server process. - -There's more variation in routing. While the routing layer is shown as one big -box, it is likely to involve several subsystems. - -When you design a deployment, your should consider two questions: - -1. How will I run the appropriate number of server processes? -2. How will I route incoming connections to these processes? - -These questions are strongly related. There's a wide range of acceptable -answers, depending on your goals and your constraints. - -You can find a few concrete examples in the :ref:`deployment how-to guides -`. - -Running server processes ------------------------- - -How many processes do I need? -............................. - -Typically, one server process will manage a few hundreds or thousands -connections, depending on the frequency of messages and the amount of work -they require. - -CPU and memory usage increase with the number of connections to the server. - -Often CPU is the limiting factor. If a server process goes to 100% CPU, then -you reached the limit. How much headroom you want to keep is up to you. - -Once you know how many connections a server process can manage and how many -connections you need to handle, you can calculate how many processes to run. - -You can also automate this calculation by configuring an autoscaler to keep -CPU usage or connection count within acceptable limits. - -Don't scale with threads. Threads doesn't make sense for a server built with -:mod:`asyncio`. - -How do I run processes? -....................... - -Most solutions for running multiple instances of a server process fall into -one of these three buckets: - -1. Running N processes on a platform: - - * a Kubernetes Deployment - - * its equivalent on a Platform as a Service provider - -2. Running N servers: - - * an AWS Auto Scaling group, a GCP Managed instance group, etc. - - * a fixed set of long-lived servers - -3. Running N processes on a server: - - * preferably via a process manager or supervisor - -Option 1 is easiest of you have access to such a platform. - -Option 2 almost always combines with option 3. - -How do I start a process? -......................... - -Run a Python program that invokes :func:`~asyncio.server.serve`. That's it. - -Don't run an ASGI server such as Uvicorn, Hypercorn, or Daphne. They're -alternatives to websockets, not complements. - -Don't run a WSGI server such as Gunicorn, Waitress, or mod_wsgi. They aren't -designed to run WebSocket applications. - -Applications servers handle network connections and expose a Python API. You -don't need one because websockets handles network connections directly. - -How do I stop a process? -........................ - -Process managers send the SIGTERM signal to terminate processes. Catch this -signal and exit the server to ensure a graceful shutdown. - -Here's an example: - -.. literalinclude:: ../../example/faq/shutdown_server.py - :emphasize-lines: 13-16,19 - -When exiting the context manager, :func:`~asyncio.server.serve` closes all -connections with code 1001 (going away). As a consequence: - -* If the connection handler is awaiting - :meth:`~asyncio.server.ServerConnection.recv`, it receives a - :exc:`~exceptions.ConnectionClosedOK` exception. It can catch the exception - and clean up before exiting. - -* Otherwise, it should be waiting on - :meth:`~asyncio.server.ServerConnection.wait_closed`, so it can receive the - :exc:`~exceptions.ConnectionClosedOK` exception and exit. - -This example is easily adapted to handle other signals. - -If you override the default signal handler for SIGINT, which raises -:exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a -program with Ctrl-C anymore when it's stuck in a loop. - -Routing connections -------------------- - -What does routing involve? -.......................... - -Since the routing layer is directly exposed to the Internet, it should provide -appropriate protection against threats ranging from Internet background noise -to targeted attacks. - -You should always secure WebSocket connections with TLS. Since the routing -layer carries the public domain name, it should terminate TLS connections. - -Finally, it must route connections to the server processes, balancing new -connections across them. - -How do I route connections? -........................... - -Here are typical solutions for load balancing, matched to ways of running -processes: - -1. If you're running on a platform, it comes with a routing layer: - - * a Kubernetes Ingress and Service - - * a service mesh: Istio, Consul, Linkerd, etc. - - * the routing mesh of a Platform as a Service - -2. If you're running N servers, you may load balance with: - - * a cloud load balancer: AWS Elastic Load Balancing, GCP Cloud Load - Balancing, etc. - - * A software load balancer: HAProxy, NGINX, etc. - -3. If you're running N processes on a server, you may load balance with: - - * A software load balancer: HAProxy, NGINX, etc. - - * The operating system — all processes listen on the same port - -You may trust the load balancer to handle encryption and to provide security. -You may add another layer in front of the load balancer for these purposes. - -There are many possibilities. Don't add layers that you don't need, though. - -How do I implement a health check? -.................................. - -Load balancers need a way to check whether server processes are up and running -to avoid routing connections to a non-functional backend. - -websockets provide minimal support for responding to HTTP requests with the -``process_request`` hook. - -Here's an example: - -.. literalinclude:: ../../example/faq/health_check_server.py - :emphasize-lines: 7-9,18 diff --git a/docs/topics/index.rst b/docs/topics/index.rst index a08d487c9..ebffff71f 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -6,7 +6,6 @@ Get a deeper understanding of how websockets is built and why. .. toctree:: :titlesonly: - deployment logging authentication broadcast From 89b037fbf7aca8af23ff68f773615b242ff75ae7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 10:53:52 +0100 Subject: [PATCH 751/762] Remove content added by mistake in 4b9caad7. --- docs/reference/sync/server.rst | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index f6a45a659..59dde9b35 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -23,18 +23,6 @@ Routing connections .. currentmodule:: websockets.sync.server -Routing connections -------------------- - -.. autofunction:: route - :async: - -.. autofunction:: unix_route - :async: - -.. autoclass:: Server - - Running a server ---------------- From 74a7ac20a7009d68cd4a38705d8aac969a17e78a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 13:27:12 +0100 Subject: [PATCH 752/762] Add topic guide on routing. --- docs/deploy/index.rst | 8 ++-- docs/faq/server.rst | 22 +--------- docs/project/changelog.rst | 7 +++- docs/topics/index.rst | 3 +- docs/topics/routing.rst | 83 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 27 deletions(-) create mode 100644 docs/topics/routing.rst diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index a6432d835..2bdab9464 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -14,8 +14,8 @@ architecture will almost certainly look like the following diagram: The basic unit for scaling a websockets server is "one server process". Each blue box in the diagram represents one server process. -There's more variation in routing. While the routing layer is shown as one big -box, it is likely to involve several subsystems. +There's more variation in routing connections to processes. While the routing +layer is shown as one big box, it is likely to involve several subsystems. As a consequence, when you design a deployment, you must answer two questions: @@ -153,8 +153,8 @@ If you override the default signal handler for SIGINT, which raises :exc:`KeyboardInterrupt`, be aware that you won't be able to interrupt a program with Ctrl-C anymore when it's stuck in a loop. -Routing connections -------------------- +Routing connections to processes +-------------------------------- What does routing involve? .......................... diff --git a/docs/faq/server.rst b/docs/faq/server.rst index bb04c5e1c..10b041095 100644 --- a/docs/faq/server.rst +++ b/docs/faq/server.rst @@ -208,26 +208,8 @@ How do I access the request path? It is available in the :attr:`~ServerConnection.request` object. -You may route a connection to different handlers depending on the request path:: - - async def handler(websocket): - if websocket.request.path == "/blue": - await blue_handler(websocket) - elif websocket.request.path == "/green": - await green_handler(websocket) - else: - # No handler for this path; close the connection. - return - -For more complex routing, you may use :func:`~websockets.asyncio.router.route`. - -You may also route the connection based on the first message received from the -client, as shown in the :doc:`tutorial <../intro/tutorial2>`. When you want to -authenticate the connection before routing it, this is usually more convenient. - -Generally speaking, there is far less emphasis on the request path in WebSocket -servers than in HTTP servers. When a WebSocket server provides a single endpoint, -it may ignore the request path entirely. +Refer to the :doc:`routing guide <../topics/routing>` for details on how to +route connections to different handlers depending on the request path. How do I access HTTP headers? ----------------------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index d7db6167a..31a537af5 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,7 +43,9 @@ Backwards-incompatible changes SOCKS proxies require installing the third-party library `python-socks`_. If you want to disable the proxy, add ``proxy=None`` when calling - :func:`~asyncio.client.connect`. See :doc:`../topics/proxies` for details. + :func:`~asyncio.client.connect`. + + See :doc:`proxies <../topics/proxies>` for details. .. _python-socks: https://github.com/romis2012/python-socks @@ -60,7 +62,8 @@ New features ............ * Added :func:`~asyncio.router.route` and :func:`~asyncio.router.unix_route` to - dispatch connections to different handlers depending on the URL. + dispatch connections to handlers based on the request path. Read more about + routing in :doc:`routing <../topics/routing>`. Improvements ............ diff --git a/docs/topics/index.rst b/docs/topics/index.rst index ebffff71f..4273599b7 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -6,12 +6,13 @@ Get a deeper understanding of how websockets is built and why. .. toctree:: :titlesonly: - logging authentication broadcast compression keepalive + logging memory security performance proxies + routing diff --git a/docs/topics/routing.rst b/docs/topics/routing.rst new file mode 100644 index 000000000..64c675e74 --- /dev/null +++ b/docs/topics/routing.rst @@ -0,0 +1,83 @@ +Routing +======= + +.. currentmodule:: websockets + +Many WebSocket servers provide just one endpoint. That's why +:func:`~asyncio.server.serve` accepts a single connection handler as its first +argument. + +This may come as a surprise to you if you're used to HTTP servers. In a standard +HTTP application, each request gets dispatched to a handler based on the request +path. Clients know which path to use for which operation. + +In a WebSocket application, clients open a persistent connection then they send +all messages over that unique connection. When different messages correspond to +different operations, they must be dispatched based on the message content. + +Simple routing +-------------- + +If you need different handlers for different clients or different use cases, you +may route each connection to the right handler based on the request path. + +Since WebSocket servers typically provide fewer routes than HTTP servers, you +can keep it simple:: + + async def handler(websocket): + match websocket.request.path: + case "/blue": + await blue_handler(websocket) + case "/green": + await green_handler(websocket) + case _: + # No handler for this path. Close the connection. + return + +You may also route connections based on the first message received from the +client, as demonstrated in the :doc:`tutorial <../intro/tutorial2>`:: + + import json + + async def handler(websocket): + message = await websocket.recv() + settings = json.loads(message) + match settings["color"]: + case "blue": + await blue_handler(websocket) + case "green": + await green_handler(websocket) + case _: + # No handler for this message. Close the connection. + return + +When you need to authenticate the connection before routing it, this pattern is +more convenient. + +Complex routing +--------------- + +If you have outgrow these simple patterns, websockets provides full-fledged +routing based on the request path with :func:`~asyncio.router.route`. + +This feature builds upon Flask_'s router. To use it, you must install the +third-party library `werkzeug`_:: + + $ pip install werkzeug + +.. _Flask: https://flask.palletsprojects.com/ +.. _werkzeug: https://werkzeug.palletsprojects.com/ + +:func:`~asyncio.router.route` expects a :class:`werkzeug.routing.Map` as its +first argument to declare which URL patterns map to which handlers. Review the +documentation of :mod:`werkzeug.routing` to learn about its functionality. + +To give you a sense of what's possible, here's the URL map of the example in +`example/routing.py`_: + +.. _example/routing.py: https://github.com/python-websockets/websockets/blob/main/example/routing.py + +.. literalinclude:: ../../example/routing.py + :language: python + :start-at: url_map = Map( + :end-at: await server.serve_forever() From 7bfb1140e1c19aa3b59ecf2b08bde56a82cfe04a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 13:48:51 +0100 Subject: [PATCH 753/762] Add standalone guide to enable debug logs. Pointing to this information is frequently needed in the issue tracker. --- docs/faq/common.rst | 32 ++------------------------------ docs/howto/autoreload.rst | 7 +++---- docs/howto/debugging.rst | 33 +++++++++++++++++++++++++++++++++ docs/howto/index.rst | 29 +++++++++++++---------------- docs/intro/tutorial1.rst | 5 ++++- 5 files changed, 55 insertions(+), 51 deletions(-) create mode 100644 docs/howto/debugging.rst diff --git a/docs/faq/common.rst b/docs/faq/common.rst index b31d4504f..1ee0062af 100644 --- a/docs/faq/common.rst +++ b/docs/faq/common.rst @@ -3,33 +3,6 @@ Both sides .. currentmodule:: websockets.asyncio.connection -.. _enable-debug-logs: - -How do I enable debug logs? ---------------------------- - -You can enable debug logs to see exactly what websockets is doing. - -If logging isn't configured in your application:: - - import logging - - logging.basicConfig( - format="%(asctime)s %(message)s", - level=logging.DEBUG, - ) - -If logging is already configured:: - - import logging - - logger = logging.getLogger("websockets") - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler()) - -Refer to the :doc:`logging documentation <../topics/logging>` for more details -on logging in websockets. - What does ``ConnectionClosedError: no close frame received or sent`` mean? -------------------------------------------------------------------------- @@ -66,9 +39,8 @@ There are several reasons why long-lived connections may be lost: connections may terminate connections after a short amount of time, usually 30 seconds, despite websockets' keepalive mechanism. -If you're facing a reproducible issue, :ref:`enable debug logs -` to see when and how connections are closed. connections are -closed. +If you're facing a reproducible issue, :doc:`enable debug logs +<../howto/debugging>` to see when and how connections are closed. What does ``ConnectionClosedError: sent 1011 (internal error) keepalive ping timeout; no close frame received`` mean? --------------------------------------------------------------------------------------------------------------------- diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst index fc736a591..4d1adee8b 100644 --- a/docs/howto/autoreload.rst +++ b/docs/howto/autoreload.rst @@ -7,8 +7,7 @@ stop the server and restart it, which slows down your development process. Web frameworks such as Django or Flask provide a development server that reloads the application automatically when you make code changes. There is no -such functionality in websockets because it's designed for production rather -than development. +such functionality in websockets because it's designed only for production. However, you can achieve the same result easily. @@ -27,5 +26,5 @@ Run your server with ``watchmedo auto-restart``: $ watchmedo auto-restart --pattern "*.py" --recursive --signal SIGTERM \ python app.py -This example assumes that the server is defined in a script called ``app.py``. -Adapt it as necessary. +This example assumes that the server is defined in a script called ``app.py`` +and exits cleanly when receiving the ``SIGTERM`` signal. Adapt it as necessary. diff --git a/docs/howto/debugging.rst b/docs/howto/debugging.rst new file mode 100644 index 000000000..fc5dcba56 --- /dev/null +++ b/docs/howto/debugging.rst @@ -0,0 +1,33 @@ +Enable debug logs +================== + +websockets logs events with the :mod:`logging` module from the standard library. + +It writes to the ``"websockets.server"`` and ``"websockets.client"`` loggers. + +Enable logs at the ``DEBUG`` level to see exactly what websockets is doing. + +If logging isn't configured in your application:: + + import logging + + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.DEBUG, + ) + +If logging is already configured:: + + import logging + + logger = logging.getLogger("websockets") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + +Refer to the :doc:`logging guide <../topics/logging>` for more details on +logging in websockets. + +In addition, you may enable asyncio's `debug mode`_ to see what asyncio is +doing. + +.. _debug mode: https://docs.python.org/3/library/asyncio-dev.html#asyncio-debug-mode diff --git a/docs/howto/index.rst b/docs/howto/index.rst index c7859f3cd..8a9717ed1 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -4,40 +4,30 @@ How-to guides In a hurry? Check out these examples. .. toctree:: - :titlesonly: quickstart - -Upgrading from the legacy :mod:`asyncio` implementation to the new one? -Read this. - -.. toctree:: - :titlesonly: - - upgrade + debugging + autoreload If you're stuck, perhaps you'll find the answer here. .. toctree:: - :titlesonly: patterns - autoreload This guide will help you integrate websockets into a broader system. .. toctree:: - :titlesonly: django -The WebSocket protocol makes provisions for extending or specializing its -features, which websockets supports fully. +Upgrading from the legacy :mod:`asyncio` implementation to the new one? +Read this. .. toctree:: - :titlesonly: + :maxdepth: 2 - extensions + upgrade If you're integrating the Sans-I/O layer of websockets into a library, rather than building an application with websockets, follow this guide. @@ -46,3 +36,10 @@ than building an application with websockets, follow this guide. :maxdepth: 2 sansio + +The WebSocket protocol makes provisions for extending or specializing its +features, which websockets supports fully. + +.. toctree:: + + extensions diff --git a/docs/intro/tutorial1.rst b/docs/intro/tutorial1.rst index cc88e6a4a..39e693aae 100644 --- a/docs/intro/tutorial1.rst +++ b/docs/intro/tutorial1.rst @@ -545,7 +545,10 @@ taking alternate turns. import logging - logging.basicConfig(format="%(message)s", level=logging.DEBUG) + logging.basicConfig( + format="%(asctime)s %(message)s", + level=logging.DEBUG, + ) If you're stuck, a solution is available at the bottom of this document. From e60e04cdc755fe41e8057b23842e87cabacdc694 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 14:12:00 +0100 Subject: [PATCH 754/762] Unify examples for topic guides under experiments. --- docs/topics/logging.rst | 2 +- docs/topics/routing.rst | 6 +++--- {example/logging => experiments}/json_log_formatter.py | 0 {example => experiments}/routing.py | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename {example/logging => experiments}/json_log_formatter.py (100%) rename {example => experiments}/routing.py (100%) diff --git a/docs/topics/logging.rst b/docs/topics/logging.rst index 03c5bef5f..2eedd32a4 100644 --- a/docs/topics/logging.rst +++ b/docs/topics/logging.rst @@ -146,7 +146,7 @@ output logs as JSON with a bit of effort. First, we need a :class:`~logging.Formatter` that renders JSON: -.. literalinclude:: ../../example/logging/json_log_formatter.py +.. literalinclude:: ../../experiments/json_log_formatter.py Then, we configure logging to apply this formatter:: diff --git a/docs/topics/routing.rst b/docs/topics/routing.rst index 64c675e74..d1790532a 100644 --- a/docs/topics/routing.rst +++ b/docs/topics/routing.rst @@ -73,11 +73,11 @@ first argument to declare which URL patterns map to which handlers. Review the documentation of :mod:`werkzeug.routing` to learn about its functionality. To give you a sense of what's possible, here's the URL map of the example in -`example/routing.py`_: +`experiments/routing.py`_: -.. _example/routing.py: https://github.com/python-websockets/websockets/blob/main/example/routing.py +.. _experiments/routing.py: https://github.com/python-websockets/websockets/blob/main/experiments/routing.py -.. literalinclude:: ../../example/routing.py +.. literalinclude:: ../../experiments/routing.py :language: python :start-at: url_map = Map( :end-at: await server.serve_forever() diff --git a/example/logging/json_log_formatter.py b/experiments/json_log_formatter.py similarity index 100% rename from example/logging/json_log_formatter.py rename to experiments/json_log_formatter.py diff --git a/example/routing.py b/experiments/routing.py similarity index 100% rename from example/routing.py rename to experiments/routing.py From 82dfd83cee4814f27867c909da785ebe3228cbf8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 14:16:39 +0100 Subject: [PATCH 755/762] Improve index page for topic guides. --- docs/topics/authentication.rst | 4 ++-- docs/topics/index.rst | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/docs/topics/authentication.rst b/docs/topics/authentication.rst index e2de4332e..7c022066f 100644 --- a/docs/topics/authentication.rst +++ b/docs/topics/authentication.rst @@ -313,8 +313,8 @@ To authenticate a websockets client with HTTP Basic Authentication async with connect(f"wss://{username}:{password}@.../") as websocket: ... -(You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they -contain unsafe characters.) +You must :func:`~urllib.parse.quote` ``username`` and ``password`` if they +contain unsafe characters. To authenticate a websockets client with HTTP Bearer Authentication (:rfc:`6750`), add a suitable ``Authorization`` header: diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 4273599b7..ca5d83c97 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -1,18 +1,26 @@ Topic guides ============ -Get a deeper understanding of how websockets is built and why. +These documents discuss how websockets is designed and how to make the best of +its features when building applications. .. toctree:: - :titlesonly: + :maxdepth: 2 authentication broadcast + logging + proxies + routing + +These guides describe how to optimize the configuration of websockets +applications for performance and reliability. + +.. toctree:: + :maxdepth: 2 + compression keepalive - logging memory security performance - proxies - routing From 61a6a7a99534e2522773476fc72ca615f5a6ea97 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 21:48:27 +0100 Subject: [PATCH 756/762] Clean up quick start guides. * Move most examples back to the getting started section. * Add a dedicated howto on encryption. And some drive-by fixes. Refs #1209. --- docs/deploy/koyeb.rst | 1 - docs/howto/encryption.rst | 65 +++++++ docs/howto/index.rst | 12 +- docs/howto/quickstart.rst | 170 ------------------ docs/intro/examples.rst | 112 ++++++++++++ docs/intro/index.rst | 9 +- docs/reference/legacy/server.rst | 1 - docs/topics/compression.rst | 1 - docs/topics/proxies.rst | 4 +- docs/topics/routing.rst | 5 +- docs/topics/security.rst | 9 +- example/quick/client.py | 17 ++ example/{quickstart => quick}/counter.css | 0 example/{quickstart => quick}/counter.html | 0 example/{quickstart => quick}/counter.js | 0 example/{quickstart => quick}/counter.py | 1 + example/{quickstart => quick}/server.py | 0 example/{quickstart => quick}/show_time.html | 0 example/{quickstart => quick}/show_time.js | 0 example/{quickstart => quick}/show_time.py | 2 +- example/quick/sync_time.py | 23 +++ example/quickstart/client.py | 19 -- example/quickstart/show_time_2.py | 29 --- .../client_secure.py => tls/client.py} | 13 +- example/{quickstart => tls}/localhost.pem | 0 .../server_secure.py => tls/server.py} | 0 src/websockets/asyncio/router.py | 4 +- src/websockets/sync/router.py | 4 +- 28 files changed, 258 insertions(+), 243 deletions(-) create mode 100644 docs/howto/encryption.rst delete mode 100644 docs/howto/quickstart.rst create mode 100644 docs/intro/examples.rst create mode 100755 example/quick/client.py rename example/{quickstart => quick}/counter.css (100%) rename example/{quickstart => quick}/counter.html (100%) rename example/{quickstart => quick}/counter.js (100%) rename example/{quickstart => quick}/counter.py (99%) rename example/{quickstart => quick}/server.py (100%) rename example/{quickstart => quick}/show_time.html (100%) rename example/{quickstart => quick}/show_time.js (100%) rename example/{quickstart => quick}/show_time.py (84%) create mode 100755 example/quick/sync_time.py delete mode 100755 example/quickstart/client.py delete mode 100755 example/quickstart/show_time_2.py rename example/{quickstart/client_secure.py => tls/client.py} (61%) rename example/{quickstart => tls}/localhost.pem (100%) rename example/{quickstart/server_secure.py => tls/server.py} (100%) diff --git a/docs/deploy/koyeb.rst b/docs/deploy/koyeb.rst index 0ad126dd8..0b7c96cb9 100644 --- a/docs/deploy/koyeb.rst +++ b/docs/deploy/koyeb.rst @@ -139,7 +139,6 @@ or press Ctrl-D to terminate the connection: < Hello! Connection closed: 1000 (OK). - You can also confirm that your application shuts down gracefully. Connect an interactive client again: diff --git a/docs/howto/encryption.rst b/docs/howto/encryption.rst new file mode 100644 index 000000000..af19fefd0 --- /dev/null +++ b/docs/howto/encryption.rst @@ -0,0 +1,65 @@ +Encrypt connections +==================== + +.. currentmodule:: websockets + +You should always secure WebSocket connections with TLS_ (Transport Layer +Security). + +.. admonition:: TLS vs. SSL + :class: tip + + TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an + earlier encryption protocol; the name stuck. + +The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. + +Secure WebSocket connections require certificates just like HTTPS. + +.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security + +.. admonition:: Configure the TLS context securely + :class: attention + + The examples below demonstrate the ``ssl`` argument with a TLS certificate + shared between the client and the server. This is a simplistic setup. + + Please review the advice and security considerations in the documentation of + the :mod:`ssl` module to configure the TLS context appropriately. + +Servers +------- + +In a typical :doc:`deployment <../deploy/index>`, the server is behind a reverse +proxy that terminates TLS. The client connects to the reverse proxy with TLS and +the reverse proxy connects to the server without TLS. + +In that case, you don't need to configure TLS in websockets. + +If needed in your setup, you can terminate TLS in the server. + +In the example below, :func:`~asyncio.server.serve` is configured to receive +secure connections. Before running this server, download +:download:`localhost.pem <../../example/tls/localhost.pem>` and save it in the +same directory as ``server.py``. + +.. literalinclude:: ../../example/tls/server.py + :caption: server.py + +Receive both plain and TLS connections on the same port isn't supported. + +Clients +------- + +:func:`~asyncio.client.connect` enables TLS automatically when connecting to a +``wss://...`` URI. + +This works out of the box when the TLS certificate of the server is valid, +meaning it's signed by a certificate authority that your Python installation +trusts. + +In the example above, since the server uses a self-signed certificate, the +client needs to be configured to trust the certificate. Here's how to do so. + +.. literalinclude:: ../../example/tls/client.py + :caption: client.py diff --git a/docs/howto/index.rst b/docs/howto/index.rst index 8a9717ed1..ffded9ff0 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -1,13 +1,18 @@ How-to guides ============= -In a hurry? Check out these examples. +Set up your development environment comfortably. .. toctree:: - quickstart - debugging autoreload + debugging + +Configure websockets securely in production. + +.. toctree:: + + encryption If you're stuck, perhaps you'll find the answer here. @@ -25,7 +30,6 @@ Upgrading from the legacy :mod:`asyncio` implementation to the new one? Read this. .. toctree:: - :maxdepth: 2 upgrade diff --git a/docs/howto/quickstart.rst b/docs/howto/quickstart.rst deleted file mode 100644 index e6bd362a4..000000000 --- a/docs/howto/quickstart.rst +++ /dev/null @@ -1,170 +0,0 @@ -Quick start -=========== - -.. currentmodule:: websockets - -Here are a few examples to get you started quickly with websockets. - -Say "Hello world!" ------------------- - -Here's a WebSocket server. - -It receives a name from the client, sends a greeting, and closes the connection. - -.. literalinclude:: ../../example/quickstart/server.py - :caption: server.py - :language: python - :linenos: - -:func:`~asyncio.server.serve` executes the connection handler coroutine -``hello()`` once for each WebSocket connection. It closes the WebSocket -connection when the handler returns. - -Here's a corresponding WebSocket client. - -It sends a name to the server, receives a greeting, and closes the connection. - -.. literalinclude:: ../../example/quickstart/client.py - :caption: client.py - :language: python - :linenos: - -Using :func:`~asyncio.client.connect` as an asynchronous context manager ensures -the WebSocket connection is closed. - -.. _secure-server-example: - -Encrypt connections -------------------- - -Secure WebSocket connections improve confidentiality and also reliability -because they reduce the risk of interference by bad proxies. - -The ``wss`` protocol is to ``ws`` what ``https`` is to ``http``. The -connection is encrypted with TLS_ (Transport Layer Security). ``wss`` -requires certificates like ``https``. - -.. _TLS: https://developer.mozilla.org/en-US/docs/Web/Security/Transport_Layer_Security - -.. admonition:: TLS vs. SSL - :class: tip - - TLS is sometimes referred to as SSL (Secure Sockets Layer). SSL was an - earlier encryption protocol; the name stuck. - -Here's how to adapt the server to encrypt connections. You must download -:download:`localhost.pem <../../example/quickstart/localhost.pem>` and save it -in the same directory as ``server_secure.py``. - -.. literalinclude:: ../../example/quickstart/server_secure.py - :caption: server_secure.py - :language: python - :linenos: - -Here's how to adapt the client similarly. - -.. literalinclude:: ../../example/quickstart/client_secure.py - :caption: client_secure.py - :language: python - :linenos: - -In this example, the client needs a TLS context because the server uses a -self-signed certificate. - -When connecting to a secure WebSocket server with a valid certificate — any -certificate signed by a CA that your Python installation trusts — you can simply -pass ``ssl=True`` to :func:`~asyncio.client.connect`. - -.. admonition:: Configure the TLS context securely - :class: attention - - This example demonstrates the ``ssl`` argument with a TLS certificate shared - between the client and the server. This is a simplistic setup. - - Please review the advice and security considerations in the documentation of - the :mod:`ssl` module to configure the TLS context securely. - -Connect from a browser ----------------------- - -The WebSocket protocol was invented for the web — as the name says! - -Here's how to connect to a WebSocket server from a browser. - -Run this script in a console: - -.. literalinclude:: ../../example/quickstart/show_time.py - :caption: show_time.py - :language: python - :linenos: - -Save this file as ``show_time.html``: - -.. literalinclude:: ../../example/quickstart/show_time.html - :caption: show_time.html - :language: html - :linenos: - -Save this file as ``show_time.js``: - -.. literalinclude:: ../../example/quickstart/show_time.js - :caption: show_time.js - :language: js - :linenos: - -Then, open ``show_time.html`` in several browsers. Clocks tick irregularly. - -Broadcast messages ------------------- - -Let's change the previous example to send the same timestamps to all browsers, -instead of generating independent sequences for each client. - -Stop the previous script if it's still running and run this script in a console: - -.. literalinclude:: ../../example/quickstart/show_time_2.py - :caption: show_time_2.py - :language: python - :linenos: - -Refresh ``show_time.html`` in all browsers. Clocks tick in sync. - -Manage application state ------------------------- - -A WebSocket server can receive events from clients, process them to update the -application state, and broadcast the updated state to all connected clients. - -Here's an example where any client can increment or decrement a counter. The -concurrency model of :mod:`asyncio` guarantees that updates are serialized. - -Run this script in a console: - -.. literalinclude:: ../../example/quickstart/counter.py - :caption: counter.py - :language: python - :linenos: - -Save this file as ``counter.html``: - -.. literalinclude:: ../../example/quickstart/counter.html - :caption: counter.html - :language: html - :linenos: - -Save this file as ``counter.css``: - -.. literalinclude:: ../../example/quickstart/counter.css - :caption: counter.css - :language: css - :linenos: - -Save this file as ``counter.js``: - -.. literalinclude:: ../../example/quickstart/counter.js - :caption: counter.js - :language: js - :linenos: - -Then open ``counter.html`` file in several browsers and play with [+] and [-]. diff --git a/docs/intro/examples.rst b/docs/intro/examples.rst new file mode 100644 index 000000000..341712475 --- /dev/null +++ b/docs/intro/examples.rst @@ -0,0 +1,112 @@ +Quick examples +============== + +.. currentmodule:: websockets + +Start a server +-------------- + +This WebSocket server receives a name from the client, sends a greeting, and +closes the connection. + +.. literalinclude:: ../../example/quick/server.py + :caption: server.py + :language: python + +:func:`~asyncio.server.serve` executes the connection handler coroutine +``hello()`` once for each WebSocket connection. It closes the WebSocket +connection when the handler returns. + +Connect a client +---------------- + +This WebSocket client sends a name to the server, receives a greeting, and +closes the connection. + +.. literalinclude:: ../../example/quick/client.py + :caption: client.py + :language: python + +Using :func:`~sync.client.connect` as a context manager ensures that the +WebSocket connection is closed. + +Connect a browser +----------------- + +The WebSocket protocol was invented for the web — as the name says! + +Here's how to connect a browser to a WebSocket server. + +Run this script in a console: + +.. literalinclude:: ../../example/quick/show_time.py + :caption: show_time.py + :language: python + +Save this file as ``show_time.html``: + +.. literalinclude:: ../../example/quick/show_time.html + :caption: show_time.html + :language: html + +Save this file as ``show_time.js``: + +.. literalinclude:: ../../example/quick/show_time.js + :caption: show_time.js + :language: js + +Then, open ``show_time.html`` in several browsers or tabs. Clocks tick +irregularly. + +Broadcast messages +------------------ + +Let's send the same timestamps to everyone instead of generating independent +sequences for each connection. + +Stop the previous script if it's still running and run this script in a console: + +.. literalinclude:: ../../example/quick/sync_time.py + :caption: sync_time.py + :language: python + +Refresh ``show_time.html`` in all browsers or tabs. Clocks tick in sync. + +Manage application state +------------------------ + +A WebSocket server can receive events from clients, process them to update the +application state, and broadcast the updated state to all connected clients. + +Here's an example where any client can increment or decrement a counter. The +concurrency model of :mod:`asyncio` guarantees that updates are serialized. + +This example keep tracks of connected users explicitly in ``USERS`` instead of +relying on :attr:`server.connections `. The +result is the same. + +Run this script in a console: + +.. literalinclude:: ../../example/quick/counter.py + :caption: counter.py + :language: python + +Save this file as ``counter.html``: + +.. literalinclude:: ../../example/quick/counter.html + :caption: counter.html + :language: html + +Save this file as ``counter.css``: + +.. literalinclude:: ../../example/quick/counter.css + :caption: counter.css + :language: css + +Save this file as ``counter.js``: + +.. literalinclude:: ../../example/quick/counter.js + :caption: counter.js + :language: js + +Then open ``counter.html`` file in several browsers and play with [+] and [-]. diff --git a/docs/intro/index.rst b/docs/intro/index.rst index 76994b6a2..d6f8fb9e0 100644 --- a/docs/intro/index.rst +++ b/docs/intro/index.rst @@ -35,7 +35,7 @@ Tutorial Learn how to build an real-time web application with websockets. .. toctree:: - :maxdepth: 1 + :maxdepth: 2 tutorial1 tutorial2 @@ -44,4 +44,9 @@ Learn how to build an real-time web application with websockets. In a hurry? ----------- -Look at the :doc:`quick start guide <../howto/quickstart>`. +These examples will get you started quickly with websockets. + +.. toctree:: + :maxdepth: 2 + + examples diff --git a/docs/reference/legacy/server.rst b/docs/reference/legacy/server.rst index 8636034e2..0ac84156d 100644 --- a/docs/reference/legacy/server.rst +++ b/docs/reference/legacy/server.rst @@ -94,7 +94,6 @@ Using a connection .. autoproperty:: close_reason - Broadcast --------- diff --git a/docs/topics/compression.rst b/docs/topics/compression.rst index 5f09bbf73..06bba0922 100644 --- a/docs/topics/compression.rst +++ b/docs/topics/compression.rst @@ -20,7 +20,6 @@ based on the Deflate_ algorithm specified in :rfc:`7692`. the reduction in network bandwidth is usually worth the additional memory and CPU cost. - Configuring compression ----------------------- diff --git a/docs/topics/proxies.rst b/docs/topics/proxies.rst index 14fc68c0c..a2536d4c0 100644 --- a/docs/topics/proxies.rst +++ b/docs/topics/proxies.rst @@ -55,7 +55,9 @@ SOCKS proxies ------------- Connecting through a SOCKS proxy requires installing the third-party library -`python-socks`_:: +`python-socks`_: + +.. code-block:: console $ pip install python-socks\[asyncio\] diff --git a/docs/topics/routing.rst b/docs/topics/routing.rst index d1790532a..44d89e00b 100644 --- a/docs/topics/routing.rst +++ b/docs/topics/routing.rst @@ -61,7 +61,9 @@ If you have outgrow these simple patterns, websockets provides full-fledged routing based on the request path with :func:`~asyncio.router.route`. This feature builds upon Flask_'s router. To use it, you must install the -third-party library `werkzeug`_:: +third-party library `werkzeug`_: + +.. code-block:: console $ pip install werkzeug @@ -78,6 +80,5 @@ To give you a sense of what's possible, here's the URL map of the example in .. _experiments/routing.py: https://github.com/python-websockets/websockets/blob/main/experiments/routing.py .. literalinclude:: ../../experiments/routing.py - :language: python :start-at: url_map = Map( :end-at: await server.serve_forever() diff --git a/docs/topics/security.rst b/docs/topics/security.rst index a22b752c7..e91f73b15 100644 --- a/docs/topics/security.rst +++ b/docs/topics/security.rst @@ -6,10 +6,13 @@ Security Encryption ---------- -For production use, a server should require encrypted connections. +In production, you should always secure WebSocket connections with TLS. -See this example of :ref:`encrypting connections with TLS -`. +Secure WebSocket connections provide confidentiality and integrity, as well as +better reliability because they reduce the risk of interference by bad proxies. + +WebSocket servers are usually deployed behind a reverse proxy that terminates +TLS. Else, you can :doc:`configure TLS <../howto/encryption>` for the server. Memory usage ------------ diff --git a/example/quick/client.py b/example/quick/client.py new file mode 100755 index 000000000..4f34c0628 --- /dev/null +++ b/example/quick/client.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +from websockets.sync.client import connect + +def hello(): + uri = "ws://localhost:8765" + with connect(uri) as websocket: + name = input("What's your name? ") + + websocket.send(name) + print(f">>> {name}") + + greeting = websocket.recv() + print(f"<<< {greeting}") + +if __name__ == "__main__": + hello() diff --git a/example/quickstart/counter.css b/example/quick/counter.css similarity index 100% rename from example/quickstart/counter.css rename to example/quick/counter.css diff --git a/example/quickstart/counter.html b/example/quick/counter.html similarity index 100% rename from example/quickstart/counter.html rename to example/quick/counter.html diff --git a/example/quickstart/counter.js b/example/quick/counter.js similarity index 100% rename from example/quickstart/counter.js rename to example/quick/counter.js diff --git a/example/quickstart/counter.py b/example/quick/counter.py similarity index 99% rename from example/quickstart/counter.py rename to example/quick/counter.py index 8f0ff81be..b31345ce2 100755 --- a/example/quickstart/counter.py +++ b/example/quick/counter.py @@ -3,6 +3,7 @@ import asyncio import json import logging + from websockets.asyncio.server import broadcast, serve logging.basicConfig() diff --git a/example/quickstart/server.py b/example/quick/server.py similarity index 100% rename from example/quickstart/server.py rename to example/quick/server.py diff --git a/example/quickstart/show_time.html b/example/quick/show_time.html similarity index 100% rename from example/quickstart/show_time.html rename to example/quick/show_time.html diff --git a/example/quickstart/show_time.js b/example/quick/show_time.js similarity index 100% rename from example/quickstart/show_time.js rename to example/quick/show_time.js diff --git a/example/quickstart/show_time.py b/example/quick/show_time.py similarity index 84% rename from example/quickstart/show_time.py rename to example/quick/show_time.py index ecb908f30..b56aada7b 100755 --- a/example/quickstart/show_time.py +++ b/example/quick/show_time.py @@ -8,7 +8,7 @@ async def show_time(websocket): while True: - message = datetime.datetime.utcnow().isoformat() + "Z" + message = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() await websocket.send(message) await asyncio.sleep(random.random() * 2 + 1) diff --git a/example/quick/sync_time.py b/example/quick/sync_time.py new file mode 100755 index 000000000..cdbe731af --- /dev/null +++ b/example/quick/sync_time.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +import asyncio +import datetime +import random + +from websockets.asyncio.server import broadcast, serve + +async def noop(websocket): + await websocket.wait_closed() + +async def show_time(server): + while True: + message = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + broadcast(server.connections, message) + await asyncio.sleep(random.random() * 2 + 1) + +async def main(): + async with serve(noop, "localhost", 5678) as server: + await show_time(server) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/example/quickstart/client.py b/example/quickstart/client.py deleted file mode 100755 index 934af69e3..000000000 --- a/example/quickstart/client.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python - -import asyncio - -from websockets.asyncio.client import connect - -async def hello(): - uri = "ws://localhost:8765" - async with connect(uri) as websocket: - name = input("What's your name? ") - - await websocket.send(name) - print(f">>> {name}") - - greeting = await websocket.recv() - print(f"<<< {greeting}") - -if __name__ == "__main__": - asyncio.run(hello()) diff --git a/example/quickstart/show_time_2.py b/example/quickstart/show_time_2.py deleted file mode 100755 index 9c9659d14..000000000 --- a/example/quickstart/show_time_2.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import datetime -import random - -from websockets.asyncio.server import broadcast, serve - -CONNECTIONS = set() - -async def register(websocket): - CONNECTIONS.add(websocket) - try: - await websocket.wait_closed() - finally: - CONNECTIONS.remove(websocket) - -async def show_time(): - while True: - message = datetime.datetime.utcnow().isoformat() + "Z" - broadcast(CONNECTIONS, message) - await asyncio.sleep(random.random() * 2 + 1) - -async def main(): - async with serve(register, "localhost", 5678): - await show_time() - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/example/quickstart/client_secure.py b/example/tls/client.py similarity index 61% rename from example/quickstart/client_secure.py rename to example/tls/client.py index a1449587a..c97ccf8e4 100755 --- a/example/quickstart/client_secure.py +++ b/example/tls/client.py @@ -1,25 +1,24 @@ #!/usr/bin/env python -import asyncio import pathlib import ssl -from websockets.asyncio.client import connect +from websockets.sync.client import connect ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) localhost_pem = pathlib.Path(__file__).with_name("localhost.pem") ssl_context.load_verify_locations(localhost_pem) -async def hello(): +def hello(): uri = "wss://localhost:8765" - async with connect(uri, ssl=ssl_context) as websocket: + with connect(uri, ssl=ssl_context) as websocket: name = input("What's your name? ") - await websocket.send(name) + websocket.send(name) print(f">>> {name}") - greeting = await websocket.recv() + greeting = websocket.recv() print(f"<<< {greeting}") if __name__ == "__main__": - asyncio.run(hello()) + hello() diff --git a/example/quickstart/localhost.pem b/example/tls/localhost.pem similarity index 100% rename from example/quickstart/localhost.pem rename to example/tls/localhost.pem diff --git a/example/quickstart/server_secure.py b/example/tls/server.py similarity index 100% rename from example/quickstart/server_secure.py rename to example/tls/server.py diff --git a/src/websockets/asyncio/router.py b/src/websockets/asyncio/router.py index cd95022c1..047e7ef1c 100644 --- a/src/websockets/asyncio/router.py +++ b/src/websockets/asyncio/router.py @@ -81,7 +81,9 @@ def route( """ Create a WebSocket server dispatching connections to different handlers. - This feature requires the third-party library `werkzeug`_:: + This feature requires the third-party library `werkzeug`_: + + .. code-block:: console $ pip install werkzeug diff --git a/src/websockets/sync/router.py b/src/websockets/sync/router.py index 33105bf32..5572c4261 100644 --- a/src/websockets/sync/router.py +++ b/src/websockets/sync/router.py @@ -81,7 +81,9 @@ def route( """ Create a WebSocket server dispatching connections to different handlers. - This feature requires the third-party library `werkzeug`_:: + This feature requires the third-party library `werkzeug`_: + + .. code-block:: console $ pip install werkzeug From e934680c21b719bad631d8a6ee12bbc54c783601 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 22:20:05 +0100 Subject: [PATCH 757/762] Broaden type of extension parameters. --- src/websockets/extensions/base.py | 2 +- src/websockets/extensions/permessage_deflate.py | 2 +- src/websockets/headers.py | 2 +- src/websockets/typing.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/websockets/extensions/base.py b/src/websockets/extensions/base.py index 42dd6c5fa..2fdc59f0f 100644 --- a/src/websockets/extensions/base.py +++ b/src/websockets/extensions/base.py @@ -58,7 +58,7 @@ class ClientExtensionFactory: name: ExtensionName """Extension identifier.""" - def get_request_params(self) -> list[ExtensionParameter]: + def get_request_params(self) -> Sequence[ExtensionParameter]: """ Build parameters to send to the server for this extension. diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 8e74cb282..7e9e7a5dd 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -351,7 +351,7 @@ def __init__( self.client_max_window_bits = client_max_window_bits self.compress_settings = compress_settings - def get_request_params(self) -> list[ExtensionParameter]: + def get_request_params(self) -> Sequence[ExtensionParameter]: """ Build request parameters. diff --git a/src/websockets/headers.py b/src/websockets/headers.py index c42abd976..e05ff5b4c 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -390,7 +390,7 @@ def parse_extension(header: str) -> list[ExtensionHeader]: def build_extension_item( - name: ExtensionName, parameters: list[ExtensionParameter] + name: ExtensionName, parameters: Sequence[ExtensionParameter] ) -> str: """ Build an extension definition. diff --git a/src/websockets/typing.py b/src/websockets/typing.py index f10481b8b..ab7ddd33e 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,7 +2,7 @@ import http import logging -from typing import TYPE_CHECKING, Any, NewType, Optional, Union +from typing import TYPE_CHECKING, Any, NewType, Optional, Sequence, Union __all__ = [ @@ -62,7 +62,7 @@ # Private types -ExtensionHeader = tuple[ExtensionName, list[ExtensionParameter]] +ExtensionHeader = tuple[ExtensionName, Sequence[ExtensionParameter]] """Extension in a ``Sec-WebSocket-Extensions`` header.""" From 667e418ae474235c2bc26024df11c8e1d573d27a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 15 Feb 2025 22:55:58 +0100 Subject: [PATCH 758/762] Review how-to guides, notably the patterns guide. Fix #1209. --- docs/conf.py | 1 + docs/howto/autoreload.rst | 17 +++++----- docs/howto/debugging.rst | 11 ++++--- docs/howto/django.rst | 57 ++++++++++++++++----------------- docs/howto/extensions.rst | 39 ++++++++++++++--------- docs/howto/index.rst | 9 ++---- docs/howto/patterns.rst | 64 +++++++++++++++++++++++--------------- docs/project/changelog.rst | 2 ++ 8 files changed, 113 insertions(+), 87 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index c6b9ac7d8..798d595db 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -84,6 +84,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), + "sesame": ("https://django-sesame.readthedocs.io/en/stable/", None), "werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None), } diff --git a/docs/howto/autoreload.rst b/docs/howto/autoreload.rst index 4d1adee8b..dfa84ada3 100644 --- a/docs/howto/autoreload.rst +++ b/docs/howto/autoreload.rst @@ -1,15 +1,16 @@ Reload on code changes ====================== -When developing a websockets server, you may run it locally to test changes. -Unfortunately, whenever you want to try a new version of the code, you must -stop the server and restart it, which slows down your development process. +When developing a websockets server, you are likely to run it locally to test +changes. Unfortunately, whenever you want to try a new version of the code, you +must stop the server and restart it, which slows down your development process. -Web frameworks such as Django or Flask provide a development server that -reloads the application automatically when you make code changes. There is no -such functionality in websockets because it's designed only for production. +Web frameworks such as Django or Flask provide a development server that reloads +the application automatically when you make code changes. There is no equivalent +functionality in websockets because it's designed only for production. -However, you can achieve the same result easily. +However, you can achieve the same result easily with a third-party library and a +shell command. Install watchdog_ with the ``watchmedo`` shell utility: @@ -27,4 +28,4 @@ Run your server with ``watchmedo auto-restart``: python app.py This example assumes that the server is defined in a script called ``app.py`` -and exits cleanly when receiving the ``SIGTERM`` signal. Adapt it as necessary. +and exits cleanly when receiving the ``SIGTERM`` signal. Adapt as necessary. diff --git a/docs/howto/debugging.rst b/docs/howto/debugging.rst index fc5dcba56..546f70a6f 100644 --- a/docs/howto/debugging.rst +++ b/docs/howto/debugging.rst @@ -3,9 +3,10 @@ Enable debug logs websockets logs events with the :mod:`logging` module from the standard library. -It writes to the ``"websockets.server"`` and ``"websockets.client"`` loggers. +It emits logs in the ``"websockets.server"`` and ``"websockets.client"`` +loggers. -Enable logs at the ``DEBUG`` level to see exactly what websockets is doing. +You can enable logs at the ``DEBUG`` level to see exactly what websockets does. If logging isn't configured in your application:: @@ -24,10 +25,10 @@ If logging is already configured:: logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) -Refer to the :doc:`logging guide <../topics/logging>` for more details on +Refer to the :doc:`logging guide <../topics/logging>` for more information about logging in websockets. -In addition, you may enable asyncio's `debug mode`_ to see what asyncio is -doing. +You may also enable asyncio's `debug mode`_ to get warnings about classic +pitfalls. .. _debug mode: https://docs.python.org/3/library/asyncio-dev.html#asyncio-debug-mode diff --git a/docs/howto/django.rst b/docs/howto/django.rst index 4fe2311cb..2fa0b89fd 100644 --- a/docs/howto/django.rst +++ b/docs/howto/django.rst @@ -101,6 +101,7 @@ call its APIs in the websockets server. Now here's how to implement authentication. .. literalinclude:: ../../example/django/authentication.py + :caption: authentication.py Let's unpack this code. @@ -113,23 +114,25 @@ your settings module. The connection handler reads the first message received from the client, which is expected to contain a django-sesame token. Then it authenticates the user -with ``get_user()``, the API for `authentication outside a view`_. If -authentication fails, it closes the connection and exits. +with :func:`~sesame.utils.get_user`, the API provided by django-sesame for +`authentication outside a view`_. .. _authentication outside a view: https://django-sesame.readthedocs.io/en/stable/howto.html#outside-a-view -When we call an API that makes a database query such as ``get_user()``, we -wrap the call in :func:`~asyncio.to_thread`. Indeed, the Django ORM doesn't -support asynchronous I/O. It would block the event loop if it didn't run in a -separate thread. +If authentication fails, it closes the connection and exits. + +When we call an API that makes a database query such as +:func:`~sesame.utils.get_user`, we wrap the call in :func:`~asyncio.to_thread`. +Indeed, the Django ORM doesn't support asynchronous I/O. It would block the +event loop if it didn't run in a separate thread. Finally, we start a server with :func:`~websockets.asyncio.server.serve`. We're ready to test! -Save this code to a file called ``authentication.py``, make sure the -``DJANGO_SETTINGS_MODULE`` environment variable is set properly, and start the -websockets server: +Download :download:`authentication.py <../../example/django/authentication.py>`, +make sure the ``DJANGO_SETTINGS_MODULE`` environment variable is set properly, +and start the websockets server: .. code-block:: console @@ -169,7 +172,7 @@ following code in the JavaScript console of the browser: websocket.onmessage = (event) => console.log(event.data); If you don't want to import your entire Django project into the websockets -server, you can build a separate Django project with ``django.contrib.auth``, +server, you can create a simpler Django project with ``django.contrib.auth``, ``django-sesame``, a suitable ``User`` model, and a subset of the settings of the main project. @@ -184,11 +187,11 @@ action was made. This may be used for showing notifications to other users. Many use cases for WebSocket with Django follow a similar pattern. -Set up event bus -................ +Set up event stream +................... -We need a event bus to enable communications between Django and websockets. -Both sides connect permanently to the bus. Then Django writes events and +We need an event stream to enable communications between Django and websockets. +Both sides connect permanently to the stream. Then Django writes events and websockets reads them. For the sake of simplicity, we'll rely on `Redis Pub/Sub`_. @@ -219,14 +222,15 @@ change ``get_redis_connection("default")`` in the code below to the same name. Publish events .............. -Now let's write events to the bus. +Now let's write events to the stream. Add the following code to a module that is imported when your Django project -starts. Typically, you would put it in a ``signals.py`` module, which you -would import in the ``AppConfig.ready()`` method of one of your apps: +starts. Typically, you would put it in a :download:`signals.py +<../../example/django/signals.py>` module, which you would import in the +``AppConfig.ready()`` method of one of your apps: .. literalinclude:: ../../example/django/signals.py - + :caption: signals.py This code runs every time the admin saves a ``LogEntry`` object to keep track of a change. It extracts interesting data, serializes it to JSON, and writes an event to Redis. @@ -256,13 +260,13 @@ We need to add several features: * Keep track of connected clients so we can broadcast messages. * Tell which content types the user has permission to view or to change. -* Connect to the message bus and read events. +* Connect to the message stream and read events. * Broadcast these events to users who have corresponding permissions. Here's a complete implementation. .. literalinclude:: ../../example/django/notifications.py - + :caption: notifications.py Since the ``get_content_types()`` function makes a database query, it is wrapped inside :func:`asyncio.to_thread()`. It runs once when each WebSocket connection is open; then its result is cached for the lifetime of the @@ -273,13 +277,10 @@ The connection handler merely registers the connection in a global variable, associated to the list of content types for which events should be sent to that connection, and waits until the client disconnects. -The ``process_events()`` function reads events from Redis and broadcasts them -to all connections that should receive them. We don't care much if a sending a -notification fails — this happens when a connection drops between the moment -we iterate on connections and the moment the corresponding message is sent — -so we start a task with for each message and forget about it. Also, this means -we're immediately ready to process the next event, even if it takes time to -send a message to a slow client. +The ``process_events()`` function reads events from Redis and broadcasts them to +all connections that should receive them. We don't care much if a sending a +notification fails. This happens when a connection drops between the moment we +iterate on connections and the moment the corresponding message is sent. Since Redis can publish a message to multiple subscribers, multiple instances of this server can safely run in parallel. @@ -290,4 +291,4 @@ Does it scale? In theory, given enough servers, this design can scale to a hundred million clients, since Redis can handle ten thousand servers and each server can handle ten thousand clients. In practice, you would need a more scalable -message bus before reaching that scale, due to the volume of messages. +message stream before reaching that scale, due to the volume of messages. diff --git a/docs/howto/extensions.rst b/docs/howto/extensions.rst index c4e9da626..2f73e2f87 100644 --- a/docs/howto/extensions.rst +++ b/docs/howto/extensions.rst @@ -1,30 +1,39 @@ Write an extension ================== -.. currentmodule:: websockets.extensions +.. currentmodule:: websockets During the opening handshake, WebSocket clients and servers negotiate which -extensions_ will be used with which parameters. Then each frame is processed -by extensions before being sent or after being received. +extensions_ will be used and with which parameters. .. _extensions: https://datatracker.ietf.org/doc/html/rfc6455.html#section-9 -As a consequence, writing an extension requires implementing several classes: +Then, each frame is processed before being sent and after being received +according to the extensions that were negotiated. -* Extension Factory: it negotiates parameters and instantiates the extension. +Writing an extension requires implementing at least two classes, an extension +factory and an extension. They inherit from base classes provided by websockets. - Clients and servers require separate extension factories with distinct APIs. +Extension factory +----------------- - Extension factories are the public API of an extension. +An extension factory negotiates parameters and instantiates the extension. -* Extension: it decodes incoming frames and encodes outgoing frames. +Clients and servers require separate extension factories with distinct APIs. +Base classes are :class:`~extensions.ClientExtensionFactory` and +:class:`~extensions.ServerExtensionFactory`. - If the extension is symmetrical, clients and servers can use the same - class. +Extension factories are the public API of an extension. Extensions are enabled +with the ``extensions`` parameter of :func:`~asyncio.client.connect` or +:func:`~asyncio.server.serve`. - Extensions are initialized by extension factories, so they don't need to be - part of the public API of an extension. +Extension +--------- -websockets provides base classes for extension factories and extensions. -See :class:`ClientExtensionFactory`, :class:`ServerExtensionFactory`, -and :class:`Extension` for details. +An extension decodes incoming frames and encodes outgoing frames. + +If the extension is symmetrical, clients and servers can use the same class. The +base class is :class:`~extensions.Extension`. + +Since extensions are initialized by extension factories, they don't need to be +part of the public API of an extension. diff --git a/docs/howto/index.rst b/docs/howto/index.rst index ffded9ff0..12b38ed06 100644 --- a/docs/howto/index.rst +++ b/docs/howto/index.rst @@ -14,22 +14,19 @@ Configure websockets securely in production. encryption -If you're stuck, perhaps you'll find the answer here. +These guides will help you design and build your application. .. toctree:: + :maxdepth: 2 patterns - -This guide will help you integrate websockets into a broader system. - -.. toctree:: - django Upgrading from the legacy :mod:`asyncio` implementation to the new one? Read this. .. toctree:: + :maxdepth: 2 upgrade diff --git a/docs/howto/patterns.rst b/docs/howto/patterns.rst index bfb78b6ca..e0602b6f8 100644 --- a/docs/howto/patterns.rst +++ b/docs/howto/patterns.rst @@ -1,46 +1,52 @@ -Patterns -======== +Design a WebSocket application +============================== .. currentmodule:: websockets -Here are typical patterns for processing messages in a WebSocket server or -client. You will certainly implement some of them in your application. +WebSocket server or client applications follow common patterns. This guide +describes patterns that you're likely to implement in your application. -This page gives examples of connection handlers for a server. However, they're -also applicable to a client, simply by assuming that ``websocket`` is a -connection created with :func:`~asyncio.client.connect`. +All examples are connection handlers for a server. However, they would also +apply to a client, assuming that ``websocket`` is a connection created with +:func:`~asyncio.client.connect`. -WebSocket connections are long-lived. You will usually write a loop to process -several messages during the lifetime of a connection. +.. admonition:: WebSocket connections are long-lived. + :class: tip -Consumer --------- + You need a loop to process several messages during the lifetime of a + connection. + +Consumer pattern +---------------- To receive messages from the WebSocket connection:: async def consumer_handler(websocket): async for message in websocket: - await consumer(message) + await consume(message) -In this example, ``consumer()`` is a coroutine implementing your business -logic for processing a message received on the WebSocket connection. Each -message may be :class:`str` or :class:`bytes`. +In this example, ``consume()`` is a coroutine implementing your business logic +for processing a message received on the WebSocket connection. Iteration terminates when the client disconnects. -Producer --------- +Producer pattern +---------------- To send messages to the WebSocket connection:: + from websockets.exceptions import ConnectionClosed + async def producer_handler(websocket): - while True: - message = await producer() - await websocket.send(message) + try: + while True: + message = await produce() + await websocket.send(message) + except ConnectionClosed: + break -In this example, ``producer()`` is a coroutine implementing your business -logic for generating the next message to send on the WebSocket connection. -Each message must be :class:`str` or :class:`bytes`. +In this example, ``produce()`` is a coroutine implementing your business logic +for generating the next message to send on the WebSocket connection. Iteration terminates when the client disconnects because :meth:`~asyncio.server.ServerConnection.send` raises a @@ -51,8 +57,12 @@ Consumer and producer --------------------- You can receive and send messages on the same WebSocket connection by -combining the consumer and producer patterns. This requires running two tasks -in parallel:: +combining the consumer and producer patterns. + +This requires running two tasks in parallel. The simplest option offered by +:mod:`asyncio` is:: + + import asyncio async def handler(websocket): await asyncio.gather( @@ -99,6 +109,10 @@ connect and unregister them when they disconnect:: This example maintains the set of connected clients in memory. This works as long as you run a single process. It doesn't scale to multiple processes. +If you just need the set of connected clients, as in this example, use the +:attr:`~asyncio.server.Server.connections` property of the server. This pattern +is needed only when recording additional information about each client. + Publish–subscribe ----------------- diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 31a537af5..7dc79cf7d 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -68,6 +68,8 @@ New features Improvements ............ +* Refreshed several how-to guides and topic guides. + * Added type overloads for the ``decode`` argument of :meth:`~asyncio.connection.Connection.recv`. This may simplify static typing. From 963f13f0bc9ed72afc38fdc0f373b2549744d570 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2025 08:41:55 +0100 Subject: [PATCH 759/762] Restore display of close code and reason. Fix #1591. --- src/websockets/__main__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index 8647481d0..fbc8b0568 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -12,6 +12,7 @@ except ImportError: # Windows has no `readline` normally pass +from .frames import Close from .sync.client import ClientConnection, connect from .version import version as websockets_version @@ -150,7 +151,10 @@ def main() -> None: except (KeyboardInterrupt, EOFError): # ^C, ^D stop.set() websocket.close() - print_over_input("Connection closed.") + + assert websocket.close_code is not None and websocket.close_reason is not None + close_status = Close(websocket.close_code, websocket.close_reason) + print_over_input(f"Connection closed: {close_status}.") thread.join() From bc3fd2946a67276faff6e1a1043e6f1f1d0ad69f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2025 09:22:55 +0100 Subject: [PATCH 760/762] Simplify enabling VT100 mode on Windows. Arguably, this relies on a bug, but: * The current implementation was vulnerable to platform-specific bugs, as highlighted in discussions of VT100 on Windows in the Python bug tracker, while the new one is clearly not going to cause harm. * If the bug in Python is fixed, hopefully we will gain a proper way to enable VT100. --- src/websockets/__main__.py | 49 ++++---------------------------------- 1 file changed, 5 insertions(+), 44 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index fbc8b0568..c356510c5 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -9,7 +9,7 @@ try: import readline # noqa: F401 -except ImportError: # Windows has no `readline` normally +except ImportError: # readline isn't available on all platforms pass from .frames import Close @@ -17,38 +17,6 @@ from .version import version as websockets_version -if sys.platform == "win32": - - def win_enable_vt100() -> None: - """ - Enable VT-100 for console output on Windows. - - See also https://github.com/python/cpython/issues/73245. - - """ - import ctypes - - STD_OUTPUT_HANDLE = ctypes.c_uint(-11) - INVALID_HANDLE_VALUE = ctypes.c_uint(-1) - ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004 - - handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE) - if handle == INVALID_HANDLE_VALUE: - raise RuntimeError("unable to obtain stdout handle") - - cur_mode = ctypes.c_uint() - if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0: - raise RuntimeError("unable to query current console mode") - - # ctypes ints lack support for the required bit-OR operation. - # Temporarily convert to Py int, do the OR and convert back. - py_int_mode = int.from_bytes(cur_mode, sys.byteorder) - new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING) - - if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0: - raise RuntimeError("unable to set console mode") - - def print_during_input(string: str) -> None: sys.stdout.write( # Save cursor position @@ -116,17 +84,10 @@ def main() -> None: if args.uri is None: parser.error("the following arguments are required: ") - # If we're on Windows, enable VT100 terminal support. - if sys.platform == "win32": - try: - win_enable_vt100() - except RuntimeError as exc: - sys.stderr.write( - f"Unable to set terminal to VT100 mode. This is only " - f"supported since Win10 anniversary update. Expect " - f"weird symbols on the terminal.\nError: {exc}\n" - ) - sys.stderr.flush() + # Enable VT100 to support ANSI escape codes in Command Prompt on Windows. + # See https://github.com/python/cpython/issues/74261 for why this works. + if sys.platform == "win32": # pragma: no cover + os.system("") try: websocket = connect(args.uri) From a1ba01db142459db0ea6f7659b3a5f4962749fa6 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2025 11:32:49 +0100 Subject: [PATCH 761/762] Rewrite interactive client (again) without threads. Fix #1592. --- src/websockets/__main__.py | 138 ++++++++++++++++++++++++------------- 1 file changed, 91 insertions(+), 47 deletions(-) diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index c356510c5..5043e1e29 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -1,19 +1,16 @@ from __future__ import annotations import argparse +import asyncio import os -import signal import sys -import threading - - -try: - import readline # noqa: F401 -except ImportError: # readline isn't available on all platforms - pass +from typing import Generator +from .asyncio.client import ClientConnection, connect +from .asyncio.messages import SimpleQueue +from .exceptions import ConnectionClosed from .frames import Close -from .sync.client import ClientConnection, connect +from .streams import StreamReader from .version import version as websockets_version @@ -49,24 +46,94 @@ def print_over_input(string: str) -> None: sys.stdout.flush() -def print_incoming_messages(websocket: ClientConnection, stop: threading.Event) -> None: - for message in websocket: +class ReadLines(asyncio.Protocol): + def __init__(self) -> None: + self.reader = StreamReader() + self.messages: SimpleQueue[str] = SimpleQueue() + + def parse(self) -> Generator[None, None, None]: + while True: + sys.stdout.write("> ") + sys.stdout.flush() + line = yield from self.reader.read_line(sys.maxsize) + self.messages.put(line.decode().rstrip("\r\n")) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.parser = self.parse() + next(self.parser) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + next(self.parser) + + def eof_received(self) -> None: + self.reader.feed_eof() + # next(self.parser) isn't useful and would raise EOFError. + + def connection_lost(self, exc: Exception | None) -> None: + self.reader.discard() + self.messages.abort() + + +async def print_incoming_messages(websocket: ClientConnection) -> None: + async for message in websocket: if isinstance(message, str): print_during_input("< " + message) else: print_during_input("< (binary) " + message.hex()) - if not stop.is_set(): - # When the server closes the connection, raise KeyboardInterrupt - # in the main thread to exit the program. - if sys.platform == "win32": - ctrl_c = signal.CTRL_C_EVENT - else: - ctrl_c = signal.SIGINT - os.kill(os.getpid(), ctrl_c) + + +async def send_outgoing_messages( + websocket: ClientConnection, + messages: SimpleQueue[str], +) -> None: + while True: + try: + message = await messages.get() + except EOFError: + break + try: + await websocket.send(message) + except ConnectionClosed: + break + + +async def interactive_client(uri: str) -> None: + try: + websocket = await connect(uri) + except Exception as exc: + print(f"Failed to connect to {uri}: {exc}.") + sys.exit(1) + else: + print(f"Connected to {uri}.") + + loop = asyncio.get_running_loop() + transport, protocol = await loop.connect_read_pipe(ReadLines, sys.stdin) + incoming = asyncio.create_task( + print_incoming_messages(websocket), + ) + outgoing = asyncio.create_task( + send_outgoing_messages(websocket, protocol.messages), + ) + try: + await asyncio.wait( + [incoming, outgoing], + return_when=asyncio.FIRST_COMPLETED, + ) + except (KeyboardInterrupt, EOFError): # ^C, ^D + pass + finally: + incoming.cancel() + outgoing.cancel() + transport.close() + + await websocket.close() + assert websocket.close_code is not None and websocket.close_reason is not None + close_status = Close(websocket.close_code, websocket.close_reason) + print_over_input(f"Connection closed: {close_status}.") def main() -> None: - # Parse command line arguments. parser = argparse.ArgumentParser( prog="python -m websockets", description="Interactive WebSocket client.", @@ -90,34 +157,11 @@ def main() -> None: os.system("") try: - websocket = connect(args.uri) - except Exception as exc: - print(f"Failed to connect to {args.uri}: {exc}.") - sys.exit(1) - else: - print(f"Connected to {args.uri}.") - - stop = threading.Event() - - # Start the thread that reads messages from the connection. - thread = threading.Thread(target=print_incoming_messages, args=(websocket, stop)) - thread.start() - - # Read from stdin in the main thread in order to receive signals. - try: - while True: - # Since there's no size limit, put_nowait is identical to put. - message = input("> ") - websocket.send(message) - except (KeyboardInterrupt, EOFError): # ^C, ^D - stop.set() - websocket.close() - - assert websocket.close_code is not None and websocket.close_reason is not None - close_status = Close(websocket.close_code, websocket.close_reason) - print_over_input(f"Connection closed: {close_status}.") + import readline # noqa: F401 + except ImportError: # readline isn't available on all platforms + pass - thread.join() + asyncio.run(interactive_client(args.uri)) if __name__ == "__main__": From 7ac73c645329055a3c352077b8055e6ed65fa46c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 16 Feb 2025 11:47:18 +0100 Subject: [PATCH 762/762] Release version 15.0. --- docs/project/changelog.rst | 2 +- src/websockets/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 7dc79cf7d..287d2fe31 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -30,7 +30,7 @@ notice. 15.0 ---- -*In development* +*February 16, 2025* Backwards-incompatible changes .............................. diff --git a/src/websockets/version.py b/src/websockets/version.py index 611e7d238..738f2cac1 100644 --- a/src/websockets/version.py +++ b/src/websockets/version.py @@ -18,7 +18,7 @@ # When tagging a release, set `released = True`. # After tagging a release, set `released = False` and increment `tag`. -released = False +released = True tag = version = commit = "15.0"